1use std::sync::Arc;
10
11use async_trait::async_trait;
12use http::StatusCode;
13use serde_json::{json, Value};
14
15use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
16
17use crate::state::{DynamoTable, SharedDynamoDbState};
18
19pub struct DynamoDbStreamsService {
20 state: SharedDynamoDbState,
21}
22
23impl DynamoDbStreamsService {
24 pub fn new(state: SharedDynamoDbState) -> Self {
25 Self { state }
26 }
27}
28
29#[async_trait]
30impl AwsService for DynamoDbStreamsService {
31 fn service_name(&self) -> &str {
32 "dynamodbstreams"
33 }
34
35 async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
36 let body: Value = serde_json::from_slice(&req.body).unwrap_or_default();
37 match req.action.as_str() {
38 "ListStreams" => self.list_streams(&req, &body),
39 "DescribeStream" => self.describe_stream(&req, &body),
40 "GetShardIterator" => self.get_shard_iterator(&req, &body),
41 "GetRecords" => self.get_records(&req, &body),
42 _ => Err(AwsServiceError::action_not_implemented(
43 "dynamodbstreams",
44 &req.action,
45 )),
46 }
47 }
48
49 fn supported_actions(&self) -> &[&str] {
50 &[
51 "ListStreams",
52 "DescribeStream",
53 "GetShardIterator",
54 "GetRecords",
55 ]
56 }
57}
58
59impl DynamoDbStreamsService {
60 fn list_streams(&self, req: &AwsRequest, body: &Value) -> Result<AwsResponse, AwsServiceError> {
61 let table_filter = body["TableName"].as_str();
62 let accounts = self.state.read();
63 let state = match accounts.get(&req.account_id) {
64 Some(s) => s,
65 None => return Ok(AwsResponse::ok_json(json!({ "Streams": [] }))),
66 };
67 let mut streams = Vec::new();
68 for table in state.tables.values() {
69 if let Some(name) = table_filter {
70 if table.name != name {
71 continue;
72 }
73 }
74 if !table.stream_enabled {
75 continue;
76 }
77 let Some(arn) = table.stream_arn.as_ref() else {
78 continue;
79 };
80 let label = stream_label(arn);
81 streams.push(json!({
82 "StreamArn": arn,
83 "TableName": table.name,
84 "StreamLabel": label,
85 }));
86 }
87 Ok(AwsResponse::ok_json(json!({ "Streams": streams })))
88 }
89
90 fn describe_stream(
91 &self,
92 req: &AwsRequest,
93 body: &Value,
94 ) -> Result<AwsResponse, AwsServiceError> {
95 let stream_arn = require_string(body, "StreamArn")?;
96 let accounts = self.state.read();
97 let state = accounts
98 .get(&req.account_id)
99 .ok_or_else(|| not_found("Stream", &stream_arn))?;
100 let table = state
101 .tables
102 .values()
103 .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
104 .ok_or_else(|| not_found("Stream", &stream_arn))?;
105
106 let view_type = table
107 .stream_view_type
108 .clone()
109 .unwrap_or_else(|| "NEW_AND_OLD_IMAGES".to_string());
110 let label = stream_label(&stream_arn);
111 let key_schema: Vec<Value> = table
112 .key_schema
113 .iter()
114 .map(|k| {
115 json!({
116 "AttributeName": k.attribute_name,
117 "KeyType": k.key_type,
118 })
119 })
120 .collect();
121
122 let body = json!({
123 "StreamDescription": {
124 "StreamArn": stream_arn,
125 "StreamLabel": label,
126 "StreamStatus": "ENABLED",
127 "StreamViewType": view_type,
128 "CreationRequestDateTime": table.created_at.timestamp() as f64,
129 "TableName": table.name,
130 "KeySchema": key_schema,
131 "Shards": [{
132 "ShardId": "shardId-00000000000000000000-00000001",
133 "SequenceNumberRange": {
134 "StartingSequenceNumber": "0",
135 },
136 }],
137 }
138 });
139 Ok(AwsResponse::ok_json(body))
140 }
141
142 fn get_shard_iterator(
143 &self,
144 req: &AwsRequest,
145 body: &Value,
146 ) -> Result<AwsResponse, AwsServiceError> {
147 let stream_arn = require_string(body, "StreamArn")?;
148 let shard_id = require_string(body, "ShardId")?;
149 let iterator_type = require_string(body, "ShardIteratorType")?;
150
151 let accounts = self.state.read();
152 let state = accounts
153 .get(&req.account_id)
154 .ok_or_else(|| not_found("Stream", &stream_arn))?;
155 let table = state
156 .tables
157 .values()
158 .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
159 .ok_or_else(|| not_found("Stream", &stream_arn))?;
160
161 let records = table.stream_records.read();
169 let after_seq: String = match iterator_type.as_str() {
170 "TRIM_HORIZON" => "0".to_string(),
173 "LATEST" => records
176 .iter()
177 .map(|r| r.dynamodb.sequence_number.clone())
178 .max_by(|a, b| cmp_seq(a, b))
179 .unwrap_or_else(|| "0".to_string()),
180 "AT_SEQUENCE_NUMBER" => {
182 let seq = require_string(body, "SequenceNumber")?;
183 if !records.iter().any(|r| r.dynamodb.sequence_number == seq) {
184 return Err(invalid_argument("SequenceNumber not found"));
185 }
186 exclusive_before(&seq)
187 }
188 "AFTER_SEQUENCE_NUMBER" => {
190 let seq = require_string(body, "SequenceNumber")?;
191 if !records.iter().any(|r| r.dynamodb.sequence_number == seq) {
192 return Err(invalid_argument("SequenceNumber not found"));
193 }
194 seq
195 }
196 other => {
197 return Err(invalid_argument(&format!(
198 "Unsupported ShardIteratorType: {other}",
199 )))
200 }
201 };
202
203 let token = format!("{stream_arn}|{shard_id}|{after_seq}");
204 Ok(AwsResponse::ok_json(json!({ "ShardIterator": token })))
205 }
206
207 fn get_records(&self, req: &AwsRequest, body: &Value) -> Result<AwsResponse, AwsServiceError> {
208 let iterator = require_string(body, "ShardIterator")?;
209 let limit = body["Limit"].as_u64().unwrap_or(1000) as usize;
210
211 let parts: Vec<&str> = iterator.splitn(3, '|').collect();
212 if parts.len() != 3 {
213 return Err(invalid_argument("ShardIterator is invalid"));
214 }
215 let stream_arn = parts[0].to_string();
216 let shard_id = parts[1].to_string();
217 let after_seq = parts[2].to_string();
222
223 let accounts = self.state.read();
224 let state = accounts
225 .get(&req.account_id)
226 .ok_or_else(|| not_found("Stream", &stream_arn))?;
227 let table = state
228 .tables
229 .values()
230 .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
231 .ok_or_else(|| not_found("Stream", &stream_arn))?;
232
233 let records = table.stream_records.read();
238 let selected: Vec<&crate::state::StreamRecord> = records
239 .iter()
240 .filter(|r| {
241 cmp_seq(&r.dynamodb.sequence_number, &after_seq) == std::cmp::Ordering::Greater
242 })
243 .take(limit)
244 .collect();
245
246 let next_seq = selected
247 .last()
248 .map(|r| r.dynamodb.sequence_number.clone())
249 .unwrap_or(after_seq);
250 let records_json: Vec<Value> = selected
251 .iter()
252 .map(|r| stream_record_to_json(r, table))
253 .collect();
254
255 let next_token = format!("{stream_arn}|{shard_id}|{next_seq}");
256 Ok(AwsResponse::ok_json(json!({
257 "Records": records_json,
258 "NextShardIterator": next_token,
259 })))
260 }
261}
262
263fn stream_record_to_json(r: &crate::state::StreamRecord, table: &DynamoTable) -> Value {
264 let mut dynamodb = json!({
265 "ApproximateCreationDateTime": r.timestamp.timestamp() as f64,
266 "Keys": &r.dynamodb.keys,
267 "SequenceNumber": r.dynamodb.sequence_number,
268 "SizeBytes": r.dynamodb.size_bytes,
269 "StreamViewType": r.dynamodb.stream_view_type,
270 });
271 if let Some(ni) = r.dynamodb.new_image.as_ref() {
272 dynamodb["NewImage"] = json!(ni);
273 }
274 if let Some(oi) = r.dynamodb.old_image.as_ref() {
275 dynamodb["OldImage"] = json!(oi);
276 }
277 let mut out = json!({
278 "eventID": r.event_id,
279 "eventName": r.event_name,
280 "eventVersion": r.event_version,
281 "eventSource": r.event_source,
282 "awsRegion": r.aws_region,
283 "eventSourceARN": table.stream_arn.clone().unwrap_or_default(),
284 "dynamodb": dynamodb,
285 });
286 if let Some(ui) = r.user_identity.as_ref() {
287 out["userIdentity"] = json!({
291 "PrincipalId": ui.principal_id,
292 "Type": ui.identity_type,
293 });
294 }
295 out
296}
297
298fn stream_label(stream_arn: &str) -> String {
299 stream_arn.rsplit('/').next().unwrap_or("").to_string()
300}
301
302fn cmp_seq(a: &str, b: &str) -> std::cmp::Ordering {
307 match (a.parse::<u128>(), b.parse::<u128>()) {
308 (Ok(x), Ok(y)) => x.cmp(&y),
309 _ => a.cmp(b),
311 }
312}
313
314fn exclusive_before(seq: &str) -> String {
319 match seq.parse::<u128>() {
320 Ok(n) if n > 0 => {
321 format!("{:0width$}", n - 1, width = seq.len())
324 }
325 _ => "0".to_string(),
326 }
327}
328
329fn require_string(body: &Value, field: &str) -> Result<String, AwsServiceError> {
330 body[field]
331 .as_str()
332 .map(|s| s.to_string())
333 .ok_or_else(|| invalid_argument(&format!("{field} is required")))
334}
335
336fn invalid_argument(msg: &str) -> AwsServiceError {
337 AwsServiceError::aws_error(StatusCode::BAD_REQUEST, "ValidationException", msg)
338}
339
340fn not_found(kind: &str, target: &str) -> AwsServiceError {
341 AwsServiceError::aws_error(
342 StatusCode::BAD_REQUEST,
343 "ResourceNotFoundException",
344 format!("{kind} not found: {target}"),
345 )
346}
347
348pub fn shared(state: SharedDynamoDbState) -> Arc<dyn AwsService> {
349 Arc::new(DynamoDbStreamsService::new(state))
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355 use crate::state::{DynamoDbStreamRecord, DynamoTable, ProvisionedThroughput, StreamRecord};
356 use bytes::Bytes;
357 use chrono::Utc;
358 use http::{HeaderMap, Method};
359 use parking_lot::RwLock;
360 use std::collections::{BTreeMap, HashMap};
361 use std::sync::Arc;
362
363 fn make_state() -> SharedDynamoDbState {
364 Arc::new(RwLock::new(
365 fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
366 ))
367 }
368
369 fn req(action: &str, body: Value) -> AwsRequest {
370 AwsRequest {
371 service: "dynamodbstreams".into(),
372 action: action.into(),
373 region: "us-east-1".into(),
374 account_id: "123456789012".into(),
375 request_id: "r".into(),
376 headers: HeaderMap::new(),
377 query_params: HashMap::new(),
378 body: Bytes::from(serde_json::to_vec(&body).unwrap()),
379 body_stream: parking_lot::Mutex::new(None),
380 path_segments: vec![],
381 raw_path: "/".into(),
382 raw_query: String::new(),
383 method: Method::POST,
384 is_query_protocol: false,
385 access_key_id: None,
386 principal: None,
387 }
388 }
389
390 fn seed_table(state: &SharedDynamoDbState) -> String {
391 let mut accts = state.write();
392 let s = accts.get_or_create("123456789012");
393 let arn =
394 "arn:aws:dynamodb:us-east-1:123456789012:table/widgets/stream/2026-05-03T00:00:00.000"
395 .to_string();
396 let table = DynamoTable {
397 name: "widgets".to_string(),
398 arn: "arn:aws:dynamodb:us-east-1:123456789012:table/widgets".to_string(),
399 table_id: "id".to_string(),
400 key_schema: Vec::new(),
401 attribute_definitions: Vec::new(),
402 provisioned_throughput: ProvisionedThroughput {
403 read_capacity_units: 0,
404 write_capacity_units: 0,
405 },
406 items: Vec::new(),
407 gsi: Vec::new(),
408 lsi: Vec::new(),
409 tags: BTreeMap::new(),
410 created_at: Utc::now(),
411 status: "ACTIVE".to_string(),
412 item_count: 0,
413 size_bytes: 0,
414 billing_mode: "PAY_PER_REQUEST".to_string(),
415 ttl_attribute: None,
416 ttl_enabled: false,
417 resource_policy: None,
418 pitr_enabled: false,
419 kinesis_destinations: Vec::new(),
420 contributor_insights_status: "DISABLED".to_string(),
421 contributor_insights_counters: BTreeMap::new(),
422 stream_enabled: true,
423 stream_view_type: Some("NEW_AND_OLD_IMAGES".to_string()),
424 stream_arn: Some(arn.clone()),
425 stream_records: Arc::new(RwLock::new(Vec::new())),
426 sse_type: None,
427 sse_kms_key_arn: None,
428 deletion_protection_enabled: false,
429 on_demand_throughput: None,
430 table_class: "STANDARD".to_string(),
431 };
432 let rec = StreamRecord {
433 event_id: "e1".into(),
434 event_name: "INSERT".into(),
435 event_version: "1.1".into(),
436 event_source: "aws:dynamodb".into(),
437 aws_region: "us-east-1".into(),
438 event_source_arn: arn.clone(),
439 timestamp: Utc::now(),
440 dynamodb: DynamoDbStreamRecord {
441 keys: HashMap::new(),
442 new_image: Some(HashMap::new()),
443 old_image: None,
444 sequence_number: "1".into(),
445 size_bytes: 16,
446 stream_view_type: "NEW_AND_OLD_IMAGES".into(),
447 },
448 user_identity: None,
449 };
450 table.stream_records.write().push(rec);
451 s.tables.insert("widgets".to_string(), table);
452 arn
453 }
454
455 #[tokio::test]
456 async fn list_streams_returns_enabled_streams() {
457 let state = make_state();
458 let arn = seed_table(&state);
459 let svc = DynamoDbStreamsService::new(state);
460 let resp = svc.handle(req("ListStreams", json!({}))).await.unwrap();
461 let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
462 let streams = body["Streams"].as_array().unwrap();
463 assert_eq!(streams.len(), 1);
464 assert_eq!(streams[0]["StreamArn"].as_str().unwrap(), arn);
465 }
466
467 #[tokio::test]
468 async fn describe_stream_returns_shard() {
469 let state = make_state();
470 let arn = seed_table(&state);
471 let svc = DynamoDbStreamsService::new(state);
472 let resp = svc
473 .handle(req("DescribeStream", json!({"StreamArn": arn})))
474 .await
475 .unwrap();
476 let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
477 let desc = &body["StreamDescription"];
478 assert_eq!(desc["StreamStatus"].as_str().unwrap(), "ENABLED");
479 assert_eq!(desc["Shards"].as_array().unwrap().len(), 1);
480 }
481
482 #[tokio::test]
483 async fn get_records_round_trip_via_shard_iterator() {
484 let state = make_state();
485 let arn = seed_table(&state);
486 let svc = DynamoDbStreamsService::new(state);
487 let it_resp = svc
488 .handle(req(
489 "GetShardIterator",
490 json!({
491 "StreamArn": arn,
492 "ShardId": "shardId-00000000000000000000-00000001",
493 "ShardIteratorType": "TRIM_HORIZON",
494 }),
495 ))
496 .await
497 .unwrap();
498 let it_body: Value = serde_json::from_slice(it_resp.body.expect_bytes()).unwrap();
499 let iterator = it_body["ShardIterator"].as_str().unwrap().to_string();
500 let resp = svc
501 .handle(req("GetRecords", json!({"ShardIterator": iterator})))
502 .await
503 .unwrap();
504 let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
505 let recs = body["Records"].as_array().unwrap();
506 assert_eq!(recs.len(), 1);
507 assert_eq!(recs[0]["eventName"].as_str().unwrap(), "INSERT");
508 }
509
510 fn push_record(state: &SharedDynamoDbState, seq: &str, age_hours: i64, event_id: &str) {
511 let mut accts = state.write();
512 let s = accts.get_or_create("123456789012");
513 let table = s.tables.get_mut("widgets").unwrap();
514 let rec = StreamRecord {
515 event_id: event_id.into(),
516 event_name: "INSERT".into(),
517 event_version: "1.1".into(),
518 event_source: "aws:dynamodb".into(),
519 aws_region: "us-east-1".into(),
520 event_source_arn: table.stream_arn.clone().unwrap(),
521 timestamp: Utc::now() - chrono::Duration::hours(age_hours),
522 dynamodb: DynamoDbStreamRecord {
523 keys: HashMap::new(),
524 new_image: Some(HashMap::new()),
525 old_image: None,
526 sequence_number: seq.into(),
527 size_bytes: 16,
528 stream_view_type: "NEW_AND_OLD_IMAGES".into(),
529 },
530 user_identity: None,
531 };
532 table.stream_records.write().push(rec);
533 }
534
535 fn trim_front(state: &SharedDynamoDbState, n: usize) {
536 let accts = state.read();
537 let s = accts.get("123456789012").unwrap();
538 let table = s.tables.get("widgets").unwrap();
539 let mut recs = table.stream_records.write();
540 for _ in 0..n {
541 if !recs.is_empty() {
542 recs.remove(0);
543 }
544 }
545 }
546
547 #[tokio::test]
552 async fn iterator_survives_front_trim_without_skip_or_replay() {
553 let state = make_state();
554 let arn = seed_table(&state); {
557 let accts = state.read();
558 let s = accts.get("123456789012").unwrap();
559 s.tables
560 .get("widgets")
561 .unwrap()
562 .stream_records
563 .write()
564 .clear();
565 }
566 for i in 1..=5u64 {
567 let age = if i <= 2 { 30 } else { 0 };
569 push_record(&state, &format!("{i:021}"), age, &format!("e{i}"));
570 }
571 let svc = DynamoDbStreamsService::new(state.clone());
572
573 let it_resp = svc
575 .handle(req(
576 "GetShardIterator",
577 json!({
578 "StreamArn": arn,
579 "ShardId": "shardId-00000000000000000000-00000001",
580 "ShardIteratorType": "TRIM_HORIZON",
581 }),
582 ))
583 .await
584 .unwrap();
585 let it: Value = serde_json::from_slice(it_resp.body.expect_bytes()).unwrap();
586 let iterator = it["ShardIterator"].as_str().unwrap().to_string();
587
588 let r1 = svc
589 .handle(req(
590 "GetRecords",
591 json!({"ShardIterator": iterator, "Limit": 3}),
592 ))
593 .await
594 .unwrap();
595 let b1: Value = serde_json::from_slice(r1.body.expect_bytes()).unwrap();
596 let recs1 = b1["Records"].as_array().unwrap();
597 assert_eq!(recs1.len(), 3);
598 assert_eq!(recs1[0]["eventID"].as_str().unwrap(), "e1");
599 assert_eq!(recs1[2]["eventID"].as_str().unwrap(), "e3");
600 let next = b1["NextShardIterator"].as_str().unwrap().to_string();
601
602 trim_front(&state, 2);
606
607 let r2 = svc
608 .handle(req("GetRecords", json!({"ShardIterator": next})))
609 .await
610 .unwrap();
611 let b2: Value = serde_json::from_slice(r2.body.expect_bytes()).unwrap();
612 let recs2 = b2["Records"].as_array().unwrap();
613 assert_eq!(
614 recs2.len(),
615 2,
616 "must return exactly the un-consumed records after a front trim"
617 );
618 assert_eq!(recs2[0]["eventID"].as_str().unwrap(), "e4");
619 assert_eq!(recs2[1]["eventID"].as_str().unwrap(), "e5");
620 }
621
622 #[tokio::test]
623 async fn describe_stream_unknown_arn_404s() {
624 let state = make_state();
625 let _ = seed_table(&state);
626 let svc = DynamoDbStreamsService::new(state);
627 let err = svc
628 .handle(req(
629 "DescribeStream",
630 json!({"StreamArn": "arn:aws:dynamodb:us-east-1:123456789012:table/missing/stream/x"}),
631 ))
632 .await
633 .err()
634 .expect("expected ResourceNotFound");
635 assert!(format!("{:?}", err).contains("ResourceNotFoundException"));
636 }
637}