Skip to main content

dynoxide/actions/
batch_write_item.rs

1use crate::actions::helpers;
2use crate::errors::{DynoxideError, Result};
3use crate::storage_backend::StorageBackend;
4use crate::types::{self, AttributeValue, Item};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8#[derive(Debug, Default, Deserialize)]
9pub struct BatchWriteItemRequest {
10    #[serde(rename = "RequestItems")]
11    pub request_items: HashMap<String, Vec<WriteRequest>>,
12    #[serde(rename = "ReturnConsumedCapacity", default)]
13    pub return_consumed_capacity: Option<String>,
14    #[serde(rename = "ReturnItemCollectionMetrics", default)]
15    pub return_item_collection_metrics: Option<String>,
16}
17
18#[derive(Debug, Default, Deserialize)]
19pub struct WriteRequest {
20    #[serde(rename = "PutRequest", default)]
21    pub put_request: Option<PutRequest>,
22    #[serde(rename = "DeleteRequest", default)]
23    pub delete_request: Option<DeleteRequest>,
24}
25
26#[derive(Debug, Default, Deserialize)]
27pub struct PutRequest {
28    #[serde(rename = "Item")]
29    pub item: Item,
30}
31
32#[derive(Debug, Default, Deserialize)]
33pub struct DeleteRequest {
34    #[serde(rename = "Key")]
35    pub key: HashMap<String, AttributeValue>,
36}
37
38#[derive(Debug, Default, Serialize)]
39pub struct BatchWriteItemResponse {
40    #[serde(rename = "UnprocessedItems")]
41    pub unprocessed_items: HashMap<String, serde_json::Value>,
42    #[serde(rename = "ConsumedCapacity", skip_serializing_if = "Option::is_none")]
43    pub consumed_capacity: Option<Vec<crate::types::ConsumedCapacity>>,
44    #[serde(
45        rename = "ItemCollectionMetrics",
46        skip_serializing_if = "Option::is_none"
47    )]
48    pub item_collection_metrics: Option<HashMap<String, Vec<crate::types::ItemCollectionMetrics>>>,
49}
50
51pub async fn execute<S: StorageBackend>(
52    storage: &S,
53    mut request: BatchWriteItemRequest,
54) -> Result<BatchWriteItemResponse> {
55    const MAX_REQUEST_SIZE: usize = 16 * 1024 * 1024; // 16MB
56
57    // Validate RequestItems is not empty.
58    // AWS routes the empty-map case through a separate parameter-required path
59    // rather than the standard "N validation errors detected" envelope.
60    if request.request_items.is_empty() {
61        return Err(DynoxideError::ValidationException(
62            "The requestItems parameter is required for BatchWriteItem".to_string(),
63        ));
64    }
65
66    // Validate each table entry has at least one write request
67    for (table_name, wrs) in &request.request_items {
68        if wrs.is_empty() {
69            return Err(DynoxideError::ValidationException(format!(
70                "1 validation error detected: Value at 'requestItems.{table_name}.member' failed to satisfy constraint: Member must have length greater than or equal to 1"
71            )));
72        }
73    }
74
75    // Validate table name format for all tables before checking existence
76    for table_name in request.request_items.keys() {
77        crate::validation::validate_table_name(table_name)?;
78    }
79
80    // Validate total request count.
81    // AWS surfaces this as the standard "1 validation error detected" envelope
82    // and echoes the WriteRequest list inside `Value '{<table>=[<dump>]}'`. The
83    // conformance suite anchors a regex around the envelope and the constraint
84    // phrase but leaves the dump body unconstrained (because the AWS SDK's
85    // Java-toString shape adds new AttributeValue fields over time). We emit
86    // the table name verbatim and a Rust Debug dump of the WriteRequests so
87    // the envelope matches without coupling to a specific SDK version. If a
88    // future suite tightens the regex to pin the dump exactly, this site
89    // will need a follow-up change.
90    let total_requests: usize = request.request_items.values().map(|v| v.len()).sum();
91    if total_requests > 25 {
92        let empty: Vec<WriteRequest> = Vec::new();
93        let (table_name, requests) = request
94            .request_items
95            .iter()
96            .max_by_key(|(_, v)| v.len())
97            .map(|(name, v)| (name.as_str(), v))
98            .unwrap_or(("", &empty));
99        let dump = format!("{requests:?}");
100        return Err(DynoxideError::ValidationException(format!(
101            "1 validation error detected: Value '{{{table_name}=[{dump}]}}' at 'requestItems' failed to satisfy constraint: Map value must satisfy constraint: [Member must have length less than or equal to 25, Member must have length greater than or equal to 1]"
102        )));
103    }
104
105    // --- Pre-table validations ---
106    // DynamoDB validates attribute values, item size, and empty write requests
107    // BEFORE checking table existence.
108    for write_requests in request.request_items.values() {
109        for wr in write_requests {
110            if wr.put_request.is_none() && wr.delete_request.is_none() {
111                return Err(DynoxideError::ValidationException(
112                    "Supplied AttributeValue has more than one datatypes set, must contain exactly one of the supported datatypes".to_string(),
113                ));
114            }
115            if let Some(ref put_req) = wr.put_request {
116                // Validate attribute values (empty strings, empty sets, invalid numbers)
117                crate::validation::validate_item_attribute_values(&put_req.item)?;
118
119                // Validate item size before table lookup
120                let size = types::item_size(&put_req.item);
121                if size > types::MAX_ITEM_SIZE {
122                    return Err(DynoxideError::ValidationException(
123                        "Item size has exceeded the maximum allowed size".to_string(),
124                    ));
125                }
126            }
127            if let Some(ref del_req) = wr.delete_request {
128                crate::validation::validate_item_attribute_values(&del_req.key)?;
129            }
130        }
131    }
132
133    // Validate aggregate request size
134    let total_size: usize = request
135        .request_items
136        .values()
137        .flat_map(|wrs| wrs.iter())
138        .map(|wr| {
139            if let Some(ref put_req) = wr.put_request {
140                types::item_size(&put_req.item)
141            } else if let Some(ref del_req) = wr.delete_request {
142                types::item_size(&del_req.key)
143            } else {
144                0
145            }
146        })
147        .sum();
148    if total_size > MAX_REQUEST_SIZE {
149        return Err(DynoxideError::ValidationException(
150            "Item collection too large: aggregate size of items in BatchWriteItem exceeds 16MB limit".to_string(),
151        ));
152    }
153
154    // Validate: no duplicate keys across all operations
155    {
156        let mut seen_keys: std::collections::HashSet<(String, String, String)> =
157            std::collections::HashSet::new();
158        for (table_name, write_requests) in &request.request_items {
159            let meta = helpers::require_table_for_item_op(storage, table_name).await?;
160            let key_schema = helpers::parse_key_schema(&meta)?;
161            for wr in write_requests {
162                // Validate keys BEFORE extract_key_strings: that helper returns
163                // InternalServerError (HTTP 500) for a missing partition or sort
164                // key, but a key-less request is client input and must surface as
165                // a 400 ValidationException. Mirrors put_item.rs, which validates
166                // before extracting.
167                let key_item = if let Some(ref put) = wr.put_request {
168                    helpers::validate_item_keys_for_batch(&put.item, &key_schema, &meta)?;
169                    &put.item
170                } else if let Some(ref del) = wr.delete_request {
171                    helpers::validate_key_only(&del.key, &key_schema)?;
172                    &del.key
173                } else {
174                    continue;
175                };
176                let (pk, sk) = helpers::extract_key_strings(key_item, &key_schema)?;
177                let key = (table_name.clone(), pk, sk);
178                if !seen_keys.insert(key) {
179                    return Err(DynoxideError::ValidationException(
180                        "Provided list of item keys contains duplicates".to_string(),
181                    ));
182                }
183            }
184        }
185    }
186
187    // Track per-table GSI capacity and affected partition keys for deferred metrics
188    let mut table_gsi_units: HashMap<String, HashMap<String, f64>> = HashMap::new();
189    // Track per-table WCU (table-level, excludes GSI)
190    let mut table_wcu: HashMap<String, f64> = HashMap::new();
191    // Collect unique (table, pk_str, pk_attr, pk_value) for deferred metrics computation
192    let mut affected_partitions: Vec<(String, String, String, AttributeValue)> = Vec::new();
193
194    // OPTIMISATION: maintain_gsis_after_write/maintain_lsis_after_write each
195    // deserialise GSI/LSI definitions from JSON on every call. For batch writes
196    // of 25 items against one table, that's 50 redundant deserialise calls.
197    // A future improvement would hoist parse_gsi_defs/parse_lsi_defs to this
198    // level and pass pre-parsed defs into the maintenance functions.
199
200    for (table_name, write_requests) in &mut request.request_items {
201        let meta = helpers::require_table_for_item_op(storage, table_name).await?;
202        let key_schema = helpers::parse_key_schema(&meta)?;
203
204        for wr in write_requests {
205            if let Some(ref mut put_req) = wr.put_request {
206                // Validate keys
207                helpers::validate_item_keys_for_batch(&put_req.item, &key_schema, &meta)?;
208
209                // Validate attribute values (empty strings, empty sets)
210                crate::validation::validate_item_attribute_values(&put_req.item)?;
211
212                // Normalize sets (deduplication)
213                crate::validation::normalize_item_sets(&mut put_req.item);
214
215                // Validate item size
216                let size = types::item_size(&put_req.item);
217                if size > types::MAX_ITEM_SIZE {
218                    return Err(DynoxideError::ValidationException(
219                        "Item size has exceeded the maximum allowed size".to_string(),
220                    ));
221                }
222
223                // TODO: validation must precede this call -- if reaching this line, caller has already validated keys.
224                let (pk, sk) = helpers::extract_key_strings(&put_req.item, &key_schema)?;
225                let item_json = serde_json::to_string(&put_req.item)
226                    .map_err(|e| DynoxideError::InternalServerError(e.to_string()))?;
227                let hash_prefix = put_req
228                    .item
229                    .get(&key_schema.partition_key)
230                    .map(crate::storage::compute_hash_prefix)
231                    .unwrap_or_default();
232                // Base write + index fan-out + stream are one atomic unit per
233                // item: a mid-fan-out failure rolls this item's write back.
234                // BatchWriteItem items are independent, so this is one
235                // transaction per write request.
236                let gsi_units = helpers::with_write_transaction(storage, async {
237                    let old_json = storage
238                        .put_item_with_hash(table_name, &pk, &sk, &item_json, size, &hash_prefix)
239                        .await?;
240                    let gsi_units = super::gsi::maintain_gsis_after_write(
241                        storage,
242                        table_name,
243                        &meta,
244                        &pk,
245                        &sk,
246                        &put_req.item,
247                        &key_schema.partition_key,
248                        key_schema.sort_key.as_deref(),
249                    )
250                    .await?;
251                    super::lsi::maintain_lsis_after_write(
252                        storage,
253                        table_name,
254                        &meta,
255                        &pk,
256                        &sk,
257                        &put_req.item,
258                        &key_schema.partition_key,
259                        key_schema.sort_key.as_deref(),
260                    )
261                    .await?;
262                    let old_item: Option<Item> =
263                        old_json.and_then(|j| serde_json::from_str(&j).ok());
264                    crate::streams::record_stream_event(
265                        storage,
266                        &meta,
267                        old_item.as_ref(),
268                        Some(&put_req.item),
269                    )
270                    .await?;
271                    Ok(gsi_units)
272                })
273                .await?;
274
275                // Accumulate WCU based on item size
276                *table_wcu.entry(table_name.clone()).or_insert(0.0) +=
277                    types::write_capacity_units(size);
278
279                // Accumulate GSI units per table
280                let table_entry = table_gsi_units.entry(table_name.clone()).or_default();
281                for (gsi_name, units) in &gsi_units {
282                    *table_entry.entry(gsi_name.clone()).or_insert(0.0) += units;
283                }
284
285                // Track affected partition for deferred metrics
286                if let Some(pk_val) = put_req.item.get(&key_schema.partition_key) {
287                    affected_partitions.push((
288                        table_name.clone(),
289                        pk.clone(),
290                        key_schema.partition_key.clone(),
291                        pk_val.clone(),
292                    ));
293                }
294            } else if let Some(ref del_req) = wr.delete_request {
295                helpers::validate_key_only(&del_req.key, &key_schema)?;
296                // TODO: validation must precede this call -- if reaching this line, caller has already validated keys.
297                let (pk, sk) = helpers::extract_key_strings(&del_req.key, &key_schema)?;
298
299                // Base delete + index fan-out + stream are one atomic unit per item.
300                let (old_item, gsi_units) = helpers::with_write_transaction(storage, async {
301                    let old_json = storage.delete_item(table_name, &pk, &sk).await?;
302                    let old_item: Option<Item> =
303                        old_json.as_ref().and_then(|j| serde_json::from_str(j).ok());
304                    let gsi_units = super::gsi::maintain_gsis_after_delete(
305                        storage, table_name, &meta, &pk, &sk,
306                    )
307                    .await?;
308                    super::lsi::maintain_lsis_after_delete(storage, table_name, &meta, &pk, &sk)
309                        .await?;
310                    if old_item.is_some() {
311                        crate::streams::record_stream_event(
312                            storage,
313                            &meta,
314                            old_item.as_ref(),
315                            None,
316                        )
317                        .await?;
318                    }
319                    Ok((old_item, gsi_units))
320                })
321                .await?;
322
323                // Accumulate WCU: based on old item size if it existed, else 1 WCU
324                let delete_wcu = if let Some(ref old) = old_item {
325                    types::write_capacity_units(types::item_size(old))
326                } else {
327                    1.0
328                };
329                *table_wcu.entry(table_name.clone()).or_insert(0.0) += delete_wcu;
330
331                // Accumulate GSI units per table
332                let table_entry = table_gsi_units.entry(table_name.clone()).or_default();
333                for (gsi_name, units) in &gsi_units {
334                    *table_entry.entry(gsi_name.clone()).or_insert(0.0) += units;
335                }
336
337                // Track affected partition for deferred metrics
338                if let Some(pk_val) = del_req.key.get(&key_schema.partition_key) {
339                    affected_partitions.push((
340                        table_name.clone(),
341                        pk.clone(),
342                        key_schema.partition_key.clone(),
343                        pk_val.clone(),
344                    ));
345                }
346            } else {
347                return Err(DynoxideError::ValidationException(
348                    "WriteRequest must contain either PutRequest or DeleteRequest".to_string(),
349                ));
350            }
351        }
352    }
353
354    // Build consumed capacity per table using pre-tracked WCU
355    let consumed_capacity = if matches!(
356        request.return_consumed_capacity.as_deref(),
357        Some("TOTAL") | Some("INDEXES")
358    ) {
359        let mut caps = Vec::new();
360        for table_name in request.request_items.keys() {
361            let total_wcu = table_wcu.get(table_name).copied().unwrap_or(0.0);
362            let gsi_units = table_gsi_units.get(table_name).cloned().unwrap_or_default();
363            if let Some(cc) = crate::types::consumed_capacity_with_indexes(
364                table_name,
365                total_wcu,
366                &gsi_units,
367                &request.return_consumed_capacity,
368            ) {
369                caps.push(cc);
370            }
371        }
372        Some(caps)
373    } else {
374        None
375    };
376
377    // Compute item collection metrics once per unique (table, pk) — deferred from the write loop
378    let mut all_item_collection_metrics: HashMap<String, Vec<crate::types::ItemCollectionMetrics>> =
379        HashMap::new();
380    if matches!(
381        request.return_item_collection_metrics.as_deref(),
382        Some("SIZE")
383    ) {
384        // Deduplicate by (table, pk) to avoid redundant queries
385        let mut seen = std::collections::HashSet::new();
386        for (tbl, pk_str, pk_attr, pk_val) in &affected_partitions {
387            let key = (tbl.as_str(), pk_str.as_str());
388            if !seen.insert(key) {
389                continue;
390            }
391            let meta = helpers::require_table(storage, tbl).await?;
392            if let Some(icm) = helpers::build_item_collection_metrics(
393                storage,
394                &meta,
395                tbl,
396                pk_str,
397                pk_attr,
398                pk_val,
399                &request.return_item_collection_metrics,
400            )
401            .await?
402            {
403                all_item_collection_metrics
404                    .entry(tbl.clone())
405                    .or_default()
406                    .push(icm);
407            }
408        }
409    }
410    let item_collection_metrics = if all_item_collection_metrics.is_empty() {
411        None
412    } else {
413        Some(all_item_collection_metrics)
414    };
415
416    Ok(BatchWriteItemResponse {
417        unprocessed_items: HashMap::new(),
418        consumed_capacity,
419        item_collection_metrics,
420    })
421}
422
423#[cfg(test)]
424mod tests {
425    use crate::actions::{batch_write_item, create_table};
426    use crate::storage::Storage;
427    use crate::storage_backend::StorageBackend;
428
429    /// Each batch put is atomic with its own GSI fan-out: a mid-fan-out failure
430    /// rolls that item's base write back rather than leaving a torn index.
431    #[test]
432    fn batch_put_rolls_back_base_write_when_gsi_fan_out_fails() {
433        let storage = Storage::memory().unwrap();
434
435        let create = serde_json::from_value(serde_json::json!({
436            "TableName": "Orders",
437            "KeySchema": [{"AttributeName": "UserId", "KeyType": "HASH"}],
438            "AttributeDefinitions": [
439                {"AttributeName": "UserId", "AttributeType": "S"},
440                {"AttributeName": "Status", "AttributeType": "S"},
441                {"AttributeName": "Priority", "AttributeType": "S"}
442            ],
443            "GlobalSecondaryIndexes": [
444                {"IndexName": "StatusIndex", "KeySchema": [{"AttributeName": "Status", "KeyType": "HASH"}], "Projection": {"ProjectionType": "ALL"}},
445                {"IndexName": "PriorityIndex", "KeySchema": [{"AttributeName": "Priority", "KeyType": "HASH"}], "Projection": {"ProjectionType": "ALL"}}
446            ]
447        }))
448        .unwrap();
449        pollster::block_on(create_table::execute(&storage, create)).unwrap();
450
451        // Break the second GSI's fan-out by dropping its physical table.
452        storage.drop_gsi_table("Orders", "PriorityIndex").unwrap();
453
454        let batch = serde_json::from_value(serde_json::json!({
455            "RequestItems": {
456                "Orders": [
457                    {"PutRequest": {"Item": {"UserId": {"S": "u1"}, "Status": {"S": "SHIPPED"}, "Priority": {"S": "HIGH"}}}}
458                ]
459            }
460        }))
461        .unwrap();
462        let res = pollster::block_on(batch_write_item::execute(&storage, batch));
463        assert!(
464            res.is_err(),
465            "a mid-fan-out failure must surface as an error"
466        );
467
468        // The base write must roll back: no item landed.
469        let count =
470            pollster::block_on(<Storage as StorageBackend>::count_items(&storage, "Orders"))
471                .unwrap();
472        assert_eq!(
473            count, 0,
474            "batch put base write must roll back when fan-out fails"
475        );
476    }
477}