1use crate::partiql::ast::*;
6use crate::partiql::validator::{DynamoDBValidator, QueryType};
7use crate::{Error, Key, Result};
8use bytes::Bytes;
9
10pub struct PartiQLTranslator;
12
13impl PartiQLTranslator {
14 pub fn translate_select(stmt: &SelectStatement) -> Result<SelectTranslation> {
16 let query_type = DynamoDBValidator::validate_select(stmt)?;
18
19 match query_type {
20 QueryType::Query { pk_condition, sk_condition } => {
21 let (pk_bytes, multiple_pks) = Self::extract_pk_bytes(&pk_condition)?;
23
24 if multiple_pks {
25 Ok(SelectTranslation::MultiGet {
27 keys: pk_bytes,
28 index_name: stmt.index_name.clone(),
29 })
30 } else {
31 let pk = pk_bytes.into_iter().next().unwrap();
33 let sk_condition_translated = sk_condition
34 .as_ref()
35 .map(Self::translate_sk_condition)
36 .transpose()?;
37
38 Ok(SelectTranslation::Query {
39 pk,
40 sk_condition: sk_condition_translated,
41 index_name: stmt.index_name.clone(),
42 forward: stmt.order_by.as_ref().map_or(true, |o| o.ascending),
43 })
44 }
45 }
46 QueryType::Scan => {
47 Ok(SelectTranslation::Scan {
49 filter_conditions: stmt
50 .where_clause
51 .as_ref()
52 .map(|wc| wc.conditions.clone())
53 .unwrap_or_default(),
54 })
55 }
56 }
57 }
58
59 fn extract_pk_bytes(condition: &Condition) -> Result<(Vec<Bytes>, bool)> {
61 match &condition.operator {
62 CompareOp::Equal => {
63 let bytes = Self::value_to_bytes(&condition.value)?;
64 Ok((vec![bytes], false))
65 }
66 CompareOp::In => {
67 match &condition.value {
69 SqlValue::List(values) => {
70 let bytes_vec: Result<Vec<Bytes>> = values
71 .iter()
72 .map(Self::value_to_bytes)
73 .collect();
74 Ok((bytes_vec?, true))
75 }
76 _ => Err(Error::InvalidQuery("IN value must be a list".into())),
77 }
78 }
79 _ => Err(Error::InvalidQuery(
80 "Partition key must use = or IN operator".into(),
81 )),
82 }
83 }
84
85 fn value_to_bytes(value: &SqlValue) -> Result<Bytes> {
87 match value {
88 SqlValue::String(s) => Ok(Bytes::copy_from_slice(s.as_bytes())),
89 SqlValue::Number(n) => Ok(Bytes::copy_from_slice(n.as_bytes())),
90 _ => Err(Error::InvalidQuery(format!(
91 "Unsupported key value type: {:?}",
92 value
93 ))),
94 }
95 }
96
97 fn translate_sk_condition(condition: &Condition) -> Result<SortKeyConditionType> {
99 let sk_bytes = Self::value_to_bytes(&condition.value)?;
100
101 match condition.operator {
102 CompareOp::Equal => Ok(SortKeyConditionType::Equal(sk_bytes)),
103 CompareOp::LessThan => Ok(SortKeyConditionType::LessThan(sk_bytes)),
104 CompareOp::LessThanOrEqual => Ok(SortKeyConditionType::LessThanOrEqual(sk_bytes)),
105 CompareOp::GreaterThan => Ok(SortKeyConditionType::GreaterThan(sk_bytes)),
106 CompareOp::GreaterThanOrEqual => Ok(SortKeyConditionType::GreaterThanOrEqual(sk_bytes)),
107 CompareOp::Between => {
108 match &condition.value {
109 SqlValue::List(values) if values.len() == 2 => {
110 let low = Self::value_to_bytes(&values[0])?;
111 let high = Self::value_to_bytes(&values[1])?;
112 Ok(SortKeyConditionType::Between(low, high))
113 }
114 _ => Err(Error::InvalidQuery("BETWEEN requires exactly 2 values".into())),
115 }
116 }
117 _ => Err(Error::InvalidQuery(format!(
118 "Unsupported sort key operator: {:?}",
119 condition.operator
120 ))),
121 }
122 }
123
124 pub fn translate_insert(stmt: &InsertStatement) -> Result<InsertTranslation> {
126 DynamoDBValidator::validate_insert(stmt)?;
128
129 let value_map = match &stmt.value {
131 SqlValue::Map(map) => map,
132 _ => return Err(Error::InvalidQuery("INSERT value must be a map".into())),
133 };
134
135 let pk_value = value_map
137 .get("pk")
138 .ok_or_else(|| Error::InvalidQuery("INSERT value must contain 'pk'".into()))?;
139 let pk_bytes = Self::value_to_bytes(pk_value)?;
140
141 let sk_bytes = value_map
143 .get("sk")
144 .map(Self::value_to_bytes)
145 .transpose()?;
146
147 let key = if let Some(sk) = sk_bytes {
149 Key::with_sk(pk_bytes.to_vec(), sk.to_vec())
150 } else {
151 Key::new(pk_bytes.to_vec())
152 };
153
154 let mut item = std::collections::HashMap::new();
156 for (attr_name, attr_value) in value_map {
157 if attr_name != "pk" && attr_name != "sk" {
158 item.insert(attr_name.clone(), attr_value.to_kstone_value());
159 }
160 }
161
162 Ok(InsertTranslation { key, item })
163 }
164
165 pub fn translate_update(stmt: &UpdateStatement) -> Result<UpdateTranslation> {
167 DynamoDBValidator::validate_update(stmt)?;
169
170 let pk_cond = stmt
172 .where_clause
173 .get_condition("pk")
174 .ok_or_else(|| Error::InvalidQuery("UPDATE must specify pk in WHERE clause".into()))?;
175 let pk_bytes = Self::value_to_bytes(&pk_cond.value)?;
176
177 let sk_bytes = stmt
178 .where_clause
179 .get_condition("sk")
180 .map(|c| Self::value_to_bytes(&c.value))
181 .transpose()?;
182
183 let key = if let Some(sk) = sk_bytes {
184 Key::with_sk(pk_bytes.to_vec(), sk.to_vec())
185 } else {
186 Key::new(pk_bytes.to_vec())
187 };
188
189 let mut expression_parts = Vec::new();
191 let mut values = std::collections::HashMap::new();
192 let mut value_counter = 1;
193
194 if !stmt.set_assignments.is_empty() {
196 let mut set_exprs = Vec::new();
197 for assignment in &stmt.set_assignments {
198 match &assignment.value {
199 SetValue::Literal(sql_value) => {
200 let placeholder = format!(":v{}", value_counter);
202 set_exprs.push(format!("{} = {}", assignment.attribute, placeholder));
203 values.insert(placeholder, sql_value.to_kstone_value());
204 value_counter += 1;
205 }
206 SetValue::Add { attribute, value } => {
207 let placeholder = format!(":v{}", value_counter);
209 set_exprs.push(format!(
210 "{} = {} + {}",
211 assignment.attribute, attribute, placeholder
212 ));
213 values.insert(placeholder, value.to_kstone_value());
214 value_counter += 1;
215 }
216 SetValue::Subtract { attribute, value } => {
217 let placeholder = format!(":v{}", value_counter);
219 set_exprs.push(format!(
220 "{} = {} - {}",
221 assignment.attribute, attribute, placeholder
222 ));
223 values.insert(placeholder, value.to_kstone_value());
224 value_counter += 1;
225 }
226 }
227 }
228 expression_parts.push(format!("SET {}", set_exprs.join(", ")));
229 }
230
231 if !stmt.remove_attributes.is_empty() {
233 expression_parts.push(format!("REMOVE {}", stmt.remove_attributes.join(", ")));
234 }
235
236 let expression = expression_parts.join(" ");
237
238 Ok(UpdateTranslation {
239 key,
240 expression,
241 values,
242 })
243 }
244
245 pub fn translate_delete(stmt: &DeleteStatement) -> Result<DeleteTranslation> {
247 DynamoDBValidator::validate_delete(stmt)?;
249
250 let pk_cond = stmt
252 .where_clause
253 .get_condition("pk")
254 .ok_or_else(|| Error::InvalidQuery("DELETE must specify pk in WHERE clause".into()))?;
255 let pk_bytes = Self::value_to_bytes(&pk_cond.value)?;
256
257 let sk_bytes = stmt
258 .where_clause
259 .get_condition("sk")
260 .map(|c| Self::value_to_bytes(&c.value))
261 .transpose()?;
262
263 let key = if let Some(sk) = sk_bytes {
264 Key::with_sk(pk_bytes.to_vec(), sk.to_vec())
265 } else {
266 Key::new(pk_bytes.to_vec())
267 };
268
269 Ok(DeleteTranslation { key })
270 }
271}
272
273#[derive(Debug)]
275pub enum SelectTranslation {
276 Query {
278 pk: Bytes,
279 sk_condition: Option<SortKeyConditionType>,
280 index_name: Option<String>,
281 forward: bool,
282 },
283 MultiGet {
285 keys: Vec<Bytes>,
286 index_name: Option<String>,
287 },
288 Scan {
290 filter_conditions: Vec<Condition>,
291 },
292}
293
294#[derive(Debug, Clone)]
296pub enum SortKeyConditionType {
297 Equal(Bytes),
298 LessThan(Bytes),
299 LessThanOrEqual(Bytes),
300 GreaterThan(Bytes),
301 GreaterThanOrEqual(Bytes),
302 Between(Bytes, Bytes),
303}
304
305#[derive(Debug)]
307pub struct InsertTranslation {
308 pub key: Key,
309 pub item: crate::Item,
310}
311
312#[derive(Debug)]
314pub struct UpdateTranslation {
315 pub key: Key,
316 pub expression: String,
317 pub values: std::collections::HashMap<String, crate::Value>,
318}
319
320#[derive(Debug)]
322pub struct DeleteTranslation {
323 pub key: Key,
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_translate_select_query() {
332 let stmt = SelectStatement {
333 table_name: "users".to_string(),
334 index_name: None,
335 select_list: SelectList::All,
336 where_clause: Some(WhereClause {
337 conditions: vec![Condition {
338 attribute: "pk".to_string(),
339 operator: CompareOp::Equal,
340 value: SqlValue::String("user#123".to_string()),
341 }],
342 }),
343 order_by: None,
344 limit: None,
345 offset: None,
346 };
347
348 let translation = PartiQLTranslator::translate_select(&stmt).unwrap();
349 match translation {
350 SelectTranslation::Query { pk, .. } => {
351 assert_eq!(pk, Bytes::from("user#123"));
352 }
353 _ => panic!("Expected Query translation"),
354 }
355 }
356
357 #[test]
358 fn test_translate_select_scan() {
359 let stmt = SelectStatement {
360 table_name: "users".to_string(),
361 index_name: None,
362 select_list: SelectList::All,
363 where_clause: None,
364 order_by: None,
365 limit: None,
366 offset: None,
367 };
368
369 let translation = PartiQLTranslator::translate_select(&stmt).unwrap();
370 match translation {
371 SelectTranslation::Scan { .. } => {}
372 _ => panic!("Expected Scan translation"),
373 }
374 }
375
376 #[test]
377 fn test_translate_insert() {
378 let mut map = std::collections::HashMap::new();
379 map.insert("pk".to_string(), SqlValue::String("user#123".to_string()));
380 map.insert("name".to_string(), SqlValue::String("Alice".to_string()));
381 map.insert("age".to_string(), SqlValue::Number("30".to_string()));
382
383 let stmt = InsertStatement {
384 table_name: "users".to_string(),
385 value: SqlValue::Map(map),
386 };
387
388 let translation = PartiQLTranslator::translate_insert(&stmt).unwrap();
389 assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
390 assert_eq!(translation.item.len(), 2); }
392
393 #[test]
394 fn test_translate_delete() {
395 let stmt = DeleteStatement {
396 table_name: "users".to_string(),
397 where_clause: WhereClause {
398 conditions: vec![Condition {
399 attribute: "pk".to_string(),
400 operator: CompareOp::Equal,
401 value: SqlValue::String("user#123".to_string()),
402 }],
403 },
404 };
405
406 let translation = PartiQLTranslator::translate_delete(&stmt).unwrap();
407 assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
408 }
409
410 #[test]
411 fn test_translate_update_simple() {
412 let stmt = UpdateStatement {
413 table_name: "users".to_string(),
414 where_clause: WhereClause {
415 conditions: vec![Condition {
416 attribute: "pk".to_string(),
417 operator: CompareOp::Equal,
418 value: SqlValue::String("user#123".to_string()),
419 }],
420 },
421 set_assignments: vec![
422 SetAssignment {
423 attribute: "name".to_string(),
424 value: SetValue::Literal(SqlValue::String("Alice".to_string())),
425 },
426 SetAssignment {
427 attribute: "age".to_string(),
428 value: SetValue::Literal(SqlValue::Number("30".to_string())),
429 },
430 ],
431 remove_attributes: vec![],
432 };
433
434 let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
435 assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
436 assert!(translation.expression.contains("SET"));
437 assert_eq!(translation.values.len(), 2); }
439
440 #[test]
441 fn test_translate_update_with_arithmetic() {
442 let stmt = UpdateStatement {
443 table_name: "users".to_string(),
444 where_clause: WhereClause {
445 conditions: vec![Condition {
446 attribute: "pk".to_string(),
447 operator: CompareOp::Equal,
448 value: SqlValue::String("user#123".to_string()),
449 }],
450 },
451 set_assignments: vec![
452 SetAssignment {
453 attribute: "age".to_string(),
454 value: SetValue::Add {
455 attribute: "age".to_string(),
456 value: SqlValue::Number("1".to_string()),
457 },
458 },
459 SetAssignment {
460 attribute: "count".to_string(),
461 value: SetValue::Subtract {
462 attribute: "count".to_string(),
463 value: SqlValue::Number("5".to_string()),
464 },
465 },
466 ],
467 remove_attributes: vec![],
468 };
469
470 let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
471 assert!(translation.expression.contains("age = age + :v1"));
472 assert!(translation.expression.contains("count = count - :v2"));
473 assert_eq!(translation.values.len(), 2);
474 }
475
476 #[test]
477 fn test_translate_update_with_remove() {
478 let stmt = UpdateStatement {
479 table_name: "users".to_string(),
480 where_clause: WhereClause {
481 conditions: vec![Condition {
482 attribute: "pk".to_string(),
483 operator: CompareOp::Equal,
484 value: SqlValue::String("user#123".to_string()),
485 }],
486 },
487 set_assignments: vec![SetAssignment {
488 attribute: "name".to_string(),
489 value: SetValue::Literal(SqlValue::String("Alice".to_string())),
490 }],
491 remove_attributes: vec!["tags".to_string(), "metadata".to_string()],
492 };
493
494 let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
495 assert!(translation.expression.contains("SET"));
496 assert!(translation.expression.contains("REMOVE tags, metadata"));
497 assert_eq!(translation.values.len(), 1); }
499
500 #[test]
501 fn test_translate_update_remove_only() {
502 let stmt = UpdateStatement {
503 table_name: "users".to_string(),
504 where_clause: WhereClause {
505 conditions: vec![
506 Condition {
507 attribute: "pk".to_string(),
508 operator: CompareOp::Equal,
509 value: SqlValue::String("user#123".to_string()),
510 },
511 Condition {
512 attribute: "sk".to_string(),
513 operator: CompareOp::Equal,
514 value: SqlValue::String("profile".to_string()),
515 },
516 ],
517 },
518 set_assignments: vec![],
519 remove_attributes: vec!["tags".to_string(), "metadata".to_string()],
520 };
521
522 let translation = PartiQLTranslator::translate_update(&stmt).unwrap();
523 assert_eq!(translation.key.pk.as_ref(), "user#123".as_bytes());
524 assert_eq!(translation.key.sk.as_ref().map(|b| b.as_ref()), Some("profile".as_bytes()));
525 assert_eq!(translation.expression, "REMOVE tags, metadata");
526 assert_eq!(translation.values.len(), 0); }
528}