Skip to main content

dynoxide/actions/
execute_transaction.rs

1use crate::errors::{CancellationReason, DynoxideError, Result};
2use crate::partiql;
3use crate::storage::Storage;
4use crate::types::{AttributeValue, Item};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Default, Deserialize)]
8pub struct ExecuteTransactionRequest {
9    #[serde(rename = "TransactStatements")]
10    pub transact_statements: Vec<ParameterizedStatement>,
11    #[serde(rename = "ClientRequestToken", default)]
12    pub client_request_token: Option<String>,
13    #[serde(rename = "ReturnConsumedCapacity", default)]
14    pub return_consumed_capacity: Option<String>,
15}
16
17#[derive(Debug, Default, Deserialize)]
18pub struct ParameterizedStatement {
19    #[serde(rename = "Statement")]
20    pub statement: String,
21    #[serde(rename = "Parameters", default)]
22    pub parameters: Option<Vec<AttributeValue>>,
23}
24
25#[derive(Debug, Default, Serialize)]
26pub struct ExecuteTransactionResponse {
27    #[serde(rename = "Responses", skip_serializing_if = "Option::is_none")]
28    pub responses: Option<Vec<ItemResponse>>,
29    #[serde(rename = "ConsumedCapacity", skip_serializing_if = "Option::is_none")]
30    pub consumed_capacity: Option<Vec<crate::types::ConsumedCapacity>>,
31}
32
33#[derive(Debug, Default, Serialize)]
34pub struct ItemResponse {
35    #[serde(rename = "Item", skip_serializing_if = "Option::is_none")]
36    pub item: Option<Item>,
37}
38
39pub fn execute(
40    storage: &Storage,
41    request: ExecuteTransactionRequest,
42) -> Result<ExecuteTransactionResponse> {
43    let statements = &request.transact_statements;
44
45    // Validate: must have between 1 and 100 statements
46    if statements.is_empty() {
47        return Err(DynoxideError::ValidationException(
48            "1 validation error detected: Value at 'transactStatements' failed to satisfy constraint: Member must have length greater than or equal to 1".to_string(),
49        ));
50    }
51    if statements.len() > 100 {
52        return Err(DynoxideError::ValidationException(
53            "Member must have length less than or equal to 100".to_string(),
54        ));
55    }
56
57    // Parse all statements before executing any, to fail fast on syntax errors
58    let mut parsed = Vec::with_capacity(statements.len());
59    for stmt in statements {
60        let ast = partiql::parser::parse(&stmt.statement).map_err(|e| {
61            DynoxideError::ValidationException(format!(
62                "Statement wasn't well formed, got error: {e}"
63            ))
64        })?;
65        let params = stmt.parameters.clone().unwrap_or_default();
66        parsed.push((ast, params));
67    }
68
69    // Begin SQLite transaction
70    storage.begin_transaction()?;
71
72    let result = execute_within_transaction(storage, &parsed);
73
74    match result {
75        Ok(responses) => {
76            storage.commit()?;
77
78            // Build ConsumedCapacity if requested (simple estimate: 1 WCU per statement)
79            let consumed_capacity = if matches!(
80                request.return_consumed_capacity.as_deref(),
81                Some("TOTAL") | Some("INDEXES")
82            ) {
83                // Aggregate capacity by table name from parsed statements
84                let mut table_units: std::collections::HashMap<String, f64> =
85                    std::collections::HashMap::new();
86                for (stmt, _) in &parsed {
87                    if let Some(tbl) = partiql::parser::table_name(stmt) {
88                        *table_units.entry(tbl.to_string()).or_default() += 1.0;
89                    }
90                }
91                let caps: Vec<_> = table_units
92                    .iter()
93                    .filter_map(|(table, &units)| {
94                        crate::types::consumed_capacity(
95                            table,
96                            units,
97                            &request.return_consumed_capacity,
98                        )
99                    })
100                    .collect();
101                Some(caps)
102            } else {
103                None
104            };
105
106            Ok(ExecuteTransactionResponse {
107                responses: Some(responses),
108                consumed_capacity,
109            })
110        }
111        Err(e) => {
112            if let Err(rb_err) = storage.rollback() {
113                return Err(DynoxideError::InternalServerError(format!(
114                    "Transaction failed ({e}) and rollback also failed ({rb_err})"
115                )));
116            }
117            Err(e)
118        }
119    }
120}
121
122fn execute_within_transaction(
123    storage: &Storage,
124    parsed: &[(partiql::parser::Statement, Vec<AttributeValue>)],
125) -> Result<Vec<ItemResponse>> {
126    let mut responses = Vec::with_capacity(parsed.len());
127    let mut cancellation_reasons: Vec<CancellationReason> = Vec::with_capacity(parsed.len());
128
129    for (stmt, params) in parsed {
130        match partiql::executor::execute(storage, stmt, params, None) {
131            Ok(result) => {
132                let item = result.and_then(|items| items.into_iter().next());
133                responses.push(ItemResponse { item });
134                cancellation_reasons.push(CancellationReason {
135                    code: "None".to_string(),
136                    message: None,
137                    item: None,
138                });
139            }
140            Err(e) => {
141                // Record the failure reason
142                let message = Some(e.to_string());
143                let (code, item) = match e {
144                    DynoxideError::ConditionalCheckFailedException(_, item) => {
145                        ("ConditionalCheckFailed".to_string(), item)
146                    }
147                    DynoxideError::DuplicateItemException(_) => ("DuplicateItem".to_string(), None),
148                    DynoxideError::ValidationException(_) => ("ValidationError".to_string(), None),
149                    _ => ("InternalError".to_string(), None),
150                };
151                responses.push(ItemResponse { item: None });
152                cancellation_reasons.push(CancellationReason {
153                    code,
154                    message,
155                    item,
156                });
157
158                // Fill remaining slots with None and stop — don't execute
159                // statements that will be rolled back.
160                for _ in responses.len()..parsed.len() {
161                    responses.push(ItemResponse { item: None });
162                    cancellation_reasons.push(CancellationReason {
163                        code: "None".to_string(),
164                        message: None,
165                        item: None,
166                    });
167                }
168
169                let codes: Vec<&str> = cancellation_reasons
170                    .iter()
171                    .map(|r| r.code.as_str())
172                    .collect();
173                let message = format!(
174                    "Transaction cancelled, please refer cancellation reasons for specific reasons [{}]",
175                    codes.join(", ")
176                );
177                return Err(DynoxideError::TransactionCanceledException(
178                    message,
179                    cancellation_reasons,
180                ));
181            }
182        }
183    }
184
185    Ok(responses)
186}