1use std::sync::Arc;
10
11use async_trait::async_trait;
12use http::StatusCode;
13use serde_json::{json, Value};
14
15use fakecloud_core::service::{AwsRequest, AwsResponse, AwsService, AwsServiceError};
16
17use crate::state::{DynamoTable, SharedDynamoDbState};
18
19pub struct DynamoDbStreamsService {
20 state: SharedDynamoDbState,
21}
22
23impl DynamoDbStreamsService {
24 pub fn new(state: SharedDynamoDbState) -> Self {
25 Self { state }
26 }
27}
28
29#[async_trait]
30impl AwsService for DynamoDbStreamsService {
31 fn service_name(&self) -> &str {
32 "dynamodbstreams"
33 }
34
35 async fn handle(&self, req: AwsRequest) -> Result<AwsResponse, AwsServiceError> {
36 let body: Value = serde_json::from_slice(&req.body).unwrap_or_default();
37 match req.action.as_str() {
38 "ListStreams" => self.list_streams(&req, &body),
39 "DescribeStream" => self.describe_stream(&req, &body),
40 "GetShardIterator" => self.get_shard_iterator(&req, &body),
41 "GetRecords" => self.get_records(&req, &body),
42 _ => Err(AwsServiceError::action_not_implemented(
43 "dynamodbstreams",
44 &req.action,
45 )),
46 }
47 }
48
49 fn supported_actions(&self) -> &[&str] {
50 &[
51 "ListStreams",
52 "DescribeStream",
53 "GetShardIterator",
54 "GetRecords",
55 ]
56 }
57}
58
59impl DynamoDbStreamsService {
60 fn list_streams(&self, req: &AwsRequest, body: &Value) -> Result<AwsResponse, AwsServiceError> {
61 let table_filter = body["TableName"].as_str();
62 let accounts = self.state.read();
63 let state = match accounts.get(&req.account_id) {
64 Some(s) => s,
65 None => return Ok(AwsResponse::ok_json(json!({ "Streams": [] }))),
66 };
67 let mut streams = Vec::new();
68 for table in state.tables.values() {
69 if let Some(name) = table_filter {
70 if table.name != name {
71 continue;
72 }
73 }
74 if !table.stream_enabled {
75 continue;
76 }
77 let Some(arn) = table.stream_arn.as_ref() else {
78 continue;
79 };
80 let label = stream_label(arn);
81 streams.push(json!({
82 "StreamArn": arn,
83 "TableName": table.name,
84 "StreamLabel": label,
85 }));
86 }
87 Ok(AwsResponse::ok_json(json!({ "Streams": streams })))
88 }
89
90 fn describe_stream(
91 &self,
92 req: &AwsRequest,
93 body: &Value,
94 ) -> Result<AwsResponse, AwsServiceError> {
95 let stream_arn = require_string(body, "StreamArn")?;
96 let accounts = self.state.read();
97 let state = accounts
98 .get(&req.account_id)
99 .ok_or_else(|| not_found("Stream", &stream_arn))?;
100 let table = state
101 .tables
102 .values()
103 .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
104 .ok_or_else(|| not_found("Stream", &stream_arn))?;
105
106 let view_type = table
107 .stream_view_type
108 .clone()
109 .unwrap_or_else(|| "NEW_AND_OLD_IMAGES".to_string());
110 let label = stream_label(&stream_arn);
111 let key_schema: Vec<Value> = table
112 .key_schema
113 .iter()
114 .map(|k| {
115 json!({
116 "AttributeName": k.attribute_name,
117 "KeyType": k.key_type,
118 })
119 })
120 .collect();
121
122 let body = json!({
123 "StreamDescription": {
124 "StreamArn": stream_arn,
125 "StreamLabel": label,
126 "StreamStatus": "ENABLED",
127 "StreamViewType": view_type,
128 "CreationRequestDateTime": table.created_at.timestamp() as f64,
129 "TableName": table.name,
130 "KeySchema": key_schema,
131 "Shards": [{
132 "ShardId": "shardId-00000000000000000000-00000001",
133 "SequenceNumberRange": {
134 "StartingSequenceNumber": "0",
135 },
136 }],
137 }
138 });
139 Ok(AwsResponse::ok_json(body))
140 }
141
142 fn get_shard_iterator(
143 &self,
144 req: &AwsRequest,
145 body: &Value,
146 ) -> Result<AwsResponse, AwsServiceError> {
147 let stream_arn = require_string(body, "StreamArn")?;
148 let shard_id = require_string(body, "ShardId")?;
149 let iterator_type = require_string(body, "ShardIteratorType")?;
150
151 let accounts = self.state.read();
152 let state = accounts
153 .get(&req.account_id)
154 .ok_or_else(|| not_found("Stream", &stream_arn))?;
155 let table = state
156 .tables
157 .values()
158 .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
159 .ok_or_else(|| not_found("Stream", &stream_arn))?;
160
161 let records = table.stream_records.read();
162 let start_index: usize = match iterator_type.as_str() {
163 "TRIM_HORIZON" => 0,
164 "LATEST" => records.len(),
165 "AT_SEQUENCE_NUMBER" => {
166 let seq = require_string(body, "SequenceNumber")?;
167 records
168 .iter()
169 .position(|r| r.dynamodb.sequence_number == seq)
170 .ok_or_else(|| invalid_argument("SequenceNumber not found"))?
171 }
172 "AFTER_SEQUENCE_NUMBER" => {
173 let seq = require_string(body, "SequenceNumber")?;
174 let idx = records
175 .iter()
176 .position(|r| r.dynamodb.sequence_number == seq)
177 .ok_or_else(|| invalid_argument("SequenceNumber not found"))?;
178 idx + 1
179 }
180 other => {
181 return Err(invalid_argument(&format!(
182 "Unsupported ShardIteratorType: {other}",
183 )))
184 }
185 };
186
187 let token = format!("{stream_arn}|{shard_id}|{start_index}");
188 Ok(AwsResponse::ok_json(json!({ "ShardIterator": token })))
189 }
190
191 fn get_records(&self, req: &AwsRequest, body: &Value) -> Result<AwsResponse, AwsServiceError> {
192 let iterator = require_string(body, "ShardIterator")?;
193 let limit = body["Limit"].as_u64().unwrap_or(1000) as usize;
194
195 let parts: Vec<&str> = iterator.splitn(3, '|').collect();
196 if parts.len() != 3 {
197 return Err(invalid_argument("ShardIterator is invalid"));
198 }
199 let stream_arn = parts[0].to_string();
200 let shard_id = parts[1].to_string();
201 let start_index: usize = parts[2]
202 .parse()
203 .map_err(|_| invalid_argument("ShardIterator is invalid"))?;
204
205 let accounts = self.state.read();
206 let state = accounts
207 .get(&req.account_id)
208 .ok_or_else(|| not_found("Stream", &stream_arn))?;
209 let table = state
210 .tables
211 .values()
212 .find(|t| t.stream_arn.as_deref() == Some(stream_arn.as_str()))
213 .ok_or_else(|| not_found("Stream", &stream_arn))?;
214
215 let records = table.stream_records.read();
216 let end_index = records.len().min(start_index.saturating_add(limit));
217 let records_json: Vec<Value> = records[start_index..end_index]
218 .iter()
219 .map(|r| stream_record_to_json(r, table))
220 .collect();
221
222 let next_token = format!("{stream_arn}|{shard_id}|{end_index}");
223 Ok(AwsResponse::ok_json(json!({
224 "Records": records_json,
225 "NextShardIterator": next_token,
226 })))
227 }
228}
229
230fn stream_record_to_json(r: &crate::state::StreamRecord, table: &DynamoTable) -> Value {
231 let mut dynamodb = json!({
232 "ApproximateCreationDateTime": r.timestamp.timestamp() as f64,
233 "Keys": &r.dynamodb.keys,
234 "SequenceNumber": r.dynamodb.sequence_number,
235 "SizeBytes": r.dynamodb.size_bytes,
236 "StreamViewType": r.dynamodb.stream_view_type,
237 });
238 if let Some(ni) = r.dynamodb.new_image.as_ref() {
239 dynamodb["NewImage"] = json!(ni);
240 }
241 if let Some(oi) = r.dynamodb.old_image.as_ref() {
242 dynamodb["OldImage"] = json!(oi);
243 }
244 json!({
245 "eventID": r.event_id,
246 "eventName": r.event_name,
247 "eventVersion": r.event_version,
248 "eventSource": r.event_source,
249 "awsRegion": r.aws_region,
250 "eventSourceARN": table.stream_arn.clone().unwrap_or_default(),
251 "dynamodb": dynamodb,
252 })
253}
254
255fn stream_label(stream_arn: &str) -> String {
256 stream_arn.rsplit('/').next().unwrap_or("").to_string()
257}
258
259fn require_string(body: &Value, field: &str) -> Result<String, AwsServiceError> {
260 body[field]
261 .as_str()
262 .map(|s| s.to_string())
263 .ok_or_else(|| invalid_argument(&format!("{field} is required")))
264}
265
266fn invalid_argument(msg: &str) -> AwsServiceError {
267 AwsServiceError::aws_error(StatusCode::BAD_REQUEST, "ValidationException", msg)
268}
269
270fn not_found(kind: &str, target: &str) -> AwsServiceError {
271 AwsServiceError::aws_error(
272 StatusCode::BAD_REQUEST,
273 "ResourceNotFoundException",
274 format!("{kind} not found: {target}"),
275 )
276}
277
278pub fn shared(state: SharedDynamoDbState) -> Arc<dyn AwsService> {
279 Arc::new(DynamoDbStreamsService::new(state))
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285 use crate::state::{DynamoDbStreamRecord, DynamoTable, ProvisionedThroughput, StreamRecord};
286 use bytes::Bytes;
287 use chrono::Utc;
288 use http::{HeaderMap, Method};
289 use parking_lot::RwLock;
290 use std::collections::{BTreeMap, HashMap};
291 use std::sync::Arc;
292
293 fn make_state() -> SharedDynamoDbState {
294 Arc::new(RwLock::new(
295 fakecloud_core::multi_account::MultiAccountState::new("123456789012", "us-east-1", ""),
296 ))
297 }
298
299 fn req(action: &str, body: Value) -> AwsRequest {
300 AwsRequest {
301 service: "dynamodbstreams".into(),
302 action: action.into(),
303 region: "us-east-1".into(),
304 account_id: "123456789012".into(),
305 request_id: "r".into(),
306 headers: HeaderMap::new(),
307 query_params: HashMap::new(),
308 body: Bytes::from(serde_json::to_vec(&body).unwrap()),
309 body_stream: parking_lot::Mutex::new(None),
310 path_segments: vec![],
311 raw_path: "/".into(),
312 raw_query: String::new(),
313 method: Method::POST,
314 is_query_protocol: false,
315 access_key_id: None,
316 principal: None,
317 }
318 }
319
320 fn seed_table(state: &SharedDynamoDbState) -> String {
321 let mut accts = state.write();
322 let s = accts.get_or_create("123456789012");
323 let arn =
324 "arn:aws:dynamodb:us-east-1:123456789012:table/widgets/stream/2026-05-03T00:00:00.000"
325 .to_string();
326 let table = DynamoTable {
327 name: "widgets".to_string(),
328 arn: "arn:aws:dynamodb:us-east-1:123456789012:table/widgets".to_string(),
329 table_id: "id".to_string(),
330 key_schema: Vec::new(),
331 attribute_definitions: Vec::new(),
332 provisioned_throughput: ProvisionedThroughput {
333 read_capacity_units: 0,
334 write_capacity_units: 0,
335 },
336 items: Vec::new(),
337 gsi: Vec::new(),
338 lsi: Vec::new(),
339 tags: BTreeMap::new(),
340 created_at: Utc::now(),
341 status: "ACTIVE".to_string(),
342 item_count: 0,
343 size_bytes: 0,
344 billing_mode: "PAY_PER_REQUEST".to_string(),
345 ttl_attribute: None,
346 ttl_enabled: false,
347 resource_policy: None,
348 pitr_enabled: false,
349 kinesis_destinations: Vec::new(),
350 contributor_insights_status: "DISABLED".to_string(),
351 contributor_insights_counters: BTreeMap::new(),
352 stream_enabled: true,
353 stream_view_type: Some("NEW_AND_OLD_IMAGES".to_string()),
354 stream_arn: Some(arn.clone()),
355 stream_records: Arc::new(RwLock::new(Vec::new())),
356 sse_type: None,
357 sse_kms_key_arn: None,
358 deletion_protection_enabled: false,
359 on_demand_throughput: None,
360 };
361 let rec = StreamRecord {
362 event_id: "e1".into(),
363 event_name: "INSERT".into(),
364 event_version: "1.1".into(),
365 event_source: "aws:dynamodb".into(),
366 aws_region: "us-east-1".into(),
367 event_source_arn: arn.clone(),
368 timestamp: Utc::now(),
369 dynamodb: DynamoDbStreamRecord {
370 keys: HashMap::new(),
371 new_image: Some(HashMap::new()),
372 old_image: None,
373 sequence_number: "1".into(),
374 size_bytes: 16,
375 stream_view_type: "NEW_AND_OLD_IMAGES".into(),
376 },
377 };
378 table.stream_records.write().push(rec);
379 s.tables.insert("widgets".to_string(), table);
380 arn
381 }
382
383 #[tokio::test]
384 async fn list_streams_returns_enabled_streams() {
385 let state = make_state();
386 let arn = seed_table(&state);
387 let svc = DynamoDbStreamsService::new(state);
388 let resp = svc.handle(req("ListStreams", json!({}))).await.unwrap();
389 let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
390 let streams = body["Streams"].as_array().unwrap();
391 assert_eq!(streams.len(), 1);
392 assert_eq!(streams[0]["StreamArn"].as_str().unwrap(), arn);
393 }
394
395 #[tokio::test]
396 async fn describe_stream_returns_shard() {
397 let state = make_state();
398 let arn = seed_table(&state);
399 let svc = DynamoDbStreamsService::new(state);
400 let resp = svc
401 .handle(req("DescribeStream", json!({"StreamArn": arn})))
402 .await
403 .unwrap();
404 let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
405 let desc = &body["StreamDescription"];
406 assert_eq!(desc["StreamStatus"].as_str().unwrap(), "ENABLED");
407 assert_eq!(desc["Shards"].as_array().unwrap().len(), 1);
408 }
409
410 #[tokio::test]
411 async fn get_records_round_trip_via_shard_iterator() {
412 let state = make_state();
413 let arn = seed_table(&state);
414 let svc = DynamoDbStreamsService::new(state);
415 let it_resp = svc
416 .handle(req(
417 "GetShardIterator",
418 json!({
419 "StreamArn": arn,
420 "ShardId": "shardId-00000000000000000000-00000001",
421 "ShardIteratorType": "TRIM_HORIZON",
422 }),
423 ))
424 .await
425 .unwrap();
426 let it_body: Value = serde_json::from_slice(it_resp.body.expect_bytes()).unwrap();
427 let iterator = it_body["ShardIterator"].as_str().unwrap().to_string();
428 let resp = svc
429 .handle(req("GetRecords", json!({"ShardIterator": iterator})))
430 .await
431 .unwrap();
432 let body: Value = serde_json::from_slice(resp.body.expect_bytes()).unwrap();
433 let recs = body["Records"].as_array().unwrap();
434 assert_eq!(recs.len(), 1);
435 assert_eq!(recs[0]["eventName"].as_str().unwrap(), "INSERT");
436 }
437
438 #[tokio::test]
439 async fn describe_stream_unknown_arn_404s() {
440 let state = make_state();
441 let _ = seed_table(&state);
442 let svc = DynamoDbStreamsService::new(state);
443 let err = svc
444 .handle(req(
445 "DescribeStream",
446 json!({"StreamArn": "arn:aws:dynamodb:us-east-1:123456789012:table/missing/stream/x"}),
447 ))
448 .await
449 .err()
450 .expect("expected ResourceNotFound");
451 assert!(format!("{:?}", err).contains("ResourceNotFoundException"));
452 }
453}