1use crate::partiql::ast::*;
9use crate::{Error, Result};
10
11#[derive(Debug, Clone, PartialEq)]
13pub enum QueryType {
14 Query {
16 pk_condition: Condition,
18 sk_condition: Option<Condition>,
20 },
21 Scan,
23}
24
25pub struct DynamoDBValidator;
27
28impl DynamoDBValidator {
29 pub fn validate_select(stmt: &SelectStatement) -> Result<QueryType> {
31 let where_clause = match &stmt.where_clause {
33 Some(wc) => wc,
34 None => return Ok(QueryType::Scan),
35 };
36
37 if let Some(pk_cond) = where_clause.get_condition("pk") {
39 match pk_cond.operator {
41 CompareOp::Equal | CompareOp::In => {
42 let sk_condition = where_clause.get_condition("sk").cloned();
44 Ok(QueryType::Query {
45 pk_condition: pk_cond.clone(),
46 sk_condition,
47 })
48 }
49 _ => {
50 Ok(QueryType::Scan)
52 }
53 }
54 } else {
55 Ok(QueryType::Scan)
57 }
58 }
59
60 pub fn validate_insert(stmt: &InsertStatement) -> Result<()> {
62 match &stmt.value {
64 SqlValue::Map(map) => {
65 if !map.contains_key("pk") {
67 return Err(Error::InvalidQuery(
68 "INSERT value must contain 'pk' field".into(),
69 ));
70 }
71 Ok(())
72 }
73 _ => Err(Error::InvalidQuery(
74 "INSERT value must be a map/object".into(),
75 )),
76 }
77 }
78
79 pub fn validate_update(stmt: &UpdateStatement) -> Result<()> {
81 if !stmt.where_clause.has_condition("pk") {
83 return Err(Error::InvalidQuery(
84 "UPDATE must specify partition key (pk) in WHERE clause".into(),
85 ));
86 }
87 Ok(())
88 }
89
90 pub fn validate_delete(stmt: &DeleteStatement) -> Result<()> {
92 if !stmt.where_clause.has_condition("pk") {
94 return Err(Error::InvalidQuery(
95 "DELETE must specify partition key (pk) in WHERE clause".into(),
96 ));
97 }
98 Ok(())
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 #[test]
107 fn test_validate_select_query_with_pk_equal() {
108 let stmt = SelectStatement {
109 table_name: "users".to_string(),
110 index_name: None,
111 select_list: SelectList::All,
112 where_clause: Some(WhereClause {
113 conditions: vec![Condition {
114 attribute: "pk".to_string(),
115 operator: CompareOp::Equal,
116 value: SqlValue::String("user#123".to_string()),
117 }],
118 }),
119 order_by: None,
120 limit: None,
121 offset: None,
122 };
123
124 let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
125 match query_type {
126 QueryType::Query { pk_condition, sk_condition } => {
127 assert_eq!(pk_condition.attribute, "pk");
128 assert_eq!(pk_condition.operator, CompareOp::Equal);
129 assert!(sk_condition.is_none());
130 }
131 _ => panic!("Expected Query type"),
132 }
133 }
134
135 #[test]
136 fn test_validate_select_query_with_pk_and_sk() {
137 let stmt = SelectStatement {
138 table_name: "users".to_string(),
139 index_name: None,
140 select_list: SelectList::All,
141 where_clause: Some(WhereClause {
142 conditions: vec![
143 Condition {
144 attribute: "pk".to_string(),
145 operator: CompareOp::Equal,
146 value: SqlValue::String("user#123".to_string()),
147 },
148 Condition {
149 attribute: "sk".to_string(),
150 operator: CompareOp::GreaterThan,
151 value: SqlValue::String("post#".to_string()),
152 },
153 ],
154 }),
155 order_by: None,
156 limit: None,
157 offset: None,
158 };
159
160 let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
161 match query_type {
162 QueryType::Query { pk_condition, sk_condition } => {
163 assert_eq!(pk_condition.attribute, "pk");
164 assert!(sk_condition.is_some());
165 assert_eq!(sk_condition.unwrap().attribute, "sk");
166 }
167 _ => panic!("Expected Query type"),
168 }
169 }
170
171 #[test]
172 fn test_validate_select_scan_no_where() {
173 let stmt = SelectStatement {
174 table_name: "users".to_string(),
175 index_name: None,
176 select_list: SelectList::All,
177 where_clause: None,
178 order_by: None,
179 limit: None,
180 offset: None,
181 };
182
183 let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
184 assert_eq!(query_type, QueryType::Scan);
185 }
186
187 #[test]
188 fn test_validate_select_scan_no_pk() {
189 let stmt = SelectStatement {
190 table_name: "users".to_string(),
191 index_name: None,
192 select_list: SelectList::All,
193 where_clause: Some(WhereClause {
194 conditions: vec![Condition {
195 attribute: "age".to_string(),
196 operator: CompareOp::GreaterThan,
197 value: SqlValue::Number("18".to_string()),
198 }],
199 }),
200 order_by: None,
201 limit: None,
202 offset: None,
203 };
204
205 let query_type = DynamoDBValidator::validate_select(&stmt).unwrap();
206 assert_eq!(query_type, QueryType::Scan);
207 }
208
209 #[test]
210 fn test_validate_insert_with_pk() {
211 let mut map = std::collections::HashMap::new();
212 map.insert("pk".to_string(), SqlValue::String("user#123".to_string()));
213 map.insert("name".to_string(), SqlValue::String("Alice".to_string()));
214
215 let stmt = InsertStatement {
216 table_name: "users".to_string(),
217 value: SqlValue::Map(map),
218 };
219
220 assert!(DynamoDBValidator::validate_insert(&stmt).is_ok());
221 }
222
223 #[test]
224 fn test_validate_insert_missing_pk() {
225 let mut map = std::collections::HashMap::new();
226 map.insert("name".to_string(), SqlValue::String("Alice".to_string()));
227
228 let stmt = InsertStatement {
229 table_name: "users".to_string(),
230 value: SqlValue::Map(map),
231 };
232
233 let result = DynamoDBValidator::validate_insert(&stmt);
234 assert!(result.is_err());
235 assert!(result.unwrap_err().to_string().contains("pk"));
236 }
237
238 #[test]
239 fn test_validate_update_with_pk() {
240 let stmt = UpdateStatement {
241 table_name: "users".to_string(),
242 where_clause: WhereClause {
243 conditions: vec![Condition {
244 attribute: "pk".to_string(),
245 operator: CompareOp::Equal,
246 value: SqlValue::String("user#123".to_string()),
247 }],
248 },
249 set_assignments: vec![],
250 remove_attributes: vec![],
251 };
252
253 assert!(DynamoDBValidator::validate_update(&stmt).is_ok());
254 }
255
256 #[test]
257 fn test_validate_update_missing_pk() {
258 let stmt = UpdateStatement {
259 table_name: "users".to_string(),
260 where_clause: WhereClause {
261 conditions: vec![Condition {
262 attribute: "age".to_string(),
263 operator: CompareOp::GreaterThan,
264 value: SqlValue::Number("18".to_string()),
265 }],
266 },
267 set_assignments: vec![],
268 remove_attributes: vec![],
269 };
270
271 let result = DynamoDBValidator::validate_update(&stmt);
272 assert!(result.is_err());
273 assert!(result.unwrap_err().to_string().contains("pk"));
274 }
275}