kstone_core/partiql/
validator.rs

1/// DynamoDB-specific validation for PartiQL statements
2///
3/// Validates that PartiQL queries follow DynamoDB constraints such as:
4/// - SELECT must have partition key with = or IN for Query operations
5/// - UPDATE/DELETE must specify full primary key
6/// - No full table scans without explicit opt-in
7
8use crate::partiql::ast::*;
9use crate::{Error, Result};
10
11/// Query type determination for SELECT statements
12#[derive(Debug, Clone, PartialEq)]
13pub enum QueryType {
14    /// Query operation (partition key specified with = or IN)
15    Query {
16        /// Partition key condition
17        pk_condition: Condition,
18        /// Optional sort key condition
19        sk_condition: Option<Condition>,
20    },
21    /// Scan operation (full table scan)
22    Scan,
23}
24
25/// DynamoDB constraint validator
26pub struct DynamoDBValidator;
27
28impl DynamoDBValidator {
29    /// Validate SELECT statement and determine query type
30    pub fn validate_select(stmt: &SelectStatement) -> Result<QueryType> {
31        // Extract WHERE clause
32        let where_clause = match &stmt.where_clause {
33            Some(wc) => wc,
34            None => return Ok(QueryType::Scan),
35        };
36
37        // Look for pk condition
38        if let Some(pk_cond) = where_clause.get_condition("pk") {
39            // Check if pk uses = or IN operator
40            match pk_cond.operator {
41                CompareOp::Equal | CompareOp::In => {
42                    // This is a Query operation
43                    let sk_condition = where_clause.get_condition("sk").cloned();
44                    Ok(QueryType::Query {
45                        pk_condition: pk_cond.clone(),
46                        sk_condition,
47                    })
48                }
49                _ => {
50                    // pk exists but not with = or IN, this is a Scan
51                    Ok(QueryType::Scan)
52                }
53            }
54        } else {
55            // No pk condition, this is a Scan
56            Ok(QueryType::Scan)
57        }
58    }
59
60    /// Validate INSERT statement
61    pub fn validate_insert(stmt: &InsertStatement) -> Result<()> {
62        // Ensure value is a Map
63        match &stmt.value {
64            SqlValue::Map(map) => {
65                // Ensure pk exists
66                if !map.contains_key("pk") {
67                    return Err(Error::InvalidQuery(
68                        "INSERT value must contain 'pk' field".into(),
69                    ));
70                }
71                Ok(())
72            }
73            _ => Err(Error::InvalidQuery(
74                "INSERT value must be a map/object".into(),
75            )),
76        }
77    }
78
79    /// Validate UPDATE statement
80    pub fn validate_update(stmt: &UpdateStatement) -> Result<()> {
81        // Ensure WHERE clause has pk
82        if !stmt.where_clause.has_condition("pk") {
83            return Err(Error::InvalidQuery(
84                "UPDATE must specify partition key (pk) in WHERE clause".into(),
85            ));
86        }
87        Ok(())
88    }
89
90    /// Validate DELETE statement
91    pub fn validate_delete(stmt: &DeleteStatement) -> Result<()> {
92        // Ensure WHERE clause has pk
93        if !stmt.where_clause.has_condition("pk") {
94            return Err(Error::InvalidQuery(
95                "DELETE must specify partition key (pk) in WHERE clause".into(),
96            ));
97        }
98        Ok(())
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105
106    #[test]
107    fn test_validate_select_query_with_pk_equal() {
108        let stmt = SelectStatement {
109            table_name: "users".to_string(),
110            index_name: None,
111            select_list: SelectList::All,
112            where_clause: Some(WhereClause {
113                conditions: vec![Condition {
114                    attribute: "pk".to_string(),
115                    operator: CompareOp::Equal,
116                    value: SqlValue::String("user#123".to_string()),
117                }],
118            }),
119            order_by: None,
120            limit: None,
121            offset: None,
122        };
123
124        let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
125        match query_type {
126            QueryType::Query { pk_condition, sk_condition } => {
127                assert_eq!(pk_condition.attribute, "pk");
128                assert_eq!(pk_condition.operator, CompareOp::Equal);
129                assert!(sk_condition.is_none());
130            }
131            _ => panic!("Expected Query type"),
132        }
133    }
134
135    #[test]
136    fn test_validate_select_query_with_pk_and_sk() {
137        let stmt = SelectStatement {
138            table_name: "users".to_string(),
139            index_name: None,
140            select_list: SelectList::All,
141            where_clause: Some(WhereClause {
142                conditions: vec![
143                    Condition {
144                        attribute: "pk".to_string(),
145                        operator: CompareOp::Equal,
146                        value: SqlValue::String("user#123".to_string()),
147                    },
148                    Condition {
149                        attribute: "sk".to_string(),
150                        operator: CompareOp::GreaterThan,
151                        value: SqlValue::String("post#".to_string()),
152                    },
153                ],
154            }),
155            order_by: None,
156            limit: None,
157            offset: None,
158        };
159
160        let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
161        match query_type {
162            QueryType::Query { pk_condition, sk_condition } => {
163                assert_eq!(pk_condition.attribute, "pk");
164                assert!(sk_condition.is_some());
165                assert_eq!(sk_condition.unwrap().attribute, "sk");
166            }
167            _ => panic!("Expected Query type"),
168        }
169    }
170
171    #[test]
172    fn test_validate_select_scan_no_where() {
173        let stmt = SelectStatement {
174            table_name: "users".to_string(),
175            index_name: None,
176            select_list: SelectList::All,
177            where_clause: None,
178            order_by: None,
179            limit: None,
180            offset: None,
181        };
182
183        let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
184        assert_eq!(query_type, QueryType::Scan);
185    }
186
187    #[test]
188    fn test_validate_select_scan_no_pk() {
189        let stmt = SelectStatement {
190            table_name: "users".to_string(),
191            index_name: None,
192            select_list: SelectList::All,
193            where_clause: Some(WhereClause {
194                conditions: vec![Condition {
195                    attribute: "age".to_string(),
196                    operator: CompareOp::GreaterThan,
197                    value: SqlValue::Number("18".to_string()),
198                }],
199            }),
200            order_by: None,
201            limit: None,
202            offset: None,
203        };
204
205        let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
206        assert_eq!(query_type, QueryType::Scan);
207    }
208
209    #[test]
210    fn test_validate_insert_with_pk() {
211        let mut map = std::collections::HashMap::new();
212        map.insert("pk".to_string(), SqlValue::String("user#123".to_string()));
213        map.insert("name".to_string(), SqlValue::String("Alice".to_string()));
214
215        let stmt = InsertStatement {
216            table_name: "users".to_string(),
217            value: SqlValue::Map(map),
218        };
219
220        assert!(DynamoDBValidator::validate_insert(&stmt).is_ok());
221    }
222
223    #[test]
224    fn test_validate_insert_missing_pk() {
225        let mut map = std::collections::HashMap::new();
226        map.insert("name".to_string(), SqlValue::String("Alice".to_string()));
227
228        let stmt = InsertStatement {
229            table_name: "users".to_string(),
230            value: SqlValue::Map(map),
231        };
232
233        let result = DynamoDBValidator::validate_insert(&stmt);
234        assert!(result.is_err());
235        assert!(result.unwrap_err().to_string().contains("pk"));
236    }
237
238    #[test]
239    fn test_validate_update_with_pk() {
240        let stmt = UpdateStatement {
241            table_name: "users".to_string(),
242            where_clause: WhereClause {
243                conditions: vec![Condition {
244                    attribute: "pk".to_string(),
245                    operator: CompareOp::Equal,
246                    value: SqlValue::String("user#123".to_string()),
247                }],
248            },
249            set_assignments: vec![],
250            remove_attributes: vec![],
251        };
252
253        assert!(DynamoDBValidator::validate_update(&stmt).is_ok());
254    }
255
256    #[test]
257    fn test_validate_update_missing_pk() {
258        let stmt = UpdateStatement {
259            table_name: "users".to_string(),
260            where_clause: WhereClause {
261                conditions: vec![Condition {
262                    attribute: "age".to_string(),
263                    operator: CompareOp::GreaterThan,
264                    value: SqlValue::Number("18".to_string()),
265                }],
266            },
267            set_assignments: vec![],
268            remove_attributes: vec![],
269        };
270
271        let result = DynamoDBValidator::validate_update(&stmt);
272        assert!(result.is_err());
273        assert!(result.unwrap_err().to_string().contains("pk"));
274    }
275}