1use crate::partiql::ast::*;
7use crate::{Error, Result};
8use sqlparser::ast::{self as sql_ast, Statement};
9use sqlparser::dialect::GenericDialect;
10use sqlparser::parser::Parser as SqlParser;
11
12pub struct PartiQLParser;
14
15impl PartiQLParser {
16 pub fn parse(sql: &str) -> Result<PartiQLStatement> {
18 if sql.is_empty() {
20 return Err(Error::InvalidQuery("Statement cannot be empty".into()));
21 }
22 if sql.len() > 8192 {
23 return Err(Error::InvalidQuery(format!(
24 "Statement too long: {} chars (max 8192)",
25 sql.len()
26 )));
27 }
28
29 if sql.trim().to_uppercase().starts_with("INSERT") && sql.contains('{') {
31 return Self::parse_insert_with_json_map(sql);
32 }
33
34 if sql.trim().to_uppercase().starts_with("UPDATE") && sql.to_uppercase().contains(" REMOVE ") {
36 return Self::parse_update_with_remove(sql);
37 }
38
39 let normalized_sql = sql.replace(" VALUE ", " VALUES ");
42
43 let dialect = GenericDialect {};
45 let statements = SqlParser::parse_sql(&dialect, &normalized_sql).map_err(|e| {
46 Error::InvalidQuery(format!("Failed to parse SQL: {}", e))
47 })?;
48
49 if statements.is_empty() {
51 return Err(Error::InvalidQuery("No statement found".into()));
52 }
53 if statements.len() > 1 {
54 return Err(Error::InvalidQuery("Multiple statements not supported".into()));
55 }
56
57 Self::convert_statement(&statements[0])
59 }
60
61 fn parse_insert_with_json_map(sql: &str) -> Result<PartiQLStatement> {
64 let sql_upper = sql.to_uppercase();
66
67 let into_idx = sql_upper.find("INTO ").ok_or_else(|| {
69 Error::InvalidQuery("INSERT requires INTO clause".into())
70 })?;
71 let value_idx = sql_upper.find(" VALUE ").or_else(|| sql_upper.find(" VALUES ")).ok_or_else(|| {
72 Error::InvalidQuery("INSERT requires VALUE clause".into())
73 })?;
74
75 let table_part = sql[into_idx + 5..value_idx].trim();
77 let table_name = table_part.split_whitespace().next().ok_or_else(|| {
78 Error::InvalidQuery("Could not extract table name".into())
79 })?.to_string();
80
81 let brace_start = sql.find('{').ok_or_else(|| {
83 Error::InvalidQuery("Expected JSON map starting with {".into())
84 })?;
85 let brace_end = sql.rfind('}').ok_or_else(|| {
86 Error::InvalidQuery("Expected JSON map ending with }".into())
87 })?;
88
89 if brace_end <= brace_start {
90 return Err(Error::InvalidQuery("Invalid JSON map syntax".into()));
91 }
92
93 let json_str = &sql[brace_start..=brace_end];
94 let value_map = Self::parse_json_string(json_str)?;
95
96 Ok(PartiQLStatement::Insert(InsertStatement {
97 table_name,
98 value: value_map,
99 }))
100 }
101
102 fn convert_statement(stmt: &Statement) -> Result<PartiQLStatement> {
104 match stmt {
105 Statement::Query(query) => {
106 let select_stmt = Self::convert_select(query)?;
107 Ok(PartiQLStatement::Select(select_stmt))
108 }
109 Statement::Insert(insert) => {
110 let insert_stmt = Self::convert_insert(insert)?;
111 Ok(PartiQLStatement::Insert(insert_stmt))
112 }
113 Statement::Update {
114 table,
115 assignments,
116 selection,
117 ..
118 } => {
119 let update_stmt = Self::convert_update(table, assignments, selection, &[])?;
120 Ok(PartiQLStatement::Update(update_stmt))
121 }
122 Statement::Delete(delete) => {
123 let delete_stmt = Self::convert_delete(delete)?;
124 Ok(PartiQLStatement::Delete(delete_stmt))
125 }
126 _ => Err(Error::InvalidQuery(format!(
127 "Unsupported statement type: {:?}",
128 stmt
129 ))),
130 }
131 }
132
133 fn convert_select(query: &sql_ast::Query) -> Result<SelectStatement> {
135 if query.with.is_some() {
137 return Err(Error::InvalidQuery("WITH clause not supported".into()));
138 }
139 if query.fetch.is_some() {
142 return Err(Error::InvalidQuery("FETCH clause not supported".into()));
143 }
144
145 let set_expr = match &*query.body {
147 sql_ast::SetExpr::Select(select) => select,
148 _ => return Err(Error::InvalidQuery("Unsupported query type (no UNION/INTERSECT/EXCEPT)".into())),
149 };
150
151 if !set_expr.cluster_by.is_empty() {
153 return Err(Error::InvalidQuery("CLUSTER BY not supported".into()));
154 }
155 if !set_expr.distribute_by.is_empty() {
156 return Err(Error::InvalidQuery("DISTRIBUTE BY not supported".into()));
157 }
158 if set_expr.group_by != sql_ast::GroupByExpr::Expressions(vec![], vec![]) {
159 return Err(Error::InvalidQuery("GROUP BY not supported".into()));
160 }
161 if set_expr.having.is_some() {
162 return Err(Error::InvalidQuery("HAVING clause not supported".into()));
163 }
164 if !set_expr.named_window.is_empty() {
165 return Err(Error::InvalidQuery("Window functions not supported".into()));
166 }
167 if !set_expr.qualify.is_none() {
168 return Err(Error::InvalidQuery("QUALIFY clause not supported".into()));
169 }
170 if let Some(_top) = &set_expr.top {
171 return Err(Error::InvalidQuery("TOP clause not supported".into()));
172 }
173
174 let (table_name, index_name) = if set_expr.from.is_empty() {
176 return Err(Error::InvalidQuery("FROM clause required".into()));
177 } else if set_expr.from.len() > 1 {
178 return Err(Error::InvalidQuery("Multiple tables not supported (no JOINs)".into()));
179 } else {
180 Self::extract_table_reference(&set_expr.from[0])?
181 };
182
183 let select_list = Self::convert_select_list(&set_expr.projection)?;
185
186 let where_clause = match &set_expr.selection {
188 Some(expr) => Some(Self::convert_where_clause(expr)?),
189 None => None,
190 };
191
192 let order_by = match &query.order_by {
194 Some(order_by_clause) => Some(Self::convert_order_by(&order_by_clause.exprs)?),
195 None => None,
196 };
197
198 let limit = match &query.limit {
200 Some(expr) => Some(Self::extract_limit_value(expr)?),
201 None => None,
202 };
203
204 let offset = match &query.offset {
206 Some(offset_expr) => Some(Self::extract_offset_value(&offset_expr.value)?),
207 None => None,
208 };
209
210 Ok(SelectStatement {
211 table_name,
212 index_name,
213 select_list,
214 where_clause,
215 order_by,
216 limit,
217 offset,
218 })
219 }
220
221 fn extract_table_reference(from: &sql_ast::TableWithJoins) -> Result<(String, Option<String>)> {
223 if !from.joins.is_empty() {
225 return Err(Error::InvalidQuery("JOIN not supported".into()));
226 }
227
228 match &from.relation {
230 sql_ast::TableFactor::Table { name, .. } => {
231 let table_parts: Vec<&sql_ast::Ident> = name.0.iter().collect();
232
233 if table_parts.is_empty() {
234 return Err(Error::InvalidQuery("Empty table name".into()));
235 }
236
237 if table_parts.len() == 1 {
239 Ok((table_parts[0].value.clone(), None))
240 } else if table_parts.len() == 2 {
241 Ok((table_parts[0].value.clone(), Some(table_parts[1].value.clone())))
242 } else {
243 Err(Error::InvalidQuery(format!(
244 "Invalid table reference: expected 'table' or 'table.index', got {:?}",
245 name
246 )))
247 }
248 }
249 _ => Err(Error::InvalidQuery("Unsupported FROM clause (subqueries not allowed)".into())),
250 }
251 }
252
253 fn convert_select_list(projection: &[sql_ast::SelectItem]) -> Result<SelectList> {
255 if projection.is_empty() {
256 return Err(Error::InvalidQuery("Empty SELECT list".into()));
257 }
258
259 if projection.len() == 1 {
261 if let sql_ast::SelectItem::Wildcard(_) = &projection[0] {
262 return Ok(SelectList::All);
263 }
264 }
265
266 let mut attributes = Vec::new();
268 for item in projection {
269 match item {
270 sql_ast::SelectItem::UnnamedExpr(expr) => {
271 let attr_name = Self::extract_attribute_name(expr)?;
272 attributes.push(attr_name);
273 }
274 sql_ast::SelectItem::ExprWithAlias { expr, alias: _ } => {
275 let attr_name = Self::extract_attribute_name(expr)?;
276 attributes.push(attr_name);
277 }
278 sql_ast::SelectItem::Wildcard(_) => {
279 return Err(Error::InvalidQuery("Cannot mix * with other columns".into()));
280 }
281 _ => {
282 return Err(Error::InvalidQuery("Unsupported SELECT item".into()));
283 }
284 }
285 }
286
287 Ok(SelectList::Attributes(attributes))
288 }
289
290 fn extract_attribute_name(expr: &sql_ast::Expr) -> Result<String> {
292 match expr {
293 sql_ast::Expr::Identifier(ident) => Ok(ident.value.clone()),
294 sql_ast::Expr::CompoundIdentifier(parts) => {
295 if parts.len() == 1 {
296 Ok(parts[0].value.clone())
297 } else {
298 Err(Error::InvalidQuery(format!(
299 "Compound identifiers not supported: {:?}",
300 parts
301 )))
302 }
303 }
304 _ => Err(Error::InvalidQuery(format!(
305 "Unsupported expression in SELECT list: {:?}",
306 expr
307 ))),
308 }
309 }
310
311 fn convert_where_clause(expr: &sql_ast::Expr) -> Result<WhereClause> {
313 let mut conditions = Vec::new();
314 Self::extract_conditions(expr, &mut conditions)?;
315 Ok(WhereClause { conditions })
316 }
317
318 fn extract_conditions(expr: &sql_ast::Expr, conditions: &mut Vec<Condition>) -> Result<()> {
320 match expr {
321 sql_ast::Expr::BinaryOp { left, op, right } => {
322 use sqlparser::ast::BinaryOperator;
323
324 match op {
325 BinaryOperator::And => {
326 Self::extract_conditions(left, conditions)?;
328 Self::extract_conditions(right, conditions)?;
329 }
330 BinaryOperator::Or => {
331 return Err(Error::InvalidQuery("OR not supported in WHERE clause (use AND only)".into()));
332 }
333 BinaryOperator::Eq
334 | BinaryOperator::NotEq
335 | BinaryOperator::Lt
336 | BinaryOperator::LtEq
337 | BinaryOperator::Gt
338 | BinaryOperator::GtEq => {
339 let condition = Self::convert_comparison(left, op, right)?;
341 conditions.push(condition);
342 }
343 _ => {
344 return Err(Error::InvalidQuery(format!(
345 "Unsupported operator in WHERE clause: {:?}",
346 op
347 )));
348 }
349 }
350 }
351 sql_ast::Expr::InList { expr, list, negated } => {
352 if *negated {
353 return Err(Error::InvalidQuery("NOT IN not supported".into()));
354 }
355 let condition = Self::convert_in_condition(expr, list)?;
356 conditions.push(condition);
357 }
358 sql_ast::Expr::Between { expr, negated, low, high } => {
359 if *negated {
360 return Err(Error::InvalidQuery("NOT BETWEEN not supported".into()));
361 }
362 let condition = Self::convert_between_condition(expr, low, high)?;
363 conditions.push(condition);
364 }
365 _ => {
366 return Err(Error::InvalidQuery(format!(
367 "Unsupported WHERE clause expression: {:?}",
368 expr
369 )));
370 }
371 }
372 Ok(())
373 }
374
375 fn convert_comparison(
377 left: &sql_ast::Expr,
378 op: &sql_ast::BinaryOperator,
379 right: &sql_ast::Expr,
380 ) -> Result<Condition> {
381 use sqlparser::ast::BinaryOperator;
382
383 let attribute = Self::extract_attribute_name(left)?;
384 let compare_op = match op {
385 BinaryOperator::Eq => CompareOp::Equal,
386 BinaryOperator::NotEq => CompareOp::NotEqual,
387 BinaryOperator::Lt => CompareOp::LessThan,
388 BinaryOperator::LtEq => CompareOp::LessThanOrEqual,
389 BinaryOperator::Gt => CompareOp::GreaterThan,
390 BinaryOperator::GtEq => CompareOp::GreaterThanOrEqual,
391 _ => unreachable!(),
392 };
393 let value = Self::convert_value(right)?;
394
395 Ok(Condition {
396 attribute,
397 operator: compare_op,
398 value,
399 })
400 }
401
402 fn convert_in_condition(
404 expr: &sql_ast::Expr,
405 list: &[sql_ast::Expr],
406 ) -> Result<Condition> {
407 let attribute = Self::extract_attribute_name(expr)?;
408 let values: Result<Vec<SqlValue>> = list.iter().map(Self::convert_value).collect();
409
410 Ok(Condition {
411 attribute,
412 operator: CompareOp::In,
413 value: SqlValue::List(values?),
414 })
415 }
416
417 fn convert_between_condition(
419 expr: &sql_ast::Expr,
420 low: &sql_ast::Expr,
421 high: &sql_ast::Expr,
422 ) -> Result<Condition> {
423 let attribute = Self::extract_attribute_name(expr)?;
424 let low_val = Self::convert_value(low)?;
425 let high_val = Self::convert_value(high)?;
426
427 Ok(Condition {
428 attribute,
429 operator: CompareOp::Between,
430 value: SqlValue::List(vec![low_val, high_val]),
431 })
432 }
433
434 fn convert_value(expr: &sql_ast::Expr) -> Result<SqlValue> {
436 match expr {
437 sql_ast::Expr::Value(val) => Self::convert_sql_value(val),
438 _ => Err(Error::InvalidQuery(format!(
439 "Unsupported value expression: {:?}",
440 expr
441 ))),
442 }
443 }
444
445 fn convert_sql_value(val: &sql_ast::Value) -> Result<SqlValue> {
447 match val {
448 sql_ast::Value::Number(n, _) => Ok(SqlValue::Number(n.clone())),
449 sql_ast::Value::SingleQuotedString(s) | sql_ast::Value::DoubleQuotedString(s) => {
450 Ok(SqlValue::String(s.clone()))
451 }
452 sql_ast::Value::Boolean(b) => Ok(SqlValue::Boolean(*b)),
453 sql_ast::Value::Null => Ok(SqlValue::Null),
454 _ => Err(Error::InvalidQuery(format!("Unsupported SQL value: {:?}", val))),
455 }
456 }
457
458 fn convert_order_by(order_by: &[sql_ast::OrderByExpr]) -> Result<OrderBy> {
460 if order_by.is_empty() {
461 return Err(Error::InvalidQuery("Empty ORDER BY clause".into()));
462 }
463 if order_by.len() > 1 {
464 return Err(Error::InvalidQuery("ORDER BY multiple columns not supported".into()));
465 }
466
467 let expr = &order_by[0];
468 let attribute = Self::extract_attribute_name(&expr.expr)?;
469 let ascending = expr.asc.unwrap_or(true);
470
471 Ok(OrderBy {
472 attribute,
473 ascending,
474 })
475 }
476
477 fn convert_insert(insert: &sql_ast::Insert) -> Result<InsertStatement> {
479 let table_name = match &insert.table_name {
481 sql_ast::ObjectName(parts) => {
482 if parts.is_empty() {
483 return Err(Error::InvalidQuery("Empty table name in INSERT".into()));
484 }
485 parts[0].value.clone()
486 }
487 };
488
489 let value_map = match &insert.source {
492 Some(source) => {
493 match &*source.body {
496 sql_ast::SetExpr::Values(values) => {
497 if values.rows.is_empty() {
498 return Err(Error::InvalidQuery("INSERT requires VALUE clause".into()));
499 }
500 if values.rows.len() > 1 {
501 return Err(Error::InvalidQuery(
502 "INSERT can only insert one item at a time".into(),
503 ));
504 }
505
506 Self::parse_map_literal(&values.rows[0])?
508 }
509 _ => {
510 return Err(Error::InvalidQuery(
511 "INSERT only supports VALUE clause, not SELECT".into(),
512 ));
513 }
514 }
515 }
516 None => return Err(Error::InvalidQuery("INSERT requires VALUE clause".into())),
517 };
518
519 Ok(InsertStatement {
520 table_name,
521 value: value_map,
522 })
523 }
524
525 fn parse_map_literal(row: &[sql_ast::Expr]) -> Result<SqlValue> {
528 if row.len() != 1 {
532 return Err(Error::InvalidQuery(format!(
533 "Expected single map literal in INSERT, got {} expressions",
534 row.len()
535 )));
536 }
537
538 match &row[0] {
545 sql_ast::Expr::Function(func) => {
547 Self::parse_function_as_map(func)
549 }
550 sql_ast::Expr::JsonAccess { .. } | sql_ast::Expr::CompositeAccess { .. } => {
552 Err(Error::InvalidQuery(
553 "Composite/JSON access expressions not yet supported for INSERT".into(),
554 ))
555 }
556 sql_ast::Expr::Value(sql_ast::Value::SingleQuotedString(s)) => {
558 Self::parse_json_string(s)
559 }
560 _ => {
561 Err(Error::InvalidQuery(format!(
562 "Unsupported INSERT VALUE format. Expression type: {:?}",
563 row[0]
564 )))
565 }
566 }
567 }
568
569 fn parse_function_as_map(func: &sql_ast::Function) -> Result<SqlValue> {
571 let mut map = std::collections::HashMap::new();
573
574 let args_list = match &func.args {
576 sql_ast::FunctionArguments::List(args) => &args.args,
577 _ => return Err(Error::InvalidQuery("Expected argument list in function".into())),
578 };
579
580 for arg in args_list {
581 match arg {
582 sql_ast::FunctionArg::Unnamed(expr_wrapper) => {
583 let expr = match expr_wrapper {
584 sql_ast::FunctionArgExpr::Expr(e) => e,
585 _ => return Err(Error::InvalidQuery("Expected expression in argument".into())),
586 };
587
588 if let sql_ast::Expr::BinaryOp { left, op, right } = expr {
590 if matches!(op, sql_ast::BinaryOperator::Eq) {
591 let key = Self::extract_string_literal(&**left)?;
592 let value = Self::convert_value(&**right)?;
593 map.insert(key, value);
594 } else {
595 return Err(Error::InvalidQuery(
596 "Expected key = value pairs in map".into(),
597 ));
598 }
599 } else {
600 return Err(Error::InvalidQuery(
601 "Expected key = value pairs in map".into(),
602 ));
603 }
604 }
605 _ => {
606 return Err(Error::InvalidQuery(
607 "Unsupported function argument in map".into(),
608 ));
609 }
610 }
611 }
612
613 Ok(SqlValue::Map(map))
614 }
615
616 fn extract_string_literal(expr: &sql_ast::Expr) -> Result<String> {
618 match expr {
619 sql_ast::Expr::Value(sql_ast::Value::SingleQuotedString(s))
620 | sql_ast::Expr::Value(sql_ast::Value::DoubleQuotedString(s)) => Ok(s.clone()),
621 sql_ast::Expr::Identifier(ident) => Ok(ident.value.clone()),
622 _ => Err(Error::InvalidQuery(format!(
623 "Expected string literal, got: {:?}",
624 expr
625 ))),
626 }
627 }
628
629 fn extract_limit_value(expr: &sql_ast::Expr) -> Result<usize> {
631 match expr {
632 sql_ast::Expr::Value(sql_ast::Value::Number(n, _)) => {
633 n.parse::<usize>().map_err(|_| {
634 Error::InvalidQuery(format!("Invalid LIMIT value: {}", n))
635 })
636 }
637 _ => Err(Error::InvalidQuery(format!(
638 "LIMIT must be a positive integer, got: {:?}",
639 expr
640 ))),
641 }
642 }
643
644 fn extract_offset_value(expr: &sql_ast::Expr) -> Result<usize> {
646 match expr {
647 sql_ast::Expr::Value(sql_ast::Value::Number(n, _)) => {
648 n.parse::<usize>().map_err(|_| {
649 Error::InvalidQuery(format!("Invalid OFFSET value: {}", n))
650 })
651 }
652 _ => Err(Error::InvalidQuery(format!(
653 "OFFSET must be a positive integer, got: {:?}",
654 expr
655 ))),
656 }
657 }
658
659 fn parse_json_string(s: &str) -> Result<SqlValue> {
661 let json_normalized = s.replace('\'', "\"");
664
665 let json_value: serde_json::Value = serde_json::from_str(&json_normalized).map_err(|e| {
666 Error::InvalidQuery(format!("Failed to parse JSON: {}", e))
667 })?;
668
669 Self::json_to_sql_value(&json_value)
670 }
671
672 fn json_to_sql_value(value: &serde_json::Value) -> Result<SqlValue> {
674 match value {
675 serde_json::Value::Null => Ok(SqlValue::Null),
676 serde_json::Value::Bool(b) => Ok(SqlValue::Boolean(*b)),
677 serde_json::Value::Number(n) => Ok(SqlValue::Number(n.to_string())),
678 serde_json::Value::String(s) => Ok(SqlValue::String(s.clone())),
679 serde_json::Value::Array(arr) => {
680 let items: Result<Vec<SqlValue>> = arr.iter().map(Self::json_to_sql_value).collect();
681 Ok(SqlValue::List(items?))
682 }
683 serde_json::Value::Object(obj) => {
684 let mut map = std::collections::HashMap::new();
685 for (k, v) in obj {
686 map.insert(k.clone(), Self::json_to_sql_value(v)?);
687 }
688 Ok(SqlValue::Map(map))
689 }
690 }
691 }
692
693 fn convert_delete(delete: &sql_ast::Delete) -> Result<DeleteStatement> {
695 let from_tables = match &delete.from {
698 sql_ast::FromTable::WithFromKeyword(tables) => tables,
699 sql_ast::FromTable::WithoutKeyword(tables) => tables,
700 };
701
702 if from_tables.is_empty() {
703 return Err(Error::InvalidQuery("DELETE requires table name".into()));
704 }
705
706 let (table_name, index_name) = Self::extract_table_reference(&from_tables[0])?;
708
709 if index_name.is_some() {
711 return Err(Error::InvalidQuery("DELETE does not support index syntax".into()));
712 }
713
714 let where_clause = match &delete.selection {
716 Some(expr) => Self::convert_where_clause(expr)?,
717 None => return Err(Error::InvalidQuery("DELETE requires WHERE clause".into())),
718 };
719
720 Ok(DeleteStatement {
721 table_name,
722 where_clause,
723 })
724 }
725
726 fn parse_update_with_remove(sql: &str) -> Result<PartiQLStatement> {
730 let sql_upper = sql.to_uppercase();
731
732 let remove_idx = sql_upper.find(" REMOVE ").ok_or_else(|| {
734 Error::InvalidQuery("Expected REMOVE clause".into())
735 })?;
736 let where_idx = sql_upper.find(" WHERE ").ok_or_else(|| {
737 Error::InvalidQuery("UPDATE requires WHERE clause".into())
738 })?;
739
740 if where_idx <= remove_idx {
741 return Err(Error::InvalidQuery("REMOVE must come before WHERE".into()));
742 }
743
744 let remove_part = &sql[remove_idx + 8..where_idx].trim();
746 let remove_attributes: Vec<String> = remove_part
747 .split(',')
748 .map(|s| s.trim().to_string())
749 .filter(|s| !s.is_empty())
750 .collect();
751
752 if remove_attributes.is_empty() {
753 return Err(Error::InvalidQuery("REMOVE requires at least one attribute".into()));
754 }
755
756 let before_remove = &sql[..remove_idx];
758 let has_set = before_remove.to_uppercase().contains(" SET ");
759
760 let sql_without_remove = if has_set {
762 format!(
764 "{} {}",
765 &sql[..remove_idx].trim(),
766 &sql[where_idx..].trim()
767 )
768 } else {
769 format!(
772 "{} SET __dummy__ = 0 {}",
773 &sql[..remove_idx].trim(),
774 &sql[where_idx..].trim()
775 )
776 };
777
778 let dialect = GenericDialect {};
780 let statements = SqlParser::parse_sql(&dialect, &sql_without_remove).map_err(|e| {
781 Error::InvalidQuery(format!("Failed to parse UPDATE: {}", e))
782 })?;
783
784 if statements.is_empty() {
785 return Err(Error::InvalidQuery("No statement found".into()));
786 }
787
788 match &statements[0] {
790 Statement::Update {
791 table,
792 assignments,
793 selection,
794 ..
795 } => {
796 let update_stmt = Self::convert_update(table, assignments, selection, &remove_attributes)?;
797 Ok(PartiQLStatement::Update(update_stmt))
798 }
799 _ => Err(Error::InvalidQuery("Expected UPDATE statement".into())),
800 }
801 }
802
803 fn convert_update(
805 table: &sql_ast::TableWithJoins,
806 assignments: &[sql_ast::Assignment],
807 selection: &Option<sql_ast::Expr>,
808 remove_attributes: &[String],
809 ) -> Result<UpdateStatement> {
810 let (table_name, index_name) = Self::extract_table_reference(table)?;
812
813 if index_name.is_some() {
815 return Err(Error::InvalidQuery("UPDATE does not support index syntax".into()));
816 }
817
818 let where_clause = match selection {
820 Some(expr) => Self::convert_where_clause(expr)?,
821 None => return Err(Error::InvalidQuery("UPDATE requires WHERE clause".into())),
822 };
823
824 let set_assignments: Vec<SetAssignment> = assignments
826 .iter()
827 .map(Self::convert_assignment)
828 .collect::<Result<Vec<_>>>()?
829 .into_iter()
830 .filter(|a| a.attribute != "__dummy__")
831 .collect();
832
833 Ok(UpdateStatement {
834 table_name,
835 where_clause,
836 set_assignments,
837 remove_attributes: remove_attributes.to_vec(),
838 })
839 }
840
841 fn convert_assignment(assignment: &sql_ast::Assignment) -> Result<SetAssignment> {
843 let attribute = match &assignment.target {
845 sql_ast::AssignmentTarget::ColumnName(name) => {
846 match name {
847 sql_ast::ObjectName(parts) => {
848 if parts.is_empty() {
849 return Err(Error::InvalidQuery("Empty attribute name".into()));
850 }
851 parts[0].value.clone()
852 }
853 }
854 }
855 _ => return Err(Error::InvalidQuery("Unsupported assignment target".into())),
856 };
857
858 let value = Self::convert_set_value(&assignment.value)?;
860
861 Ok(SetAssignment { attribute, value })
862 }
863
864 fn convert_set_value(expr: &sql_ast::Expr) -> Result<SetValue> {
866 match expr {
867 sql_ast::Expr::Value(v) => {
869 let sql_value = Self::convert_sql_value(v)?;
870 Ok(SetValue::Literal(sql_value))
871 }
872 sql_ast::Expr::BinaryOp { left, op, right } => {
874 let attribute = match &**left {
876 sql_ast::Expr::Identifier(ident) => ident.value.clone(),
877 _ => return Err(Error::InvalidQuery(
878 "Arithmetic expression left side must be an attribute".into()
879 )),
880 };
881
882 let value = match &**right {
884 sql_ast::Expr::Value(v) => Self::convert_sql_value(v)?,
885 _ => return Err(Error::InvalidQuery(
886 "Arithmetic expression right side must be a literal".into()
887 )),
888 };
889
890 match op {
892 sql_ast::BinaryOperator::Plus => Ok(SetValue::Add { attribute, value }),
893 sql_ast::BinaryOperator::Minus => Ok(SetValue::Subtract { attribute, value }),
894 _ => Err(Error::InvalidQuery(format!(
895 "Unsupported arithmetic operator in SET: {:?}",
896 op
897 ))),
898 }
899 }
900 _ => Err(Error::InvalidQuery(format!(
901 "Unsupported SET value expression: {:?}",
902 expr
903 ))),
904 }
905 }
906}
907
908#[cfg(test)]
909mod tests {
910 use super::*;
911
912 #[test]
913 fn test_parse_simple_select() {
914 let sql = "SELECT * FROM users WHERE pk = 'user#123'";
915 let stmt = PartiQLParser::parse(sql).unwrap();
916
917 match stmt {
918 PartiQLStatement::Select(select) => {
919 assert_eq!(select.table_name, "users");
920 assert_eq!(select.index_name, None);
921 assert_eq!(select.select_list, SelectList::All);
922
923 let where_clause = select.where_clause.unwrap();
924 assert_eq!(where_clause.conditions.len(), 1);
925 assert_eq!(where_clause.conditions[0].attribute, "pk");
926 assert_eq!(where_clause.conditions[0].operator, CompareOp::Equal);
927 }
928 _ => panic!("Expected SELECT statement"),
929 }
930 }
931
932 #[test]
933 fn test_parse_select_with_index() {
934 let sql = "SELECT * FROM users.email_index WHERE pk = 'org#acme'";
935 let stmt = PartiQLParser::parse(sql).unwrap();
936
937 match stmt {
938 PartiQLStatement::Select(select) => {
939 assert_eq!(select.table_name, "users");
940 assert_eq!(select.index_name, Some("email_index".to_string()));
941 }
942 _ => panic!("Expected SELECT statement"),
943 }
944 }
945
946 #[test]
947 fn test_parse_select_with_attributes() {
948 let sql = "SELECT name, age FROM users WHERE pk = 'user#123'";
949 let stmt = PartiQLParser::parse(sql).unwrap();
950
951 match stmt {
952 PartiQLStatement::Select(select) => {
953 match select.select_list {
954 SelectList::Attributes(attrs) => {
955 assert_eq!(attrs, vec!["name", "age"]);
956 }
957 _ => panic!("Expected attribute list"),
958 }
959 }
960 _ => panic!("Expected SELECT statement"),
961 }
962 }
963
964 #[test]
965 fn test_parse_select_with_order_by() {
966 let sql = "SELECT * FROM users WHERE pk = 'user#123' ORDER BY sk DESC";
967 let stmt = PartiQLParser::parse(sql).unwrap();
968
969 match stmt {
970 PartiQLStatement::Select(select) => {
971 let order_by = select.order_by.unwrap();
972 assert_eq!(order_by.attribute, "sk");
973 assert_eq!(order_by.ascending, false);
974 }
975 _ => panic!("Expected SELECT statement"),
976 }
977 }
978
979 #[test]
980 fn test_parse_select_with_in() {
981 let sql = "SELECT * FROM users WHERE pk IN ('user#1', 'user#2')";
982 let stmt = PartiQLParser::parse(sql).unwrap();
983
984 match stmt {
985 PartiQLStatement::Select(select) => {
986 let where_clause = select.where_clause.unwrap();
987 assert_eq!(where_clause.conditions[0].operator, CompareOp::In);
988 }
989 _ => panic!("Expected SELECT statement"),
990 }
991 }
992
993 #[test]
994 fn test_parse_select_with_between() {
995 let sql = "SELECT * FROM users WHERE pk = 'user#123' AND age BETWEEN 18 AND 65";
996 let stmt = PartiQLParser::parse(sql).unwrap();
997
998 match stmt {
999 PartiQLStatement::Select(select) => {
1000 let where_clause = select.where_clause.unwrap();
1001 assert_eq!(where_clause.conditions.len(), 2);
1002
1003 let between_cond = where_clause.conditions.iter()
1005 .find(|c| c.operator == CompareOp::Between)
1006 .unwrap();
1007 assert_eq!(between_cond.attribute, "age");
1008 }
1009 _ => panic!("Expected SELECT statement"),
1010 }
1011 }
1012
1013 #[test]
1014 fn test_reject_join() {
1015 let sql = "SELECT * FROM users JOIN orders ON users.pk = orders.user_id";
1016 let result = PartiQLParser::parse(sql);
1017 assert!(result.is_err());
1018 assert!(result.unwrap_err().to_string().contains("JOIN"));
1019 }
1020
1021 #[test]
1022 fn test_reject_or() {
1023 let sql = "SELECT * FROM users WHERE pk = 'user#123' OR pk = 'user#456'";
1024 let result = PartiQLParser::parse(sql);
1025 assert!(result.is_err());
1026 assert!(result.unwrap_err().to_string().contains("OR"));
1027 }
1028
1029 #[test]
1030 fn test_reject_group_by() {
1031 let sql = "SELECT COUNT(*) FROM users GROUP BY status";
1032 let result = PartiQLParser::parse(sql);
1033 assert!(result.is_err());
1034 assert!(result.unwrap_err().to_string().contains("GROUP BY"));
1035 }
1036
1037 #[test]
1038 fn test_reject_too_long() {
1039 let sql = "SELECT * FROM users WHERE pk = '".to_string() + &"x".repeat(10000) + "'";
1040 let result = PartiQLParser::parse(&sql);
1041 assert!(result.is_err());
1042 assert!(result.unwrap_err().to_string().contains("too long"));
1043 }
1044
1045 #[test]
1047 fn test_parse_delete_with_pk() {
1048 let sql = "DELETE FROM users WHERE pk = 'user#123'";
1049 let stmt = PartiQLParser::parse(sql).unwrap();
1050
1051 match stmt {
1052 PartiQLStatement::Delete(delete) => {
1053 assert_eq!(delete.table_name, "users");
1054 assert_eq!(delete.where_clause.conditions.len(), 1);
1055 assert_eq!(delete.where_clause.conditions[0].attribute, "pk");
1056 assert_eq!(delete.where_clause.conditions[0].operator, CompareOp::Equal);
1057 }
1058 _ => panic!("Expected DELETE statement"),
1059 }
1060 }
1061
1062 #[test]
1063 fn test_parse_delete_with_pk_and_sk() {
1064 let sql = "DELETE FROM users WHERE pk = 'user#123' AND sk = 'profile'";
1065 let stmt = PartiQLParser::parse(sql).unwrap();
1066
1067 match stmt {
1068 PartiQLStatement::Delete(delete) => {
1069 assert_eq!(delete.table_name, "users");
1070 assert_eq!(delete.where_clause.conditions.len(), 2);
1071
1072 let pk_cond = delete.where_clause.get_condition("pk").unwrap();
1073 assert_eq!(pk_cond.operator, CompareOp::Equal);
1074
1075 let sk_cond = delete.where_clause.get_condition("sk").unwrap();
1076 assert_eq!(sk_cond.operator, CompareOp::Equal);
1077 }
1078 _ => panic!("Expected DELETE statement"),
1079 }
1080 }
1081
1082 #[test]
1083 fn test_reject_delete_without_where() {
1084 let sql = "DELETE FROM users";
1085 let result = PartiQLParser::parse(sql);
1086 assert!(result.is_err());
1087 assert!(result.unwrap_err().to_string().contains("WHERE"));
1088 }
1089
1090 #[test]
1092 fn test_parse_insert_simple() {
1093 let sql = "INSERT INTO users VALUE {'pk': 'user#123', 'name': 'Alice', 'age': 30}";
1094 let stmt = PartiQLParser::parse(sql).unwrap();
1095
1096 match stmt {
1097 PartiQLStatement::Insert(insert) => {
1098 assert_eq!(insert.table_name, "users");
1099 match &insert.value {
1100 SqlValue::Map(map) => {
1101 assert_eq!(map.len(), 3);
1102 assert_eq!(map.get("pk"), Some(&SqlValue::String("user#123".to_string())));
1103 assert_eq!(map.get("name"), Some(&SqlValue::String("Alice".to_string())));
1104 assert_eq!(map.get("age"), Some(&SqlValue::Number("30".to_string())));
1105 }
1106 _ => panic!("Expected Map value"),
1107 }
1108 }
1109 _ => panic!("Expected INSERT statement"),
1110 }
1111 }
1112
1113 #[test]
1114 fn test_parse_insert_with_sk() {
1115 let sql = "INSERT INTO users VALUE {'pk': 'user#123', 'sk': 'profile', 'email': 'alice@example.com'}";
1116 let stmt = PartiQLParser::parse(sql).unwrap();
1117
1118 match stmt {
1119 PartiQLStatement::Insert(insert) => {
1120 assert_eq!(insert.table_name, "users");
1121 match &insert.value {
1122 SqlValue::Map(map) => {
1123 assert!(map.contains_key("pk"));
1124 assert!(map.contains_key("sk"));
1125 assert_eq!(map.get("sk"), Some(&SqlValue::String("profile".to_string())));
1126 }
1127 _ => panic!("Expected Map value"),
1128 }
1129 }
1130 _ => panic!("Expected INSERT statement"),
1131 }
1132 }
1133
1134 #[test]
1135 fn test_parse_insert_nested_values() {
1136 let sql = r#"INSERT INTO users VALUE {'pk': 'user#123', 'profile': {'name': 'Alice', 'age': 30}, 'tags': ['admin', 'active']}"#;
1137 let stmt = PartiQLParser::parse(sql).unwrap();
1138
1139 match stmt {
1140 PartiQLStatement::Insert(insert) => {
1141 match &insert.value {
1142 SqlValue::Map(map) => {
1143 match map.get("profile") {
1145 Some(SqlValue::Map(profile)) => {
1146 assert_eq!(profile.get("name"), Some(&SqlValue::String("Alice".to_string())));
1147 }
1148 _ => panic!("Expected nested map for profile"),
1149 }
1150
1151 match map.get("tags") {
1153 Some(SqlValue::List(tags)) => {
1154 assert_eq!(tags.len(), 2);
1155 }
1156 _ => panic!("Expected list for tags"),
1157 }
1158 }
1159 _ => panic!("Expected Map value"),
1160 }
1161 }
1162 _ => panic!("Expected INSERT statement"),
1163 }
1164 }
1165
1166 #[test]
1167 fn test_parse_insert_various_types() {
1168 let sql = "INSERT INTO items VALUE {'pk': 'item#1', 'price': 29.99, 'active': true, 'description': null}";
1169 let stmt = PartiQLParser::parse(sql).unwrap();
1170
1171 match stmt {
1172 PartiQLStatement::Insert(insert) => {
1173 match &insert.value {
1174 SqlValue::Map(map) => {
1175 match map.get("price") {
1177 Some(SqlValue::Number(n)) => assert_eq!(n, "29.99"),
1178 _ => panic!("Expected number for price"),
1179 }
1180
1181 assert_eq!(map.get("active"), Some(&SqlValue::Boolean(true)));
1183
1184 assert_eq!(map.get("description"), Some(&SqlValue::Null));
1186 }
1187 _ => panic!("Expected Map value"),
1188 }
1189 }
1190 _ => panic!("Expected INSERT statement"),
1191 }
1192 }
1193
1194 #[test]
1195 fn test_reject_insert_without_map() {
1196 let sql = "INSERT INTO users VALUE 'not a map'";
1197 let result = PartiQLParser::parse(sql);
1198 assert!(result.is_err());
1199 }
1200
1201 #[test]
1203 fn test_parse_update_simple() {
1204 let sql = "UPDATE users SET name = 'Alice', age = 30 WHERE pk = 'user#123'";
1205 let stmt = PartiQLParser::parse(sql).unwrap();
1206
1207 match stmt {
1208 PartiQLStatement::Update(update) => {
1209 assert_eq!(update.table_name, "users");
1210 assert_eq!(update.set_assignments.len(), 2);
1211 assert_eq!(update.remove_attributes.len(), 0);
1212
1213 assert_eq!(update.set_assignments[0].attribute, "name");
1215 match &update.set_assignments[0].value {
1216 SetValue::Literal(SqlValue::String(s)) => assert_eq!(s, "Alice"),
1217 _ => panic!("Expected string literal"),
1218 }
1219
1220 assert!(update.where_clause.has_condition("pk"));
1222 }
1223 _ => panic!("Expected UPDATE statement"),
1224 }
1225 }
1226
1227 #[test]
1228 fn test_parse_update_with_arithmetic() {
1229 let sql = "UPDATE users SET age = age + 1, count = count - 5 WHERE pk = 'user#123'";
1230 let stmt = PartiQLParser::parse(sql).unwrap();
1231
1232 match stmt {
1233 PartiQLStatement::Update(update) => {
1234 assert_eq!(update.set_assignments.len(), 2);
1235
1236 match &update.set_assignments[0].value {
1238 SetValue::Add { attribute, value } => {
1239 assert_eq!(attribute, "age");
1240 match value {
1241 SqlValue::Number(n) => assert_eq!(n, "1"),
1242 _ => panic!("Expected number"),
1243 }
1244 }
1245 _ => panic!("Expected Add operation"),
1246 }
1247
1248 match &update.set_assignments[1].value {
1250 SetValue::Subtract { attribute, value } => {
1251 assert_eq!(attribute, "count");
1252 match value {
1253 SqlValue::Number(n) => assert_eq!(n, "5"),
1254 _ => panic!("Expected number"),
1255 }
1256 }
1257 _ => panic!("Expected Subtract operation"),
1258 }
1259 }
1260 _ => panic!("Expected UPDATE statement"),
1261 }
1262 }
1263
1264 #[test]
1265 fn test_parse_update_with_remove() {
1266 let sql = "UPDATE users SET name = 'Alice' REMOVE tags, metadata WHERE pk = 'user#123'";
1267 let stmt = PartiQLParser::parse(sql).unwrap();
1268
1269 match stmt {
1270 PartiQLStatement::Update(update) => {
1271 assert_eq!(update.table_name, "users");
1272 assert_eq!(update.set_assignments.len(), 1);
1273 assert_eq!(update.remove_attributes.len(), 2);
1274 assert_eq!(update.remove_attributes[0], "tags");
1275 assert_eq!(update.remove_attributes[1], "metadata");
1276 }
1277 _ => panic!("Expected UPDATE statement"),
1278 }
1279 }
1280
1281 #[test]
1282 fn test_parse_update_remove_only() {
1283 let sql = "UPDATE users REMOVE tags, metadata WHERE pk = 'user#123' AND sk = 'profile'";
1284 let stmt = PartiQLParser::parse(sql).unwrap();
1285
1286 match stmt {
1287 PartiQLStatement::Update(update) => {
1288 assert_eq!(update.set_assignments.len(), 0);
1289 assert_eq!(update.remove_attributes.len(), 2);
1290 assert_eq!(update.where_clause.conditions.len(), 2);
1291 }
1292 _ => panic!("Expected UPDATE statement"),
1293 }
1294 }
1295
1296 #[test]
1297 fn test_reject_update_without_where() {
1298 let sql = "UPDATE users SET name = 'Alice'";
1299 let result = PartiQLParser::parse(sql);
1300 assert!(result.is_err());
1301 assert!(result.unwrap_err().to_string().contains("WHERE"));
1302 }
1303
1304 #[test]
1305 fn test_parse_select_with_limit() {
1306 let sql = "SELECT * FROM users WHERE pk = 'user#123' LIMIT 10";
1307 let stmt = PartiQLParser::parse(sql).unwrap();
1308
1309 match stmt {
1310 PartiQLStatement::Select(select) => {
1311 assert_eq!(select.table_name, "users");
1312 assert_eq!(select.limit, Some(10));
1313 assert_eq!(select.offset, None);
1314 }
1315 _ => panic!("Expected SELECT statement"),
1316 }
1317 }
1318
1319 #[test]
1320 fn test_parse_select_with_offset() {
1321 let sql = "SELECT * FROM users WHERE pk = 'user#123' OFFSET 5";
1322 let stmt = PartiQLParser::parse(sql).unwrap();
1323
1324 match stmt {
1325 PartiQLStatement::Select(select) => {
1326 assert_eq!(select.table_name, "users");
1327 assert_eq!(select.limit, None);
1328 assert_eq!(select.offset, Some(5));
1329 }
1330 _ => panic!("Expected SELECT statement"),
1331 }
1332 }
1333
1334 #[test]
1335 fn test_parse_select_with_limit_and_offset() {
1336 let sql = "SELECT * FROM users WHERE pk = 'user#123' LIMIT 20 OFFSET 10";
1337 let stmt = PartiQLParser::parse(sql).unwrap();
1338
1339 match stmt {
1340 PartiQLStatement::Select(select) => {
1341 assert_eq!(select.table_name, "users");
1342 assert_eq!(select.limit, Some(20));
1343 assert_eq!(select.offset, Some(10));
1344 }
1345 _ => panic!("Expected SELECT statement"),
1346 }
1347 }
1348}