1pub mod aggregate;
31pub mod ast;
32pub mod bridge;
33pub mod compatibility;
34pub mod error;
35pub mod lexer;
36pub mod parser;
37pub mod token;
38
39pub use ast::*;
40pub use bridge::{ExecutionResult as BridgeExecutionResult, SqlBridge, SqlConnection};
41pub use compatibility::{
42 CompatibilityMatrix, FeatureSupport, SqlDialect, SqlFeature, get_feature_support,
43};
44pub use error::{SqlError, SqlResult};
45pub use lexer::{LexError, Lexer};
46pub use parser::{ParseError, Parser};
47pub use token::{Span, Token, TokenKind};
48
49use sochdb_core::SochValue;
50use std::collections::HashMap;
51
52#[derive(Debug, Clone)]
54pub enum ExecutionResult {
55 Rows {
57 columns: Vec<String>,
58 rows: Vec<HashMap<String, SochValue>>,
59 },
60 RowsAffected(usize),
62 Ok,
64}
65
66impl ExecutionResult {
67 pub fn rows(&self) -> Option<&Vec<HashMap<String, SochValue>>> {
69 match self {
70 ExecutionResult::Rows { rows, .. } => Some(rows),
71 _ => None,
72 }
73 }
74
75 pub fn columns(&self) -> Option<&Vec<String>> {
77 match self {
78 ExecutionResult::Rows { columns, .. } => Some(columns),
79 _ => None,
80 }
81 }
82
83 pub fn rows_affected(&self) -> usize {
85 match self {
86 ExecutionResult::RowsAffected(n) => *n,
87 ExecutionResult::Rows { rows, .. } => rows.len(),
88 ExecutionResult::Ok => 0,
89 }
90 }
91}
92
93pub struct SqlExecutor {
97 tables: HashMap<String, TableData>,
99}
100
101#[derive(Debug, Clone)]
103pub struct TableData {
104 pub columns: Vec<String>,
105 pub column_types: Vec<DataType>,
106 pub rows: Vec<Vec<SochValue>>,
107}
108
109impl Default for SqlExecutor {
110 fn default() -> Self {
111 Self::new()
112 }
113}
114
115impl SqlExecutor {
116 pub fn new() -> Self {
118 Self {
119 tables: HashMap::new(),
120 }
121 }
122
123 pub fn execute(&mut self, sql: &str) -> SqlResult<ExecutionResult> {
125 self.execute_with_params(sql, &[])
126 }
127
128 pub fn execute_with_params(
130 &mut self,
131 sql: &str,
132 params: &[SochValue],
133 ) -> SqlResult<ExecutionResult> {
134 let stmt = Parser::parse(sql).map_err(SqlError::from_parse_errors)?;
135 self.execute_statement(&stmt, params)
136 }
137
138 pub fn execute_statement(
140 &mut self,
141 stmt: &Statement,
142 params: &[SochValue],
143 ) -> SqlResult<ExecutionResult> {
144 match stmt {
145 Statement::Select(select) => self.execute_select(select, params),
146 Statement::Insert(insert) => self.execute_insert(insert, params),
147 Statement::Update(update) => self.execute_update(update, params),
148 Statement::Delete(delete) => self.execute_delete(delete, params),
149 Statement::CreateTable(create) => self.execute_create_table(create),
150 Statement::DropTable(drop) => self.execute_drop_table(drop),
151 Statement::Begin(_) => Ok(ExecutionResult::Ok),
152 Statement::Commit => Ok(ExecutionResult::Ok),
153 Statement::Rollback(_) => Ok(ExecutionResult::Ok),
154 _ => Err(SqlError::NotImplemented(
155 "Statement type not yet supported".into(),
156 )),
157 }
158 }
159
160 fn execute_select(
161 &self,
162 select: &SelectStmt,
163 params: &[SochValue],
164 ) -> SqlResult<ExecutionResult> {
165 let from = select
167 .from
168 .as_ref()
169 .ok_or_else(|| SqlError::InvalidArgument("SELECT requires FROM clause".into()))?;
170
171 if from.tables.len() != 1 {
172 return Err(SqlError::NotImplemented(
173 "Multi-table queries not yet supported".into(),
174 ));
175 }
176
177 let table_name = match &from.tables[0] {
178 TableRef::Table { name, .. } => name.name().to_string(),
179 _ => {
180 return Err(SqlError::NotImplemented(
181 "Complex table references not yet supported".into(),
182 ));
183 }
184 };
185
186 let table = self
187 .tables
188 .get(&table_name)
189 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
190
191 let mut source_rows = Vec::new();
193
194 for row in &table.rows {
195 let row_map: HashMap<String, SochValue> = table
197 .columns
198 .iter()
199 .zip(row.iter())
200 .map(|(col, val)| (col.clone(), val.clone()))
201 .collect();
202
203 if let Some(where_clause) = &select.where_clause
205 && !self.evaluate_where(where_clause, &row_map, params)?
206 {
207 continue;
208 }
209
210 source_rows.push(row_map);
211 }
212
213 if !select.order_by.is_empty() {
215 source_rows.sort_by(|a, b| {
216 for order_item in &select.order_by {
217 if let Expr::Column(col_ref) = &order_item.expr {
218 let a_val = a.get(&col_ref.column);
219 let b_val = b.get(&col_ref.column);
220
221 let cmp = self.compare_values(a_val, b_val);
222 if cmp != std::cmp::Ordering::Equal {
223 return if order_item.asc { cmp } else { cmp.reverse() };
224 }
225 }
226 }
227 std::cmp::Ordering::Equal
228 });
229 }
230
231 if let Some(Expr::Literal(Literal::Integer(n))) = &select.offset {
233 let n = *n as usize;
234 if n < source_rows.len() {
235 source_rows = source_rows.into_iter().skip(n).collect();
236 } else {
237 source_rows.clear();
238 }
239 }
240
241 if let Some(Expr::Literal(Literal::Integer(n))) = &select.limit {
243 source_rows.truncate(*n as usize);
244 }
245
246 let mut output_columns: Vec<String> = Vec::new();
248 let mut result_rows: Vec<HashMap<String, SochValue>> = Vec::new();
249
250 let is_wildcard = matches!(&select.columns[..], [SelectItem::Wildcard]);
252
253 if is_wildcard {
254 output_columns = table.columns.clone();
255 result_rows = source_rows;
256 } else {
257 for item in &select.columns {
259 match item {
260 SelectItem::Wildcard => output_columns.push("*".to_string()),
261 SelectItem::QualifiedWildcard(t) => output_columns.push(format!("{}.*", t)),
262 SelectItem::Expr { expr, alias } => {
263 let col_name = alias.clone().unwrap_or_else(|| match expr {
264 Expr::Column(col) => col.column.clone(),
265 Expr::Function(func) => format!("{}()", func.name.name()),
266 _ => "?column?".to_string(),
267 });
268 output_columns.push(col_name);
269 }
270 }
271 }
272
273 for source_row in &source_rows {
275 let mut result_row = HashMap::new();
276
277 for (idx, item) in select.columns.iter().enumerate() {
278 let col_name = &output_columns[idx];
279
280 match item {
281 SelectItem::Wildcard => {
282 for (k, v) in source_row {
284 result_row.insert(k.clone(), v.clone());
285 }
286 }
287 SelectItem::QualifiedWildcard(_) => {
288 for (k, v) in source_row {
290 result_row.insert(k.clone(), v.clone());
291 }
292 }
293 SelectItem::Expr { expr, .. } => {
294 let value = self.evaluate_expr(expr, source_row, params)?;
295 result_row.insert(col_name.clone(), value);
296 }
297 }
298 }
299
300 result_rows.push(result_row);
301 }
302 }
303
304 Ok(ExecutionResult::Rows {
305 columns: output_columns,
306 rows: result_rows,
307 })
308 }
309
310 fn execute_insert(
311 &mut self,
312 insert: &InsertStmt,
313 params: &[SochValue],
314 ) -> SqlResult<ExecutionResult> {
315 let table_name = insert.table.name().to_string();
316
317 let table_columns = {
319 let table = self
320 .tables
321 .get(&table_name)
322 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
323 table.columns.clone()
324 };
325
326 let mut rows_affected = 0;
327 let mut new_rows = Vec::new();
328
329 match &insert.source {
330 InsertSource::Values(rows) => {
331 for value_exprs in rows {
332 let mut row_values = Vec::new();
333
334 if let Some(columns) = &insert.columns {
335 if columns.len() != value_exprs.len() {
336 return Err(SqlError::InvalidArgument(format!(
337 "Column count ({}) doesn't match value count ({})",
338 columns.len(),
339 value_exprs.len()
340 )));
341 }
342
343 for table_col in &table_columns {
345 if let Some(pos) = columns.iter().position(|c| c == table_col) {
346 let value =
347 self.evaluate_expr(&value_exprs[pos], &HashMap::new(), params)?;
348 row_values.push(value);
349 } else {
350 row_values.push(SochValue::Null);
351 }
352 }
353 } else {
354 for expr in value_exprs {
356 let value = self.evaluate_expr(expr, &HashMap::new(), params)?;
357 row_values.push(value);
358 }
359 }
360
361 new_rows.push(row_values);
362 rows_affected += 1;
363 }
364 }
365 InsertSource::Query(_) => {
366 return Err(SqlError::NotImplemented(
367 "INSERT ... SELECT not yet supported".into(),
368 ));
369 }
370 InsertSource::Default => {
371 return Err(SqlError::NotImplemented(
372 "INSERT DEFAULT VALUES not yet supported".into(),
373 ));
374 }
375 }
376
377 let table = self.tables.get_mut(&table_name).unwrap();
379 for row in new_rows {
380 table.rows.push(row);
381 }
382
383 Ok(ExecutionResult::RowsAffected(rows_affected))
384 }
385
386 fn execute_update(
387 &mut self,
388 update: &UpdateStmt,
389 params: &[SochValue],
390 ) -> SqlResult<ExecutionResult> {
391 let table_name = update.table.name().to_string();
392
393 let (_table_columns, updates_to_apply) = {
395 let table = self
396 .tables
397 .get(&table_name)
398 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
399
400 let mut updates = Vec::new();
401
402 for row_idx in 0..table.rows.len() {
403 let row_map: HashMap<String, SochValue> = table
405 .columns
406 .iter()
407 .zip(table.rows[row_idx].iter())
408 .map(|(col, val)| (col.clone(), val.clone()))
409 .collect();
410
411 let matches = if let Some(where_clause) = &update.where_clause {
413 self.evaluate_where(where_clause, &row_map, params)?
414 } else {
415 true
416 };
417
418 if matches {
419 let mut row_updates = Vec::new();
421 for assignment in &update.assignments {
422 if let Some(col_idx) =
423 table.columns.iter().position(|c| c == &assignment.column)
424 {
425 let value = self.evaluate_expr(&assignment.value, &row_map, params)?;
426 row_updates.push((col_idx, value));
427 }
428 }
429 updates.push((row_idx, row_updates));
430 }
431 }
432
433 (table.columns.clone(), updates)
434 };
435
436 let rows_affected = updates_to_apply.len();
437
438 let table = self.tables.get_mut(&table_name).unwrap();
440 for (row_idx, row_updates) in updates_to_apply {
441 for (col_idx, value) in row_updates {
442 table.rows[row_idx][col_idx] = value;
443 }
444 }
445
446 Ok(ExecutionResult::RowsAffected(rows_affected))
447 }
448
449 fn execute_delete(
450 &mut self,
451 delete: &DeleteStmt,
452 params: &[SochValue],
453 ) -> SqlResult<ExecutionResult> {
454 let table_name = delete.table.name().to_string();
455
456 let indices_to_remove = {
458 let table = self
459 .tables
460 .get(&table_name)
461 .ok_or_else(|| SqlError::TableNotFound(table_name.clone()))?;
462
463 let mut indices = Vec::new();
464
465 for (row_idx, row) in table.rows.iter().enumerate() {
466 let row_map: HashMap<String, SochValue> = table
468 .columns
469 .iter()
470 .zip(row.iter())
471 .map(|(col, val)| (col.clone(), val.clone()))
472 .collect();
473
474 let matches = if let Some(where_clause) = &delete.where_clause {
476 self.evaluate_where(where_clause, &row_map, params)?
477 } else {
478 true
479 };
480
481 if matches {
482 indices.push(row_idx);
483 }
484 }
485
486 indices
487 };
488
489 let rows_affected = indices_to_remove.len();
490
491 let table = self.tables.get_mut(&table_name).unwrap();
493 for idx in indices_to_remove.into_iter().rev() {
495 table.rows.remove(idx);
496 }
497
498 Ok(ExecutionResult::RowsAffected(rows_affected))
499 }
500
501 fn execute_create_table(&mut self, create: &CreateTableStmt) -> SqlResult<ExecutionResult> {
502 let table_name = create.name.name().to_string();
503
504 if self.tables.contains_key(&table_name) {
505 if create.if_not_exists {
506 return Ok(ExecutionResult::Ok);
507 }
508 return Err(SqlError::ConstraintViolation(format!(
509 "Table '{}' already exists",
510 table_name
511 )));
512 }
513
514 let columns: Vec<String> = create.columns.iter().map(|c| c.name.clone()).collect();
515 let column_types: Vec<DataType> =
516 create.columns.iter().map(|c| c.data_type.clone()).collect();
517
518 self.tables.insert(
519 table_name,
520 TableData {
521 columns,
522 column_types,
523 rows: Vec::new(),
524 },
525 );
526
527 Ok(ExecutionResult::Ok)
528 }
529
530 fn execute_drop_table(&mut self, drop: &DropTableStmt) -> SqlResult<ExecutionResult> {
531 for name in &drop.names {
532 let table_name = name.name().to_string();
533 if self.tables.remove(&table_name).is_none() && !drop.if_exists {
534 return Err(SqlError::TableNotFound(table_name));
535 }
536 }
537
538 Ok(ExecutionResult::Ok)
539 }
540
541 fn evaluate_where(
544 &self,
545 expr: &Expr,
546 row: &HashMap<String, SochValue>,
547 params: &[SochValue],
548 ) -> SqlResult<bool> {
549 let value = self.evaluate_expr(expr, row, params)?;
550 match value {
551 SochValue::Bool(b) => Ok(b),
552 SochValue::Null => Ok(false),
553 _ => Err(SqlError::TypeError(
554 "WHERE clause must evaluate to boolean".into(),
555 )),
556 }
557 }
558
559 fn evaluate_expr(
560 &self,
561 expr: &Expr,
562 row: &HashMap<String, SochValue>,
563 params: &[SochValue],
564 ) -> SqlResult<SochValue> {
565 match expr {
566 Expr::Literal(lit) => Ok(self.literal_to_value(lit)),
567
568 Expr::Column(col_ref) => row
569 .get(&col_ref.column)
570 .cloned()
571 .ok_or_else(|| SqlError::ColumnNotFound(col_ref.column.clone())),
572
573 Expr::Placeholder(n) => params
574 .get((*n as usize).saturating_sub(1))
575 .cloned()
576 .ok_or_else(|| SqlError::InvalidArgument(format!("Parameter ${} not provided", n))),
577
578 Expr::BinaryOp { left, op, right } => {
579 let left_val = self.evaluate_expr(left, row, params)?;
580 let right_val = self.evaluate_expr(right, row, params)?;
581 self.evaluate_binary_op(&left_val, op, &right_val)
582 }
583
584 Expr::UnaryOp { op, expr } => {
585 let val = self.evaluate_expr(expr, row, params)?;
586 self.evaluate_unary_op(op, &val)
587 }
588
589 Expr::IsNull { expr, negated } => {
590 let val = self.evaluate_expr(expr, row, params)?;
591 let is_null = matches!(val, SochValue::Null);
592 Ok(SochValue::Bool(if *negated { !is_null } else { is_null }))
593 }
594
595 Expr::InList {
596 expr,
597 list,
598 negated,
599 } => {
600 let val = self.evaluate_expr(expr, row, params)?;
601 let mut found = false;
602 for item in list {
603 let item_val = self.evaluate_expr(item, row, params)?;
604 if self.values_equal(&val, &item_val) {
605 found = true;
606 break;
607 }
608 }
609 Ok(SochValue::Bool(if *negated { !found } else { found }))
610 }
611
612 Expr::Between {
613 expr,
614 low,
615 high,
616 negated,
617 } => {
618 let val = self.evaluate_expr(expr, row, params)?;
619 let low_val = self.evaluate_expr(low, row, params)?;
620 let high_val = self.evaluate_expr(high, row, params)?;
621
622 let cmp_low = self.compare_values(Some(&val), Some(&low_val));
623 let cmp_high = self.compare_values(Some(&val), Some(&high_val));
624
625 let in_range =
626 cmp_low != std::cmp::Ordering::Less && cmp_high != std::cmp::Ordering::Greater;
627
628 Ok(SochValue::Bool(if *negated { !in_range } else { in_range }))
629 }
630
631 Expr::Like {
632 expr,
633 pattern,
634 negated,
635 ..
636 } => {
637 let val = self.evaluate_expr(expr, row, params)?;
638 let pattern_val = self.evaluate_expr(pattern, row, params)?;
639
640 match (&val, &pattern_val) {
641 (SochValue::Text(s), SochValue::Text(p)) => {
642 let matches = crate::like::like_match(s, p);
646 Ok(SochValue::Bool(if *negated { !matches } else { matches }))
647 }
648 _ => Ok(SochValue::Bool(false)),
649 }
650 }
651
652 Expr::Function(func) => self.evaluate_function(func, row, params),
653
654 Expr::Case {
655 operand,
656 conditions,
657 else_result,
658 } => {
659 if let Some(op) = operand {
660 let op_val = self.evaluate_expr(op, row, params)?;
662 for (when_expr, then_expr) in conditions {
663 let when_val = self.evaluate_expr(when_expr, row, params)?;
664 if self.values_equal(&op_val, &when_val) {
665 return self.evaluate_expr(then_expr, row, params);
666 }
667 }
668 } else {
669 for (when_expr, then_expr) in conditions {
671 let when_val = self.evaluate_expr(when_expr, row, params)?;
672 if matches!(when_val, SochValue::Bool(true)) {
673 return self.evaluate_expr(then_expr, row, params);
674 }
675 }
676 }
677
678 if let Some(else_expr) = else_result {
679 self.evaluate_expr(else_expr, row, params)
680 } else {
681 Ok(SochValue::Null)
682 }
683 }
684
685 _ => Err(SqlError::NotImplemented(format!(
686 "Expression type {:?} not yet supported",
687 expr
688 ))),
689 }
690 }
691
692 fn literal_to_value(&self, lit: &Literal) -> SochValue {
693 match lit {
694 Literal::Null => SochValue::Null,
695 Literal::Boolean(b) => SochValue::Bool(*b),
696 Literal::Integer(n) => SochValue::Int(*n),
697 Literal::Float(f) => SochValue::Float(*f),
698 Literal::String(s) => SochValue::Text(s.clone()),
699 Literal::Blob(b) => SochValue::Binary(b.clone()),
700 }
701 }
702
703 fn evaluate_binary_op(
704 &self,
705 left: &SochValue,
706 op: &BinaryOperator,
707 right: &SochValue,
708 ) -> SqlResult<SochValue> {
709 match op {
710 BinaryOperator::Eq => Ok(SochValue::Bool(self.values_equal(left, right))),
711 BinaryOperator::Ne => Ok(SochValue::Bool(!self.values_equal(left, right))),
712 BinaryOperator::Lt => Ok(SochValue::Bool(
713 self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Less,
714 )),
715 BinaryOperator::Le => Ok(SochValue::Bool(
716 self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Greater,
717 )),
718 BinaryOperator::Gt => Ok(SochValue::Bool(
719 self.compare_values(Some(left), Some(right)) == std::cmp::Ordering::Greater,
720 )),
721 BinaryOperator::Ge => Ok(SochValue::Bool(
722 self.compare_values(Some(left), Some(right)) != std::cmp::Ordering::Less,
723 )),
724
725 BinaryOperator::And => match (left, right) {
726 (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l && *r)),
727 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
728 _ => Err(SqlError::TypeError("AND requires boolean operands".into())),
729 },
730
731 BinaryOperator::Or => match (left, right) {
732 (SochValue::Bool(l), SochValue::Bool(r)) => Ok(SochValue::Bool(*l || *r)),
733 (SochValue::Bool(true), _) | (_, SochValue::Bool(true)) => {
734 Ok(SochValue::Bool(true))
735 }
736 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
737 _ => Err(SqlError::TypeError("OR requires boolean operands".into())),
738 },
739
740 BinaryOperator::Plus => self.arithmetic_op(left, right, |a, b| a + b, |a, b| a + b),
741 BinaryOperator::Minus => self.arithmetic_op(left, right, |a, b| a - b, |a, b| a - b),
742 BinaryOperator::Multiply => self.arithmetic_op(left, right, |a, b| a * b, |a, b| a * b),
743 BinaryOperator::Divide => self.arithmetic_op(
744 left,
745 right,
746 |a, b| if b != 0 { a / b } else { 0 },
747 |a, b| a / b,
748 ),
749 BinaryOperator::Modulo => self.arithmetic_op(
750 left,
751 right,
752 |a, b| if b != 0 { a % b } else { 0 },
753 |a, b| a % b,
754 ),
755
756 BinaryOperator::Concat => match (left, right) {
757 (SochValue::Text(l), SochValue::Text(r)) => {
758 Ok(SochValue::Text(format!("{}{}", l, r)))
759 }
760 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
761 _ => Err(SqlError::TypeError("|| requires string operands".into())),
762 },
763
764 _ => Err(SqlError::NotImplemented(format!(
765 "Operator {:?} not implemented",
766 op
767 ))),
768 }
769 }
770
771 fn evaluate_unary_op(&self, op: &UnaryOperator, val: &SochValue) -> SqlResult<SochValue> {
772 match op {
773 UnaryOperator::Not => match val {
774 SochValue::Bool(b) => Ok(SochValue::Bool(!b)),
775 SochValue::Null => Ok(SochValue::Null),
776 _ => Err(SqlError::TypeError("NOT requires boolean operand".into())),
777 },
778 UnaryOperator::Minus => match val {
779 SochValue::Int(n) => Ok(SochValue::Int(-n)),
780 SochValue::Float(f) => Ok(SochValue::Float(-f)),
781 SochValue::Null => Ok(SochValue::Null),
782 _ => Err(SqlError::TypeError(
783 "Unary minus requires numeric operand".into(),
784 )),
785 },
786 UnaryOperator::Plus => Ok(val.clone()),
787 UnaryOperator::BitNot => match val {
788 SochValue::Int(n) => Ok(SochValue::Int(!n)),
789 _ => Err(SqlError::TypeError("~ requires integer operand".into())),
790 },
791 }
792 }
793
794 fn evaluate_function(
795 &self,
796 func: &FunctionCall,
797 row: &HashMap<String, SochValue>,
798 params: &[SochValue],
799 ) -> SqlResult<SochValue> {
800 let func_name = func.name.name().to_uppercase();
801
802 match func_name.as_str() {
803 "COALESCE" => {
804 for arg in &func.args {
805 let val = self.evaluate_expr(arg, row, params)?;
806 if !matches!(val, SochValue::Null) {
807 return Ok(val);
808 }
809 }
810 Ok(SochValue::Null)
811 }
812
813 "NULLIF" => {
814 if func.args.len() != 2 {
815 return Err(SqlError::InvalidArgument(
816 "NULLIF requires 2 arguments".into(),
817 ));
818 }
819 let val1 = self.evaluate_expr(&func.args[0], row, params)?;
820 let val2 = self.evaluate_expr(&func.args[1], row, params)?;
821 if self.values_equal(&val1, &val2) {
822 Ok(SochValue::Null)
823 } else {
824 Ok(val1)
825 }
826 }
827
828 "ABS" => {
829 if func.args.len() != 1 {
830 return Err(SqlError::InvalidArgument("ABS requires 1 argument".into()));
831 }
832 let val = self.evaluate_expr(&func.args[0], row, params)?;
833 match val {
834 SochValue::Int(n) => Ok(SochValue::Int(n.abs())),
835 SochValue::Float(f) => Ok(SochValue::Float(f.abs())),
836 SochValue::Null => Ok(SochValue::Null),
837 _ => Err(SqlError::TypeError("ABS requires numeric argument".into())),
838 }
839 }
840
841 "LENGTH" | "LEN" => {
842 if func.args.len() != 1 {
843 return Err(SqlError::InvalidArgument(
844 "LENGTH requires 1 argument".into(),
845 ));
846 }
847 let val = self.evaluate_expr(&func.args[0], row, params)?;
848 match val {
849 SochValue::Text(s) => Ok(SochValue::Int(s.len() as i64)),
850 SochValue::Binary(b) => Ok(SochValue::Int(b.len() as i64)),
851 SochValue::Null => Ok(SochValue::Null),
852 _ => Err(SqlError::TypeError(
853 "LENGTH requires string argument".into(),
854 )),
855 }
856 }
857
858 "UPPER" => {
859 if func.args.len() != 1 {
860 return Err(SqlError::InvalidArgument(
861 "UPPER requires 1 argument".into(),
862 ));
863 }
864 let val = self.evaluate_expr(&func.args[0], row, params)?;
865 match val {
866 SochValue::Text(s) => Ok(SochValue::Text(s.to_uppercase())),
867 SochValue::Null => Ok(SochValue::Null),
868 _ => Err(SqlError::TypeError("UPPER requires string argument".into())),
869 }
870 }
871
872 "LOWER" => {
873 if func.args.len() != 1 {
874 return Err(SqlError::InvalidArgument(
875 "LOWER requires 1 argument".into(),
876 ));
877 }
878 let val = self.evaluate_expr(&func.args[0], row, params)?;
879 match val {
880 SochValue::Text(s) => Ok(SochValue::Text(s.to_lowercase())),
881 SochValue::Null => Ok(SochValue::Null),
882 _ => Err(SqlError::TypeError("LOWER requires string argument".into())),
883 }
884 }
885
886 "TRIM" => {
887 if func.args.len() != 1 {
888 return Err(SqlError::InvalidArgument("TRIM requires 1 argument".into()));
889 }
890 let val = self.evaluate_expr(&func.args[0], row, params)?;
891 match val {
892 SochValue::Text(s) => Ok(SochValue::Text(s.trim().to_string())),
893 SochValue::Null => Ok(SochValue::Null),
894 _ => Err(SqlError::TypeError("TRIM requires string argument".into())),
895 }
896 }
897
898 "SUBSTR" | "SUBSTRING" => {
899 if func.args.len() < 2 || func.args.len() > 3 {
900 return Err(SqlError::InvalidArgument(
901 "SUBSTR requires 2 or 3 arguments".into(),
902 ));
903 }
904 let val = self.evaluate_expr(&func.args[0], row, params)?;
905 let start = self.evaluate_expr(&func.args[1], row, params)?;
906 let len = if func.args.len() == 3 {
907 Some(self.evaluate_expr(&func.args[2], row, params)?)
908 } else {
909 None
910 };
911
912 match (val, start) {
913 (SochValue::Text(s), SochValue::Int(start)) => {
914 let start = (start.max(1) - 1) as usize;
915 if start >= s.len() {
916 return Ok(SochValue::Text(String::new()));
917 }
918 let result = if let Some(SochValue::Int(len)) = len {
919 s.chars().skip(start).take(len as usize).collect()
920 } else {
921 s.chars().skip(start).collect()
922 };
923 Ok(SochValue::Text(result))
924 }
925 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
926 _ => Err(SqlError::TypeError(
927 "SUBSTR requires string and integer arguments".into(),
928 )),
929 }
930 }
931
932 _ => Err(SqlError::NotImplemented(format!(
933 "Function {} not implemented",
934 func_name
935 ))),
936 }
937 }
938
939 fn values_equal(&self, left: &SochValue, right: &SochValue) -> bool {
942 match (left, right) {
943 (SochValue::Null, _) | (_, SochValue::Null) => false,
944 (SochValue::Int(l), SochValue::Int(r)) => l == r,
945 (SochValue::Float(l), SochValue::Float(r)) => (l - r).abs() < f64::EPSILON,
946 (SochValue::Int(l), SochValue::Float(r)) => (*l as f64 - r).abs() < f64::EPSILON,
947 (SochValue::Float(l), SochValue::Int(r)) => (l - *r as f64).abs() < f64::EPSILON,
948 (SochValue::Text(l), SochValue::Text(r)) => l == r,
949 (SochValue::Bool(l), SochValue::Bool(r)) => l == r,
950 (SochValue::Binary(l), SochValue::Binary(r)) => l == r,
951 (SochValue::UInt(l), SochValue::UInt(r)) => l == r,
952 (SochValue::Int(l), SochValue::UInt(r)) => *l >= 0 && (*l as u64) == *r,
953 (SochValue::UInt(l), SochValue::Int(r)) => *r >= 0 && *l == (*r as u64),
954 _ => false,
955 }
956 }
957
958 fn compare_values(
959 &self,
960 left: Option<&SochValue>,
961 right: Option<&SochValue>,
962 ) -> std::cmp::Ordering {
963 match (left, right) {
964 (None, None) => std::cmp::Ordering::Equal,
965 (None, _) => std::cmp::Ordering::Less,
966 (_, None) => std::cmp::Ordering::Greater,
967 (Some(SochValue::Null), _) | (_, Some(SochValue::Null)) => std::cmp::Ordering::Equal,
968 (Some(SochValue::Int(l)), Some(SochValue::Int(r))) => l.cmp(r),
969 (Some(SochValue::Float(l)), Some(SochValue::Float(r))) => {
970 l.partial_cmp(r).unwrap_or(std::cmp::Ordering::Equal)
971 }
972 (Some(SochValue::Int(l)), Some(SochValue::Float(r))) => (*l as f64)
973 .partial_cmp(r)
974 .unwrap_or(std::cmp::Ordering::Equal),
975 (Some(SochValue::Float(l)), Some(SochValue::Int(r))) => l
976 .partial_cmp(&(*r as f64))
977 .unwrap_or(std::cmp::Ordering::Equal),
978 (Some(SochValue::Text(l)), Some(SochValue::Text(r))) => l.cmp(r),
979 (Some(SochValue::UInt(l)), Some(SochValue::UInt(r))) => l.cmp(r),
980 _ => std::cmp::Ordering::Equal,
981 }
982 }
983
984 fn arithmetic_op<FI, FF>(
985 &self,
986 left: &SochValue,
987 right: &SochValue,
988 int_op: FI,
989 float_op: FF,
990 ) -> SqlResult<SochValue>
991 where
992 FI: Fn(i64, i64) -> i64,
993 FF: Fn(f64, f64) -> f64,
994 {
995 match (left, right) {
996 (SochValue::Null, _) | (_, SochValue::Null) => Ok(SochValue::Null),
997 (SochValue::Int(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l, *r))),
998 (SochValue::Float(l), SochValue::Float(r)) => Ok(SochValue::Float(float_op(*l, *r))),
999 (SochValue::Int(l), SochValue::Float(r)) => {
1000 Ok(SochValue::Float(float_op(*l as f64, *r)))
1001 }
1002 (SochValue::Float(l), SochValue::Int(r)) => {
1003 Ok(SochValue::Float(float_op(*l, *r as f64)))
1004 }
1005 (SochValue::UInt(l), SochValue::UInt(r)) => {
1006 Ok(SochValue::Int(int_op(*l as i64, *r as i64)))
1007 }
1008 (SochValue::Int(l), SochValue::UInt(r)) => Ok(SochValue::Int(int_op(*l, *r as i64))),
1009 (SochValue::UInt(l), SochValue::Int(r)) => Ok(SochValue::Int(int_op(*l as i64, *r))),
1010 _ => Err(SqlError::TypeError(
1011 "Arithmetic requires numeric operands".into(),
1012 )),
1013 }
1014 }
1015}
1016
1017#[cfg(test)]
1018mod tests {
1019 use super::*;
1020
1021 #[test]
1022 fn test_create_table_and_insert() {
1023 let mut executor = SqlExecutor::new();
1024
1025 let result = executor
1027 .execute("CREATE TABLE users (id INTEGER, name VARCHAR(100))")
1028 .unwrap();
1029 assert!(matches!(result, ExecutionResult::Ok));
1030
1031 let result = executor
1033 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1034 .unwrap();
1035 assert_eq!(result.rows_affected(), 1);
1036
1037 let result = executor
1038 .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1039 .unwrap();
1040 assert_eq!(result.rows_affected(), 1);
1041
1042 let result = executor.execute("SELECT * FROM users").unwrap();
1044 assert_eq!(result.rows_affected(), 2);
1045 }
1046
1047 #[test]
1048 fn test_select_with_where() {
1049 let mut executor = SqlExecutor::new();
1050
1051 executor
1052 .execute("CREATE TABLE products (id INTEGER, name TEXT, price FLOAT)")
1053 .unwrap();
1054 executor
1055 .execute("INSERT INTO products (id, name, price) VALUES (1, 'Apple', 1.50)")
1056 .unwrap();
1057 executor
1058 .execute("INSERT INTO products (id, name, price) VALUES (2, 'Banana', 0.75)")
1059 .unwrap();
1060 executor
1061 .execute("INSERT INTO products (id, name, price) VALUES (3, 'Orange', 2.00)")
1062 .unwrap();
1063
1064 let result = executor
1065 .execute("SELECT * FROM products WHERE price > 1.0")
1066 .unwrap();
1067 assert_eq!(result.rows_affected(), 2);
1068 }
1069
1070 #[test]
1071 fn test_update() {
1072 let mut executor = SqlExecutor::new();
1073
1074 executor
1075 .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1076 .unwrap();
1077 executor
1078 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1079 .unwrap();
1080
1081 let result = executor
1082 .execute("UPDATE users SET name = 'Alicia' WHERE id = 1")
1083 .unwrap();
1084 assert_eq!(result.rows_affected(), 1);
1085
1086 let result = executor
1087 .execute("SELECT * FROM users WHERE name = 'Alicia'")
1088 .unwrap();
1089 assert_eq!(result.rows_affected(), 1);
1090 }
1091
1092 #[test]
1093 fn test_delete() {
1094 let mut executor = SqlExecutor::new();
1095
1096 executor
1097 .execute("CREATE TABLE users (id INTEGER, name TEXT)")
1098 .unwrap();
1099 executor
1100 .execute("INSERT INTO users (id, name) VALUES (1, 'Alice')")
1101 .unwrap();
1102 executor
1103 .execute("INSERT INTO users (id, name) VALUES (2, 'Bob')")
1104 .unwrap();
1105
1106 let result = executor.execute("DELETE FROM users WHERE id = 1").unwrap();
1107 assert_eq!(result.rows_affected(), 1);
1108
1109 let result = executor.execute("SELECT * FROM users").unwrap();
1110 assert_eq!(result.rows_affected(), 1);
1111 }
1112
1113 #[test]
1114 fn test_functions() {
1115 let mut executor = SqlExecutor::new();
1116
1117 executor.execute("CREATE TABLE t (s TEXT)").unwrap();
1118 executor
1119 .execute("INSERT INTO t (s) VALUES ('hello')")
1120 .unwrap();
1121
1122 let result = executor.execute("SELECT UPPER(s) FROM t").unwrap();
1123 if let ExecutionResult::Rows { rows, .. } = result {
1124 let row = &rows[0];
1125 assert!(
1127 row.values()
1128 .any(|v| matches!(v, SochValue::Text(s) if s == "HELLO"))
1129 );
1130 } else {
1131 panic!("Expected rows");
1132 }
1133 }
1134
1135 #[test]
1136 fn test_order_by() {
1137 let mut executor = SqlExecutor::new();
1138
1139 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1140 executor.execute("INSERT INTO nums (n) VALUES (3)").unwrap();
1141 executor.execute("INSERT INTO nums (n) VALUES (1)").unwrap();
1142 executor.execute("INSERT INTO nums (n) VALUES (2)").unwrap();
1143
1144 let result = executor
1145 .execute("SELECT * FROM nums ORDER BY n ASC")
1146 .unwrap();
1147 if let ExecutionResult::Rows { rows, .. } = result {
1148 let values: Vec<i64> = rows
1149 .iter()
1150 .filter_map(|r| r.get("n"))
1151 .filter_map(|v| {
1152 if let SochValue::Int(n) = v {
1153 Some(*n)
1154 } else {
1155 None
1156 }
1157 })
1158 .collect();
1159 assert_eq!(values, vec![1, 2, 3]);
1160 } else {
1161 panic!("Expected rows");
1162 }
1163 }
1164
1165 #[test]
1166 fn test_limit_offset() {
1167 let mut executor = SqlExecutor::new();
1168
1169 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1170 for i in 1..=10 {
1171 executor
1172 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1173 .unwrap();
1174 }
1175
1176 let result = executor
1177 .execute("SELECT * FROM nums LIMIT 3 OFFSET 2")
1178 .unwrap();
1179 assert_eq!(result.rows_affected(), 3);
1180 }
1181
1182 #[test]
1183 fn test_between() {
1184 let mut executor = SqlExecutor::new();
1185
1186 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1187 for i in 1..=10 {
1188 executor
1189 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1190 .unwrap();
1191 }
1192
1193 let result = executor
1194 .execute("SELECT * FROM nums WHERE n BETWEEN 3 AND 7")
1195 .unwrap();
1196 assert_eq!(result.rows_affected(), 5);
1197 }
1198
1199 #[test]
1200 fn test_in_list() {
1201 let mut executor = SqlExecutor::new();
1202
1203 executor.execute("CREATE TABLE nums (n INTEGER)").unwrap();
1204 for i in 1..=5 {
1205 executor
1206 .execute(&format!("INSERT INTO nums (n) VALUES ({})", i))
1207 .unwrap();
1208 }
1209
1210 let result = executor
1211 .execute("SELECT * FROM nums WHERE n IN (1, 3, 5)")
1212 .unwrap();
1213 assert_eq!(result.rows_affected(), 3);
1214 }
1215}