Skip to main content

dynoxide/actions/
execute_transaction.rs

1use crate::actions::helpers;
2use crate::errors::{CancellationReason, DynoxideError, Result};
3use crate::partiql;
4use crate::storage_backend::StorageBackend;
5use crate::types::{AttributeValue, Item};
6use serde::{Deserialize, Serialize};
7
8#[derive(Debug, Default, Deserialize)]
9pub struct ExecuteTransactionRequest {
10    #[serde(rename = "TransactStatements")]
11    pub transact_statements: Vec<ParameterizedStatement>,
12    #[serde(rename = "ClientRequestToken", default)]
13    pub client_request_token: Option<String>,
14    #[serde(rename = "ReturnConsumedCapacity", default)]
15    pub return_consumed_capacity: Option<String>,
16}
17
18#[derive(Debug, Default, Deserialize)]
19pub struct ParameterizedStatement {
20    #[serde(rename = "Statement")]
21    pub statement: String,
22    #[serde(rename = "Parameters", default)]
23    pub parameters: Option<Vec<AttributeValue>>,
24}
25
26#[derive(Debug, Default, Serialize)]
27pub struct ExecuteTransactionResponse {
28    #[serde(rename = "Responses", skip_serializing_if = "Option::is_none")]
29    pub responses: Option<Vec<ItemResponse>>,
30    #[serde(rename = "ConsumedCapacity", skip_serializing_if = "Option::is_none")]
31    pub consumed_capacity: Option<Vec<crate::types::ConsumedCapacity>>,
32}
33
34#[derive(Debug, Default, Serialize)]
35pub struct ItemResponse {
36    #[serde(rename = "Item", skip_serializing_if = "Option::is_none")]
37    pub item: Option<Item>,
38}
39
40pub async fn execute<S: StorageBackend>(
41    storage: &S,
42    request: ExecuteTransactionRequest,
43) -> Result<ExecuteTransactionResponse> {
44    let statements = &request.transact_statements;
45
46    // Validate: must have between 1 and 100 statements
47    if statements.is_empty() {
48        return Err(DynoxideError::ValidationException(
49            "1 validation error detected: Value at 'transactStatements' failed to satisfy constraint: Member must have length greater than or equal to 1".to_string(),
50        ));
51    }
52    if statements.len() > 100 {
53        return Err(DynoxideError::ValidationException(
54            "Member must have length less than or equal to 100".to_string(),
55        ));
56    }
57
58    // Parse all statements before executing any, to fail fast on syntax errors
59    let mut parsed = Vec::with_capacity(statements.len());
60    for stmt in statements {
61        let ast = partiql::parser::parse(&stmt.statement).map_err(|e| {
62            DynoxideError::ValidationException(format!(
63                "Statement wasn't well formed, got error: {e}"
64            ))
65        })?;
66        let params = stmt.parameters.clone().unwrap_or_default();
67        parsed.push((ast, params));
68    }
69
70    // All statements run inside one SQLite transaction (all-or-nothing).
71    let responses =
72        helpers::with_write_transaction(storage, execute_within_transaction(storage, &parsed))
73            .await?;
74
75    // Build ConsumedCapacity if requested (simple estimate: 1 WCU per statement)
76    let consumed_capacity = if matches!(
77        request.return_consumed_capacity.as_deref(),
78        Some("TOTAL") | Some("INDEXES")
79    ) {
80        // Aggregate capacity by table name from parsed statements
81        let mut table_units: std::collections::HashMap<String, f64> =
82            std::collections::HashMap::new();
83        for (stmt, _) in &parsed {
84            if let Some(tbl) = partiql::parser::table_name(stmt) {
85                *table_units.entry(tbl.to_string()).or_default() += 1.0;
86            }
87        }
88        let caps: Vec<_> = table_units
89            .iter()
90            .filter_map(|(table, &units)| {
91                crate::types::consumed_capacity(table, units, &request.return_consumed_capacity)
92            })
93            .collect();
94        Some(caps)
95    } else {
96        None
97    };
98
99    Ok(ExecuteTransactionResponse {
100        responses: Some(responses),
101        consumed_capacity,
102    })
103}
104
105async fn execute_within_transaction<S: StorageBackend>(
106    storage: &S,
107    parsed: &[(partiql::parser::Statement, Vec<AttributeValue>)],
108) -> Result<Vec<ItemResponse>> {
109    let mut responses = Vec::with_capacity(parsed.len());
110    let mut cancellation_reasons: Vec<CancellationReason> = Vec::with_capacity(parsed.len());
111
112    for (stmt, params) in parsed {
113        match partiql::executor::execute(storage, stmt, params, None).await {
114            Ok(result) => {
115                let item = result.and_then(|items| items.into_iter().next());
116                responses.push(ItemResponse { item });
117                cancellation_reasons.push(CancellationReason {
118                    code: "None".to_string(),
119                    message: None,
120                    item: None,
121                });
122            }
123            Err(e) => {
124                // Record the failure reason
125                let message = Some(e.to_string());
126                let (code, item) = match e {
127                    DynoxideError::ConditionalCheckFailedException(_, item) => {
128                        ("ConditionalCheckFailed".to_string(), item)
129                    }
130                    DynoxideError::DuplicateItemException(_) => ("DuplicateItem".to_string(), None),
131                    DynoxideError::ValidationException(_) => ("ValidationError".to_string(), None),
132                    _ => ("InternalError".to_string(), None),
133                };
134                responses.push(ItemResponse { item: None });
135                cancellation_reasons.push(CancellationReason {
136                    code,
137                    message,
138                    item,
139                });
140
141                // Fill remaining slots with None and stop — don't execute
142                // statements that will be rolled back.
143                for _ in responses.len()..parsed.len() {
144                    responses.push(ItemResponse { item: None });
145                    cancellation_reasons.push(CancellationReason {
146                        code: "None".to_string(),
147                        message: None,
148                        item: None,
149                    });
150                }
151
152                let codes: Vec<&str> = cancellation_reasons
153                    .iter()
154                    .map(|r| r.code.as_str())
155                    .collect();
156                let message = format!(
157                    "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]",
158                    codes.join(", ")
159                );
160                return Err(DynoxideError::TransactionCanceledException(
161                    message,
162                    cancellation_reasons,
163                ));
164            }
165        }
166    }
167
168    Ok(responses)
169}