kstone_core/partiql/
ast.rs

1/// Simplified AST for PartiQL statements
2///
3/// This module provides a simplified Abstract Syntax Tree that wraps sqlparser's AST
4/// with types specific to DynamoDB PartiQL operations.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Top-level PartiQL statement
10#[derive(Debug, Clone, PartialEq)]
11pub enum PartiQLStatement {
12    Select(SelectStatement),
13    Insert(InsertStatement),
14    Update(UpdateStatement),
15    Delete(DeleteStatement),
16}
17
18/// SELECT statement
19#[derive(Debug, Clone, PartialEq)]
20pub struct SelectStatement {
21    /// Table name (DynamoDB table)
22    pub table_name: String,
23    /// Optional index name (for LSI/GSI queries)
24    pub index_name: Option<String>,
25    /// Attributes to select (None = SELECT *)
26    pub select_list: SelectList,
27    /// WHERE clause conditions
28    pub where_clause: Option<WhereClause>,
29    /// ORDER BY clause
30    pub order_by: Option<OrderBy>,
31    /// LIMIT clause (max number of items to return)
32    pub limit: Option<usize>,
33    /// OFFSET clause (number of items to skip)
34    pub offset: Option<usize>,
35}
36
37/// SELECT attribute list
38#[derive(Debug, Clone, PartialEq)]
39pub enum SelectList {
40    /// SELECT *
41    All,
42    /// SELECT attr1, attr2, ...
43    Attributes(Vec<String>),
44}
45
46/// WHERE clause with conditions
47#[derive(Debug, Clone, PartialEq)]
48pub struct WhereClause {
49    /// List of conditions (implicitly AND-ed)
50    pub conditions: Vec<Condition>,
51}
52
53impl WhereClause {
54    /// Get condition for a specific attribute
55    pub fn get_condition(&self, attr_name: &str) -> Option<&Condition> {
56        self.conditions.iter().find(|c| c.attribute == attr_name)
57    }
58
59    /// Check if clause contains a condition for given attribute
60    pub fn has_condition(&self, attr_name: &str) -> bool {
61        self.get_condition(attr_name).is_some()
62    }
63}
64
65/// Single condition in WHERE clause
66#[derive(Debug, Clone, PartialEq)]
67pub struct Condition {
68    /// Attribute name
69    pub attribute: String,
70    /// Comparison operator
71    pub operator: CompareOp,
72    /// Value(s) to compare against
73    pub value: SqlValue,
74}
75
76impl Condition {
77    /// Check if this is a key attribute condition (pk or sk)
78    pub fn is_key_attribute(&self) -> bool {
79        self.attribute == "pk" || self.attribute == "sk"
80    }
81}
82
83/// Comparison operators
84#[derive(Debug, Clone, PartialEq, Eq, Hash)]
85pub enum CompareOp {
86    /// =
87    Equal,
88    /// <>
89    NotEqual,
90    /// <
91    LessThan,
92    /// <=
93    LessThanOrEqual,
94    /// >
95    GreaterThan,
96    /// >=
97    GreaterThanOrEqual,
98    /// IN (...)
99    In,
100    /// BETWEEN x AND y
101    Between,
102}
103
104/// SQL values (simplified from sqlparser)
105#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
106pub enum SqlValue {
107    /// Number (stored as string for precision)
108    Number(String),
109    /// String
110    String(String),
111    /// Boolean
112    Boolean(bool),
113    /// Null
114    Null,
115    /// List/Array
116    List(Vec<SqlValue>),
117    /// Map/Object
118    Map(HashMap<String, SqlValue>),
119}
120
121impl SqlValue {
122    /// Convert to KeystoneDB Value type
123    pub fn to_kstone_value(&self) -> crate::Value {
124        match self {
125            SqlValue::Number(s) => crate::Value::N(s.clone()),
126            SqlValue::String(s) => crate::Value::S(s.clone()),
127            SqlValue::Boolean(b) => crate::Value::Bool(*b),
128            SqlValue::Null => crate::Value::Null,
129            SqlValue::List(items) => {
130                let values: Vec<crate::Value> = items.iter()
131                    .map(|item| item.to_kstone_value())
132                    .collect();
133                crate::Value::L(values)
134            }
135            SqlValue::Map(map) => {
136                let mut kv_map = std::collections::HashMap::new();
137                for (k, v) in map {
138                    kv_map.insert(k.clone(), v.to_kstone_value());
139                }
140                crate::Value::M(kv_map)
141            }
142        }
143    }
144
145    /// Create from KeystoneDB Value
146    pub fn from_kstone_value(value: &crate::Value) -> Self {
147        match value {
148            crate::Value::N(s) => SqlValue::Number(s.clone()),
149            crate::Value::S(s) => SqlValue::String(s.clone()),
150            crate::Value::Bool(b) => SqlValue::Boolean(*b),
151            crate::Value::Null => SqlValue::Null,
152            crate::Value::L(items) => {
153                let sql_values: Vec<SqlValue> = items.iter()
154                    .map(SqlValue::from_kstone_value)
155                    .collect();
156                SqlValue::List(sql_values)
157            }
158            crate::Value::M(map) => {
159                let mut sql_map = HashMap::new();
160                for (k, v) in map {
161                    sql_map.insert(k.clone(), SqlValue::from_kstone_value(v));
162                }
163                SqlValue::Map(sql_map)
164            }
165            crate::Value::B(bytes) => {
166                // Encode binary as base64 string
167                SqlValue::String(base64_encode(bytes))
168            }
169            crate::Value::VecF32(vec) => {
170                // Convert to list of numbers
171                let numbers: Vec<SqlValue> = vec.iter()
172                    .map(|f| SqlValue::Number(f.to_string()))
173                    .collect();
174                SqlValue::List(numbers)
175            }
176            crate::Value::Ts(ts) => SqlValue::Number(ts.to_string()),
177        }
178    }
179}
180
181fn base64_encode(bytes: &bytes::Bytes) -> String {
182    use std::io::Write;
183    let mut buf = Vec::new();
184    {
185        let mut encoder = base64::write::EncoderWriter::new(&mut buf, &base64::engine::general_purpose::STANDARD);
186        encoder.write_all(bytes).unwrap();
187        encoder.finish().unwrap();
188    }
189    String::from_utf8(buf).unwrap()
190}
191
192/// ORDER BY clause
193#[derive(Debug, Clone, PartialEq)]
194pub struct OrderBy {
195    /// Attribute to order by
196    pub attribute: String,
197    /// Sort direction (true = ASC, false = DESC)
198    pub ascending: bool,
199}
200
201/// INSERT statement
202#[derive(Debug, Clone, PartialEq)]
203pub struct InsertStatement {
204    /// Table name
205    pub table_name: String,
206    /// Value to insert (must be a Map with pk and optional sk)
207    pub value: SqlValue,
208}
209
210/// UPDATE statement
211#[derive(Debug, Clone, PartialEq)]
212pub struct UpdateStatement {
213    /// Table name
214    pub table_name: String,
215    /// WHERE clause (must contain pk, optional sk)
216    pub where_clause: WhereClause,
217    /// SET assignments
218    pub set_assignments: Vec<SetAssignment>,
219    /// REMOVE attributes
220    pub remove_attributes: Vec<String>,
221}
222
223/// SET assignment (SET attr = value)
224#[derive(Debug, Clone, PartialEq)]
225pub struct SetAssignment {
226    /// Attribute name
227    pub attribute: String,
228    /// Value expression (can be literal or arithmetic like "age + 1")
229    pub value: SetValue,
230}
231
232/// Value in SET clause
233#[derive(Debug, Clone, PartialEq)]
234pub enum SetValue {
235    /// Literal value
236    Literal(SqlValue),
237    /// Arithmetic expression (attribute + value)
238    Add {
239        attribute: String,
240        value: SqlValue,
241    },
242    /// Arithmetic expression (attribute - value)
243    Subtract {
244        attribute: String,
245        value: SqlValue,
246    },
247}
248
249/// DELETE statement
250#[derive(Debug, Clone, PartialEq)]
251pub struct DeleteStatement {
252    /// Table name
253    pub table_name: String,
254    /// WHERE clause (must contain full key: pk and optional sk)
255    pub where_clause: WhereClause,
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_where_clause_get_condition() {
264        let where_clause = WhereClause {
265            conditions: vec![
266                Condition {
267                    attribute: "pk".to_string(),
268                    operator: CompareOp::Equal,
269                    value: SqlValue::String("user#123".to_string()),
270                },
271                Condition {
272                    attribute: "age".to_string(),
273                    operator: CompareOp::GreaterThan,
274                    value: SqlValue::Number("18".to_string()),
275                },
276            ],
277        };
278
279        assert!(where_clause.get_condition("pk").is_some());
280        assert!(where_clause.get_condition("age").is_some());
281        assert!(where_clause.get_condition("name").is_none());
282    }
283
284    #[test]
285    fn test_condition_is_key_attribute() {
286        let pk_cond = Condition {
287            attribute: "pk".to_string(),
288            operator: CompareOp::Equal,
289            value: SqlValue::String("user#123".to_string()),
290        };
291
292        let sk_cond = Condition {
293            attribute: "sk".to_string(),
294            operator: CompareOp::Equal,
295            value: SqlValue::String("profile".to_string()),
296        };
297
298        let data_cond = Condition {
299            attribute: "age".to_string(),
300            operator: CompareOp::GreaterThan,
301            value: SqlValue::Number("18".to_string()),
302        };
303
304        assert!(pk_cond.is_key_attribute());
305        assert!(sk_cond.is_key_attribute());
306        assert!(!data_cond.is_key_attribute());
307    }
308
309    #[test]
310    fn test_sql_value_to_kstone_value() {
311        let sql_num = SqlValue::Number("42".to_string());
312        assert!(matches!(sql_num.to_kstone_value(), crate::Value::N(s) if s == "42"));
313
314        let sql_str = SqlValue::String("hello".to_string());
315        assert!(matches!(sql_str.to_kstone_value(), crate::Value::S(s) if s == "hello"));
316
317        let sql_bool = SqlValue::Boolean(true);
318        assert!(matches!(sql_bool.to_kstone_value(), crate::Value::Bool(true)));
319
320        let sql_null = SqlValue::Null;
321        assert!(matches!(sql_null.to_kstone_value(), crate::Value::Null));
322    }
323
324    #[test]
325    fn test_sql_value_from_kstone_value() {
326        let kv_num = crate::Value::N("42".to_string());
327        assert_eq!(SqlValue::from_kstone_value(&kv_num), SqlValue::Number("42".to_string()));
328
329        let kv_str = crate::Value::S("hello".to_string());
330        assert_eq!(SqlValue::from_kstone_value(&kv_str), SqlValue::String("hello".to_string()));
331
332        let kv_bool = crate::Value::Bool(true);
333        assert_eq!(SqlValue::from_kstone_value(&kv_bool), SqlValue::Boolean(true));
334    }
335}