kstone_core/partiql/
translator.rs

1/// Translates PartiQL AST to KeystoneDB operations
2///
3/// Maps SELECT to Query/Scan, INSERT to Put, UPDATE to Update, DELETE to Delete.
4
5use crate::partiql::ast::*;
6use crate::partiql::validator::{DynamoDBValidator, QueryType};
7use crate::{Error, Key, Result};
8use bytes::Bytes;
9
10/// PartiQL to KeystoneDB translator
11pub struct PartiQLTranslator;
12
13impl PartiQLTranslator {
14    /// Translate SELECT statement to Query or Scan parameters
15    pub fn translate_select(stmt: &SelectStatement) -> Result<SelectTranslation> {
16        // Validate and determine query type
17        let query_type = DynamoDBValidator::validate_select(stmt)?;
18
19        match query_type {
20            QueryType::Query { pk_condition, sk_condition } => {
21                // Translate to Query
22                let (pk_bytes, multiple_pks) = Self::extract_pk_bytes(&pk_condition)?;
23
24                if multiple_pks {
25                    // IN clause with multiple PKs - need to execute multiple gets
26                    Ok(SelectTranslation::MultiGet {
27                        keys: pk_bytes,
28                        index_name: stmt.index_name.clone(),
29                    })
30                } else {
31                    // Single PK - regular Query
32                    let pk = pk_bytes.into_iter().next().unwrap();
33                    let sk_condition_translated = sk_condition
34                        .as_ref()
35                        .map(Self::translate_sk_condition)
36                        .transpose()?;
37
38                    Ok(SelectTranslation::Query {
39                        pk,
40                        sk_condition: sk_condition_translated,
41                        index_name: stmt.index_name.clone(),
42                        forward: stmt.order_by.as_ref().map_or(true, |o| o.ascending),
43                    })
44                }
45            }
46            QueryType::Scan => {
47                // Translate to Scan
48                Ok(SelectTranslation::Scan {
49                    filter_conditions: stmt
50                        .where_clause
51                        .as_ref()
52                        .map(|wc| wc.conditions.clone())
53                        .unwrap_or_default(),
54                })
55            }
56        }
57    }
58
59    /// Extract partition key bytes from condition
60    fn extract_pk_bytes(condition: &Condition) -> Result<(Vec<Bytes>, bool)> {
61        match &condition.operator {
62            CompareOp::Equal => {
63                let bytes = Self::value_to_bytes(&condition.value)?;
64                Ok((vec![bytes], false))
65            }
66            CompareOp::In => {
67                // IN clause - extract all values
68                match &condition.value {
69                    SqlValue::List(values) => {
70                        let bytes_vec: Result<Vec<Bytes>> = values
71                            .iter()
72                            .map(Self::value_to_bytes)
73                            .collect();
74                        Ok((bytes_vec?, true))
75                    }
76                    _ => Err(Error::InvalidQuery("IN value must be a list".into())),
77                }
78            }
79            _ => Err(Error::InvalidQuery(
80                "Partition key must use = or IN operator".into(),
81            )),
82        }
83    }
84
85    /// Convert SqlValue to Bytes for key
86    fn value_to_bytes(value: &SqlValue) -> Result<Bytes> {
87        match value {
88            SqlValue::String(s) => Ok(Bytes::copy_from_slice(s.as_bytes())),
89            SqlValue::Number(n) => Ok(Bytes::copy_from_slice(n.as_bytes())),
90            _ => Err(Error::InvalidQuery(format!(
91                "Unsupported key value type: {:?}",
92                value
93            ))),
94        }
95    }
96
97    /// Translate sort key condition to KeystoneDB SortKeyCondition
98    fn translate_sk_condition(condition: &Condition) -> Result<SortKeyConditionType> {
99        let sk_bytes = Self::value_to_bytes(&condition.value)?;
100
101        match condition.operator {
102            CompareOp::Equal => Ok(SortKeyConditionType::Equal(sk_bytes)),
103            CompareOp::LessThan => Ok(SortKeyConditionType::LessThan(sk_bytes)),
104            CompareOp::LessThanOrEqual => Ok(SortKeyConditionType::LessThanOrEqual(sk_bytes)),
105            CompareOp::GreaterThan => Ok(SortKeyConditionType::GreaterThan(sk_bytes)),
106            CompareOp::GreaterThanOrEqual => Ok(SortKeyConditionType::GreaterThanOrEqual(sk_bytes)),
107            CompareOp::Between => {
108                match &condition.value {
109                    SqlValue::List(values) if values.len() == 2 => {
110                        let low = Self::value_to_bytes(&values[0])?;
111                        let high = Self::value_to_bytes(&values[1])?;
112                        Ok(SortKeyConditionType::Between(low, high))
113                    }
114                    _ => Err(Error::InvalidQuery("BETWEEN requires exactly 2 values".into())),
115                }
116            }
117            _ => Err(Error::InvalidQuery(format!(
118                "Unsupported sort key operator: {:?}",
119                condition.operator
120            ))),
121        }
122    }
123
124    /// Translate INSERT statement
125    pub fn translate_insert(stmt: &InsertStatement) -> Result<InsertTranslation> {
126        // Validate
127        DynamoDBValidator::validate_insert(stmt)?;
128
129        // Extract key and item from value map
130        let value_map = match &stmt.value {
131            SqlValue::Map(map) => map,
132            _ => return Err(Error::InvalidQuery("INSERT value must be a map".into())),
133        };
134
135        // Extract pk
136        let pk_value = value_map
137            .get("pk")
138            .ok_or_else(|| Error::InvalidQuery("INSERT value must contain 'pk'".into()))?;
139        let pk_bytes = Self::value_to_bytes(pk_value)?;
140
141        // Extract optional sk
142        let sk_bytes = value_map
143            .get("sk")
144            .map(Self::value_to_bytes)
145            .transpose()?;
146
147        // Build key
148        let key = if let Some(sk) = sk_bytes {
149            Key::with_sk(pk_bytes.to_vec(), sk.to_vec())
150        } else {
151            Key::new(pk_bytes.to_vec())
152        };
153
154        // Convert remaining attributes to Item
155        let mut item = std::collections::HashMap::new();
156        for (attr_name, attr_value) in value_map {
157            if attr_name != "pk" && attr_name != "sk" {
158                item.insert(attr_name.clone(), attr_value.to_kstone_value());
159            }
160        }
161
162        Ok(InsertTranslation { key, item })
163    }
164
165    /// Translate UPDATE statement
166    pub fn translate_update(stmt: &UpdateStatement) -> Result<UpdateTranslation> {
167        // Validate
168        DynamoDBValidator::validate_update(stmt)?;
169
170        // Extract key from WHERE clause
171        let pk_cond = stmt
172            .where_clause
173            .get_condition("pk")
174            .ok_or_else(|| Error::InvalidQuery("UPDATE must specify pk in WHERE clause".into()))?;
175        let pk_bytes = Self::value_to_bytes(&pk_cond.value)?;
176
177        let sk_bytes = stmt
178            .where_clause
179            .get_condition("sk")
180            .map(|c| Self::value_to_bytes(&c.value))
181            .transpose()?;
182
183        let key = if let Some(sk) = sk_bytes {
184            Key::with_sk(pk_bytes.to_vec(), sk.to_vec())
185        } else {
186            Key::new(pk_bytes.to_vec())
187        };
188
189        // Build UPDATE expression and values map
190        let mut expression_parts = Vec::new();
191        let mut values = std::collections::HashMap::new();
192        let mut value_counter = 1;
193
194        // Process SET assignments
195        if !stmt.set_assignments.is_empty() {
196            let mut set_exprs = Vec::new();
197            for assignment in &stmt.set_assignments {
198                match &assignment.value {
199                    SetValue::Literal(sql_value) => {
200                        // SET attr = :v1
201                        let placeholder = format!(":v{}", value_counter);
202                        set_exprs.push(format!("{} = {}", assignment.attribute, placeholder));
203                        values.insert(placeholder, sql_value.to_kstone_value());
204                        value_counter += 1;
205                    }
206                    SetValue::Add { attribute, value } => {
207                        // SET attr = attr + :v1
208                        let placeholder = format!(":v{}", value_counter);
209                        set_exprs.push(format!(
210                            "{} = {} + {}",
211                            assignment.attribute, attribute, placeholder
212                        ));
213                        values.insert(placeholder, value.to_kstone_value());
214                        value_counter += 1;
215                    }
216                    SetValue::Subtract { attribute, value } => {
217                        // SET attr = attr - :v1
218                        let placeholder = format!(":v{}", value_counter);
219                        set_exprs.push(format!(
220                            "{} = {} - {}",
221                            assignment.attribute, attribute, placeholder
222                        ));
223                        values.insert(placeholder, value.to_kstone_value());
224                        value_counter += 1;
225                    }
226                }
227            }
228            expression_parts.push(format!("SET {}", set_exprs.join(", ")));
229        }
230
231        // Process REMOVE attributes
232        if !stmt.remove_attributes.is_empty() {
233            expression_parts.push(format!("REMOVE {}", stmt.remove_attributes.join(", ")));
234        }
235
236        let expression = expression_parts.join(" ");
237
238        Ok(UpdateTranslation {
239            key,
240            expression,
241            values,
242        })
243    }
244
245    /// Translate DELETE statement
246    pub fn translate_delete(stmt: &DeleteStatement) -> Result<DeleteTranslation> {
247        // Validate
248        DynamoDBValidator::validate_delete(stmt)?;
249
250        // Extract key from WHERE clause
251        let pk_cond = stmt
252            .where_clause
253            .get_condition("pk")
254            .ok_or_else(|| Error::InvalidQuery("DELETE must specify pk in WHERE clause".into()))?;
255        let pk_bytes = Self::value_to_bytes(&pk_cond.value)?;
256
257        let sk_bytes = stmt
258            .where_clause
259            .get_condition("sk")
260            .map(|c| Self::value_to_bytes(&c.value))
261            .transpose()?;
262
263        let key = if let Some(sk) = sk_bytes {
264            Key::with_sk(pk_bytes.to_vec(), sk.to_vec())
265        } else {
266            Key::new(pk_bytes.to_vec())
267        };
268
269        Ok(DeleteTranslation { key })
270    }
271}
272
273/// SELECT statement translation result
274#[derive(Debug)]
275pub enum SelectTranslation {
276    /// Query operation (single partition)
277    Query {
278        pk: Bytes,
279        sk_condition: Option<SortKeyConditionType>,
280        index_name: Option<String>,
281        forward: bool,
282    },
283    /// Multiple get operations (IN clause on pk)
284    MultiGet {
285        keys: Vec<Bytes>,
286        index_name: Option<String>,
287    },
288    /// Scan operation (full table scan)
289    Scan {
290        filter_conditions: Vec<Condition>,
291    },
292}
293
294/// Sort key condition type
295#[derive(Debug, Clone)]
296pub enum SortKeyConditionType {
297    Equal(Bytes),
298    LessThan(Bytes),
299    LessThanOrEqual(Bytes),
300    GreaterThan(Bytes),
301    GreaterThanOrEqual(Bytes),
302    Between(Bytes, Bytes),
303}
304
305/// INSERT translation result
306#[derive(Debug)]
307pub struct InsertTranslation {
308    pub key: Key,
309    pub item: crate::Item,
310}
311
312/// UPDATE translation result
313#[derive(Debug)]
314pub struct UpdateTranslation {
315    pub key: Key,
316    pub expression: String,
317    pub values: std::collections::HashMap<String, crate::Value>,
318}
319
320/// DELETE translation result
321#[derive(Debug)]
322pub struct DeleteTranslation {
323    pub key: Key,
324}
325
326#[cfg(test)]
327mod tests {
328    use super::*;
329
330    #[test]
331    fn test_translate_select_query() {
332        let stmt = SelectStatement {
333            table_name: "users".to_string(),
334            index_name: None,
335            select_list: SelectList::All,
336            where_clause: Some(WhereClause {
337                conditions: vec![Condition {
338                    attribute: "pk".to_string(),
339                    operator: CompareOp::Equal,
340                    value: SqlValue::String("user#123".to_string()),
341                }],
342            }),
343            order_by: None,
344            limit: None,
345            offset: None,
346        };
347
348        let translation = PartiQLTranslator::translate_select(&stmt).unwrap();
349        match translation {
350            SelectTranslation::Query { pk, .. } => {
351                assert_eq!(pk, Bytes::from("user#123"));
352            }
353            _ => panic!("Expected Query translation"),
354        }
355    }
356
357    #[test]
358    fn test_translate_select_scan() {
359        let stmt = SelectStatement {
360            table_name: "users".to_string(),
361            index_name: None,
362            select_list: SelectList::All,
363            where_clause: None,
364            order_by: None,
365            limit: None,
366            offset: None,
367        };
368
369        let translation = PartiQLTranslator::translate_select(&stmt).unwrap();
370        match translation {
371            SelectTranslation::Scan { .. } => {}
372            _ => panic!("Expected Scan translation"),
373        }
374    }
375
376    #[test]
377    fn test_translate_insert() {
378        let mut map = std::collections::HashMap::new();
379        map.insert("pk".to_string(), SqlValue::String("user#123".to_string()));
380        map.insert("name".to_string(), SqlValue::String("Alice".to_string()));
381        map.insert("age".to_string(), SqlValue::Number("30".to_string()));
382
383        let stmt = InsertStatement {
384            table_name: "users".to_string(),
385            value: SqlValue::Map(map),
386        };
387
388        let translation = PartiQLTranslator::translate_insert(&stmt).unwrap();
389        assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
390        assert_eq!(translation.item.len(), 2); // name and age (pk/sk excluded)
391    }
392
393    #[test]
394    fn test_translate_delete() {
395        let stmt = DeleteStatement {
396            table_name: "users".to_string(),
397            where_clause: WhereClause {
398                conditions: vec![Condition {
399                    attribute: "pk".to_string(),
400                    operator: CompareOp::Equal,
401                    value: SqlValue::String("user#123".to_string()),
402                }],
403            },
404        };
405
406        let translation = PartiQLTranslator::translate_delete(&stmt).unwrap();
407        assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
408    }
409
410    #[test]
411    fn test_translate_update_simple() {
412        let stmt = UpdateStatement {
413            table_name: "users".to_string(),
414            where_clause: WhereClause {
415                conditions: vec![Condition {
416                    attribute: "pk".to_string(),
417                    operator: CompareOp::Equal,
418                    value: SqlValue::String("user#123".to_string()),
419                }],
420            },
421            set_assignments: vec![
422                SetAssignment {
423                    attribute: "name".to_string(),
424                    value: SetValue::Literal(SqlValue::String("Alice".to_string())),
425                },
426                SetAssignment {
427                    attribute: "age".to_string(),
428                    value: SetValue::Literal(SqlValue::Number("30".to_string())),
429                },
430            ],
431            remove_attributes: vec![],
432        };
433
434        let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
435        assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
436        assert!(translation.expression.contains("SET"));
437        assert_eq!(translation.values.len(), 2); // :v1 and :v2
438    }
439
440    #[test]
441    fn test_translate_update_with_arithmetic() {
442        let stmt = UpdateStatement {
443            table_name: "users".to_string(),
444            where_clause: WhereClause {
445                conditions: vec![Condition {
446                    attribute: "pk".to_string(),
447                    operator: CompareOp::Equal,
448                    value: SqlValue::String("user#123".to_string()),
449                }],
450            },
451            set_assignments: vec![
452                SetAssignment {
453                    attribute: "age".to_string(),
454                    value: SetValue::Add {
455                        attribute: "age".to_string(),
456                        value: SqlValue::Number("1".to_string()),
457                    },
458                },
459                SetAssignment {
460                    attribute: "count".to_string(),
461                    value: SetValue::Subtract {
462                        attribute: "count".to_string(),
463                        value: SqlValue::Number("5".to_string()),
464                    },
465                },
466            ],
467            remove_attributes: vec![],
468        };
469
470        let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
471        assert!(translation.expression.contains("age = age + :v1"));
472        assert!(translation.expression.contains("count = count - :v2"));
473        assert_eq!(translation.values.len(), 2);
474    }
475
476    #[test]
477    fn test_translate_update_with_remove() {
478        let stmt = UpdateStatement {
479            table_name: "users".to_string(),
480            where_clause: WhereClause {
481                conditions: vec![Condition {
482                    attribute: "pk".to_string(),
483                    operator: CompareOp::Equal,
484                    value: SqlValue::String("user#123".to_string()),
485                }],
486            },
487            set_assignments: vec![SetAssignment {
488                attribute: "name".to_string(),
489                value: SetValue::Literal(SqlValue::String("Alice".to_string())),
490            }],
491            remove_attributes: vec!["tags".to_string(), "metadata".to_string()],
492        };
493
494        let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
495        assert!(translation.expression.contains("SET"));
496        assert!(translation.expression.contains("REMOVE tags, metadata"));
497        assert_eq!(translation.values.len(), 1); // :v1 for name
498    }
499
500    #[test]
501    fn test_translate_update_remove_only() {
502        let stmt = UpdateStatement {
503            table_name: "users".to_string(),
504            where_clause: WhereClause {
505                conditions: vec![
506                    Condition {
507                        attribute: "pk".to_string(),
508                        operator: CompareOp::Equal,
509                        value: SqlValue::String("user#123".to_string()),
510                    },
511                    Condition {
512                        attribute: "sk".to_string(),
513                        operator: CompareOp::Equal,
514                        value: SqlValue::String("profile".to_string()),
515                    },
516                ],
517            },
518            set_assignments: vec![],
519            remove_attributes: vec!["tags".to_string(), "metadata".to_string()],
520        };
521
522        let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
523        assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
524        assert_eq!(translation.key.sk.as_ref().map(|b| b.as_ref()), Some("profile".as_bytes()));
525        assert_eq!(translation.expression, "REMOVE tags, metadata");
526        assert_eq!(translation.values.len(), 0); // No placeholders needed
527    }
528}