Skip to main content

fakecloud_dynamodb/
streams_dataplane.rs

1//! DynamoDB Streams data plane (`DynamoDBStreams_20120810`).
2//!
3//! Lambda event source mappings against `arn:aws:dynamodb:.../stream/...`
4//! depend on `ListStreams`, `DescribeStream`, `GetShardIterator`, and
5//! `GetRecords`. The control plane's `EnableStream` / `DescribeTable`
6//! already populate `DynamoTable::stream_records` on every mutation;
7//! this module is the consumer side that surfaces those records.
8
9use std::sync::Arc;
10
11use async_trait::async_trait;
12use http::StatusCode;
13use serde_json::{json, Value};
14
15use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
16
17use crate::state::{DynamoTable, SharedDynamoDbState};
18
19pub struct DynamoDbStreamsService {
20    state: SharedDynamoDbState,
21}
22
23impl DynamoDbStreamsService {
24    pub fn new(state: SharedDynamoDbState) -> Self {
25        Self { state }
26    }
27}
28
29#[async_trait]
30impl AwsService for DynamoDbStreamsService {
31    fn service_name(&self) -> &str {
32        "dynamodbstreams"
33    }
34
35    async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
36        let body: Value = serde_json::from_slice(&req.body).unwrap_or_default();
37        match req.action.as_str() {
38            "ListStreams" => self.list_streams(&req, &body),
39            "DescribeStream" => self.describe_stream(&req, &body),
40            "GetShardIterator" => self.get_shard_iterator(&req, &body),
41            "GetRecords" => self.get_records(&req, &body),
42            _ => Err(AwsServiceError::action_not_implemented(
43                "dynamodbstreams",
44                &req.action,
45            )),
46        }
47    }
48
49    fn supported_actions(&self) -> &[&str] {
50        &[
51            "ListStreams",
52            "DescribeStream",
53            "GetShardIterator",
54            "GetRecords",
55        ]
56    }
57}
58
59impl DynamoDbStreamsService {
60    fn list_streams(&self, req: &AwsRequest, body: &Value) -> Result<AwsResponse, AwsServiceError> {
61        let table_filter = body["TableName"].as_str();
62        let accounts = self.state.read();
63        let state = match accounts.get(&req.account_id) {
64            Some(s) => s,
65            None => return Ok(AwsResponse::ok_json(json!({ "Streams": [] }))),
66        };
67        let mut streams = Vec::new();
68        for table in state.tables.values() {
69            if let Some(name) = table_filter {
70                if table.name != name {
71                    continue;
72                }
73            }
74            if !table.stream_enabled {
75                continue;
76            }
77            let Some(arn) = table.stream_arn.as_ref() else {
78                continue;
79            };
80            let label = stream_label(arn);
81            streams.push(json!({
82                "StreamArn": arn,
83                "TableName": table.name,
84                "StreamLabel": label,
85            }));
86        }
87        Ok(AwsResponse::ok_json(json!({ "Streams": streams })))
88    }
89
90    fn describe_stream(
91        &self,
92        req: &AwsRequest,
93        body: &Value,
94    ) -> Result<AwsResponse, AwsServiceError> {
95        let stream_arn = require_string(body, "StreamArn")?;
96        let accounts = self.state.read();
97        let state = accounts
98            .get(&req.account_id)
99            .ok_or_else(|| not_found("Stream", &stream_arn))?;
100        let table = state
101            .tables
102            .values()
103            .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
104            .ok_or_else(|| not_found("Stream", &stream_arn))?;
105
106        let view_type = table
107            .stream_view_type
108            .clone()
109            .unwrap_or_else(|| "NEW_AND_OLD_IMAGES".to_string());
110        let label = stream_label(&stream_arn);
111        let key_schema: Vec<Value> = table
112            .key_schema
113            .iter()
114            .map(|k| {
115                json!({
116                    "AttributeName": k.attribute_name,
117                    "KeyType": k.key_type,
118                })
119            })
120            .collect();
121
122        let body = json!({
123            "StreamDescription": {
124                "StreamArn": stream_arn,
125                "StreamLabel": label,
126                "StreamStatus": "ENABLED",
127                "StreamViewType": view_type,
128                "CreationRequestDateTime": table.created_at.timestamp() as f64,
129                "TableName": table.name,
130                "KeySchema": key_schema,
131                "Shards": [{
132                    "ShardId": "shardId-00000000000000000000-00000001",
133                    "SequenceNumberRange": {
134                        "StartingSequenceNumber": "0",
135                    },
136                }],
137            }
138        });
139        Ok(AwsResponse::ok_json(body))
140    }
141
142    fn get_shard_iterator(
143        &self,
144        req: &AwsRequest,
145        body: &Value,
146    ) -> Result<AwsResponse, AwsServiceError> {
147        let stream_arn = require_string(body, "StreamArn")?;
148        let shard_id = require_string(body, "ShardId")?;
149        let iterator_type = require_string(body, "ShardIteratorType")?;
150
151        let accounts = self.state.read();
152        let state = accounts
153            .get(&req.account_id)
154            .ok_or_else(|| not_found("Stream", &stream_arn))?;
155        let table = state
156            .tables
157            .values()
158            .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
159            .ok_or_else(|| not_found("Stream", &stream_arn))?;
160
161        let records = table.stream_records.read();
162        let start_index: usize = match iterator_type.as_str() {
163            "TRIM_HORIZON" => 0,
164            "LATEST" => records.len(),
165            "AT_SEQUENCE_NUMBER" => {
166                let seq = require_string(body, "SequenceNumber")?;
167                records
168                    .iter()
169                    .position(|r| r.dynamodb.sequence_number == seq)
170                    .ok_or_else(|| invalid_argument("SequenceNumber not found"))?
171            }
172            "AFTER_SEQUENCE_NUMBER" => {
173                let seq = require_string(body, "SequenceNumber")?;
174                let idx = records
175                    .iter()
176                    .position(|r| r.dynamodb.sequence_number == seq)
177                    .ok_or_else(|| invalid_argument("SequenceNumber not found"))?;
178                idx + 1
179            }
180            other => {
181                return Err(invalid_argument(&format!(
182                    "Unsupported ShardIteratorType: {other}",
183                )))
184            }
185        };
186
187        let token = format!("{stream_arn}|{shard_id}|{start_index}");
188        Ok(AwsResponse::ok_json(json!({ "ShardIterator": token })))
189    }
190
191    fn get_records(&self, req: &AwsRequest, body: &Value) -> Result<AwsResponse, AwsServiceError> {
192        let iterator = require_string(body, "ShardIterator")?;
193        let limit = body["Limit"].as_u64().unwrap_or(1000) as usize;
194
195        let parts: Vec<&str> = iterator.splitn(3, '|').collect();
196        if parts.len() != 3 {
197            return Err(invalid_argument("ShardIterator is invalid"));
198        }
199        let stream_arn = parts[0].to_string();
200        let shard_id = parts[1].to_string();
201        let start_index: usize = parts[2]
202            .parse()
203            .map_err(|_| invalid_argument("ShardIterator is invalid"))?;
204
205        let accounts = self.state.read();
206        let state = accounts
207            .get(&req.account_id)
208            .ok_or_else(|| not_found("Stream", &stream_arn))?;
209        let table = state
210            .tables
211            .values()
212            .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
213            .ok_or_else(|| not_found("Stream", &stream_arn))?;
214
215        let records = table.stream_records.read();
216        let end_index = records.len().min(start_index.saturating_add(limit));
217        let records_json: Vec<Value> = records[start_index..end_index]
218            .iter()
219            .map(|r| stream_record_to_json(r, table))
220            .collect();
221
222        let next_token = format!("{stream_arn}|{shard_id}|{end_index}");
223        Ok(AwsResponse::ok_json(json!({
224            "Records": records_json,
225            "NextShardIterator": next_token,
226        })))
227    }
228}
229
230fn stream_record_to_json(r: &crate::state::StreamRecord, table: &DynamoTable) -> Value {
231    let mut dynamodb = json!({
232        "ApproximateCreationDateTime": r.timestamp.timestamp() as f64,
233        "Keys": &r.dynamodb.keys,
234        "SequenceNumber": r.dynamodb.sequence_number,
235        "SizeBytes": r.dynamodb.size_bytes,
236        "StreamViewType": r.dynamodb.stream_view_type,
237    });
238    if let Some(ni) = r.dynamodb.new_image.as_ref() {
239        dynamodb["NewImage"] = json!(ni);
240    }
241    if let Some(oi) = r.dynamodb.old_image.as_ref() {
242        dynamodb["OldImage"] = json!(oi);
243    }
244    json!({
245        "eventID": r.event_id,
246        "eventName": r.event_name,
247        "eventVersion": r.event_version,
248        "eventSource": r.event_source,
249        "awsRegion": r.aws_region,
250        "eventSourceARN": table.stream_arn.clone().unwrap_or_default(),
251        "dynamodb": dynamodb,
252    })
253}
254
255fn stream_label(stream_arn: &str) -> String {
256    stream_arn.rsplit('/').next().unwrap_or("").to_string()
257}
258
259fn require_string(body: &Value, field: &str) -> Result<String, AwsServiceError> {
260    body[field]
261        .as_str()
262        .map(|s| s.to_string())
263        .ok_or_else(|| invalid_argument(&format!("{field} is required")))
264}
265
266fn invalid_argument(msg: &str) -> AwsServiceError {
267    AwsServiceError::aws_error(StatusCode::BAD_REQUEST, "ValidationException", msg)
268}
269
270fn not_found(kind: &str, target: &str) -> AwsServiceError {
271    AwsServiceError::aws_error(
272        StatusCode::BAD_REQUEST,
273        "ResourceNotFoundException",
274        format!("{kind} not found: {target}"),
275    )
276}
277
278pub fn shared(state: SharedDynamoDbState) -> Arc<dyn AwsService> {
279    Arc::new(DynamoDbStreamsService::new(state))
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285    use crate::state::{DynamoDbStreamRecord, DynamoTable, ProvisionedThroughput, StreamRecord};
286    use bytes::Bytes;
287    use chrono::Utc;
288    use http::{HeaderMap, Method};
289    use parking_lot::RwLock;
290    use std::collections::{BTreeMap, HashMap};
291    use std::sync::Arc;
292
293    fn make_state() -> SharedDynamoDbState {
294        Arc::new(RwLock::new(
295            fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
296        ))
297    }
298
299    fn req(action: &str, body: Value) -> AwsRequest {
300        AwsRequest {
301            service: "dynamodbstreams".into(),
302            action: action.into(),
303            region: "us-east-1".into(),
304            account_id: "123456789012".into(),
305            request_id: "r".into(),
306            headers: HeaderMap::new(),
307            query_params: HashMap::new(),
308            body: Bytes::from(serde_json::to_vec(&body).unwrap()),
309            body_stream: parking_lot::Mutex::new(None),
310            path_segments: vec![],
311            raw_path: "/".into(),
312            raw_query: String::new(),
313            method: Method::POST,
314            is_query_protocol: false,
315            access_key_id: None,
316            principal: None,
317        }
318    }
319
320    fn seed_table(state: &SharedDynamoDbState) -> String {
321        let mut accts = state.write();
322        let s = accts.get_or_create("123456789012");
323        let arn =
324            "arn:aws:dynamodb:us-east-1:123456789012:table/widgets/stream/2026-05-03T00:00:00.000"
325                .to_string();
326        let table = DynamoTable {
327            name: "widgets".to_string(),
328            arn: "arn:aws:dynamodb:us-east-1:123456789012:table/widgets".to_string(),
329            table_id: "id".to_string(),
330            key_schema: Vec::new(),
331            attribute_definitions: Vec::new(),
332            provisioned_throughput: ProvisionedThroughput {
333                read_capacity_units: 0,
334                write_capacity_units: 0,
335            },
336            items: Vec::new(),
337            gsi: Vec::new(),
338            lsi: Vec::new(),
339            tags: BTreeMap::new(),
340            created_at: Utc::now(),
341            status: "ACTIVE".to_string(),
342            item_count: 0,
343            size_bytes: 0,
344            billing_mode: "PAY_PER_REQUEST".to_string(),
345            ttl_attribute: None,
346            ttl_enabled: false,
347            resource_policy: None,
348            pitr_enabled: false,
349            kinesis_destinations: Vec::new(),
350            contributor_insights_status: "DISABLED".to_string(),
351            contributor_insights_counters: BTreeMap::new(),
352            stream_enabled: true,
353            stream_view_type: Some("NEW_AND_OLD_IMAGES".to_string()),
354            stream_arn: Some(arn.clone()),
355            stream_records: Arc::new(RwLock::new(Vec::new())),
356            sse_type: None,
357            sse_kms_key_arn: None,
358            deletion_protection_enabled: false,
359            on_demand_throughput: None,
360        };
361        let rec = StreamRecord {
362            event_id: "e1".into(),
363            event_name: "INSERT".into(),
364            event_version: "1.1".into(),
365            event_source: "aws:dynamodb".into(),
366            aws_region: "us-east-1".into(),
367            event_source_arn: arn.clone(),
368            timestamp: Utc::now(),
369            dynamodb: DynamoDbStreamRecord {
370                keys: HashMap::new(),
371                new_image: Some(HashMap::new()),
372                old_image: None,
373                sequence_number: "1".into(),
374                size_bytes: 16,
375                stream_view_type: "NEW_AND_OLD_IMAGES".into(),
376            },
377        };
378        table.stream_records.write().push(rec);
379        s.tables.insert("widgets".to_string(), table);
380        arn
381    }
382
383    #[tokio::test]
384    async fn list_streams_returns_enabled_streams() {
385        let state = make_state();
386        let arn = seed_table(&state);
387        let svc = DynamoDbStreamsService::new(state);
388        let resp = svc.handle(req("ListStreams", json!({}))).await.unwrap();
389        let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
390        let streams = body["Streams"].as_array().unwrap();
391        assert_eq!(streams.len(), 1);
392        assert_eq!(streams[0]["StreamArn"].as_str().unwrap(), arn);
393    }
394
395    #[tokio::test]
396    async fn describe_stream_returns_shard() {
397        let state = make_state();
398        let arn = seed_table(&state);
399        let svc = DynamoDbStreamsService::new(state);
400        let resp = svc
401            .handle(req("DescribeStream", json!({"StreamArn": arn})))
402            .await
403            .unwrap();
404        let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
405        let desc = &body["StreamDescription"];
406        assert_eq!(desc["StreamStatus"].as_str().unwrap(), "ENABLED");
407        assert_eq!(desc["Shards"].as_array().unwrap().len(), 1);
408    }
409
410    #[tokio::test]
411    async fn get_records_round_trip_via_shard_iterator() {
412        let state = make_state();
413        let arn = seed_table(&state);
414        let svc = DynamoDbStreamsService::new(state);
415        let it_resp = svc
416            .handle(req(
417                "GetShardIterator",
418                json!({
419                    "StreamArn": arn,
420                    "ShardId": "shardId-00000000000000000000-00000001",
421                    "ShardIteratorType": "TRIM_HORIZON",
422                }),
423            ))
424            .await
425            .unwrap();
426        let it_body: Value = serde_json::from_slice(it_resp.body.expect_bytes()).unwrap();
427        let iterator = it_body["ShardIterator"].as_str().unwrap().to_string();
428        let resp = svc
429            .handle(req("GetRecords", json!({"ShardIterator": iterator})))
430            .await
431            .unwrap();
432        let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
433        let recs = body["Records"].as_array().unwrap();
434        assert_eq!(recs.len(), 1);
435        assert_eq!(recs[0]["eventName"].as_str().unwrap(), "INSERT");
436    }
437
438    #[tokio::test]
439    async fn describe_stream_unknown_arn_404s() {
440        let state = make_state();
441        let _ = seed_table(&state);
442        let svc = DynamoDbStreamsService::new(state);
443        let err = svc
444            .handle(req(
445                "DescribeStream",
446                json!({"StreamArn": "arn:aws:dynamodb:us-east-1:123456789012:table/missing/stream/x"}),
447            ))
448            .await
449            .err()
450            .expect("expected ResourceNotFound");
451        assert!(format!("{:?}", err).contains("ResourceNotFoundException"));
452    }
453}