1use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14 query: &SelectQuery,
15 rows: &[Row],
16 _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18 execute(query, rows, None)
19}
20
21pub fn execute_query_indexed(
23 query: &SelectQuery,
24 rows: &[Row],
25 schema: &Schema,
26 index: Option<&crate::index::TableIndex>,
27 searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29 let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31 collect_fts_results(wc, schema, searcher)
32 } else {
33 HashMap::new()
34 };
35
36 execute_with_fts(query, rows, index, &fts_results)
37}
38
39fn collect_fts_results(
42 clause: &WhereClause,
43 schema: &Schema,
44 searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46 let mut results = HashMap::new();
47 collect_fts_results_inner(clause, schema, searcher, &mut results);
48 results
49}
50
51fn collect_fts_results_inner(
52 clause: &WhereClause,
53 schema: &Schema,
54 searcher: &crate::search::TableSearcher,
55 results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57 match clause {
58 WhereClause::Comparison(cmp) => {
59 if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60 if let Some(SqlValue::String(pattern)) = &cmp.value {
61 let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63 if !search_term.is_empty() {
64 if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65 let key = (cmp.column.clone(), pattern.clone());
66 results.insert(key, paths.into_iter().collect());
67 }
68 }
69 }
70 }
71 }
72 WhereClause::BoolOp(bop) => {
73 collect_fts_results_inner(&bop.left, schema, searcher, results);
74 collect_fts_results_inner(&bop.right, schema, searcher, results);
75 }
76 }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82 query: &SelectQuery,
83 rows: &[Row],
84 index: Option<&crate::index::TableIndex>,
85 fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87 let mut all_columns: Vec<String> = Vec::new();
89 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90 for r in rows {
91 for k in r.keys() {
92 if seen.insert(k.clone()) {
93 all_columns.push(k.clone());
94 }
95 }
96 }
97
98 let has_aggregates = match &query.columns {
100 ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101 _ => false,
102 };
103
104 let columns: Vec<String> = match &query.columns {
106 ColumnList::All => all_columns,
107 ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108 };
109
110 let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112 let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113 if let Some(paths) = candidate_paths {
114 rows.iter()
115 .filter(|r| {
116 r.get("path")
117 .and_then(|v| v.as_str())
118 .map_or(false, |p| paths.contains(p))
119 })
120 .filter(|r| evaluate_with_fts(wc, r, fts))
121 .cloned()
122 .collect()
123 } else {
124 rows.iter()
125 .filter(|r| evaluate_with_fts(wc, r, fts))
126 .cloned()
127 .collect()
128 }
129 } else {
130 rows.to_vec()
131 };
132
133 let mut result = if has_aggregates || query.group_by.is_some() {
135 let exprs = match &query.columns {
136 ColumnList::Named(exprs) => exprs.clone(),
137 _ => return Err(MdqlError::QueryExecution(
138 "SELECT * with GROUP BY is not supported".into(),
139 )),
140 };
141 let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142 aggregate_rows(&filtered, &exprs, group_keys)?
143 } else {
144 filtered
145 };
146
147 if let Some(ref order_by) = query.order_by {
149 let resolved = resolve_order_aliases(order_by, &query.columns);
150 sort_rows(&mut result, &resolved);
151 }
152
153 if let Some(limit) = query.limit {
155 result.truncate(limit as usize);
156 }
157
158 if !matches!(query.columns, ColumnList::All) {
160 let named_exprs = match &query.columns {
161 ColumnList::Named(exprs) => exprs,
162 _ => unreachable!(),
163 };
164
165 let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
167 if has_expr_cols {
168 for row in &mut result {
169 for expr in named_exprs {
170 if let SelectExpr::Expr { expr: e, alias } = expr {
171 let name = alias.clone().unwrap_or_else(|| e.display_name());
172 let val = evaluate_expr(e, row);
173 row.insert(name, val);
174 }
175 }
176 }
177 }
178
179 let col_set: std::collections::HashSet<&str> =
180 columns.iter().map(|s| s.as_str()).collect();
181 for row in &mut result {
182 row.retain(|k, _| col_set.contains(k.as_str()));
183 }
184 }
185
186 Ok((result, columns))
187}
188
189fn aggregate_rows(
190 rows: &[Row],
191 exprs: &[SelectExpr],
192 group_keys: &[String],
193) -> crate::errors::Result<Vec<Row>> {
194 let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
196 let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
197
198 if group_keys.is_empty() {
199 let all_refs: Vec<&Row> = rows.iter().collect();
201 groups.push((vec![], all_refs));
202 } else {
203 for row in rows {
204 let key: Vec<String> = group_keys
205 .iter()
206 .map(|k| {
207 row.get(k)
208 .map(|v| v.to_display_string())
209 .unwrap_or_default()
210 })
211 .collect();
212 let key_vals: Vec<Value> = group_keys
213 .iter()
214 .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
215 .collect();
216 if let Some(&idx) = key_index.get(&key) {
217 groups[idx].1.push(row);
218 } else {
219 let idx = groups.len();
220 key_index.insert(key, idx);
221 groups.push((key_vals, vec![row]));
222 }
223 }
224 }
225
226 let mut result = Vec::new();
228 for (key_vals, group_rows) in &groups {
229 let mut out = Row::new();
230
231 for (i, k) in group_keys.iter().enumerate() {
233 out.insert(k.clone(), key_vals[i].clone());
234 }
235
236 for expr in exprs {
238 match expr {
239 SelectExpr::Column(name) => {
240 if !out.contains_key(name) {
242 if let Some(first) = group_rows.first() {
243 out.insert(
244 name.clone(),
245 first.get(name).cloned().unwrap_or(Value::Null),
246 );
247 }
248 }
249 }
250 SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
251 let out_name = alias
252 .clone()
253 .unwrap_or_else(|| expr.output_name());
254 let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
255 out.insert(out_name, val);
256 }
257 SelectExpr::Expr { expr: e, alias } => {
258 let out_name = alias.clone().unwrap_or_else(|| e.display_name());
259 if let Some(first) = group_rows.first() {
260 let val = evaluate_expr(e, first);
261 out.insert(out_name, val);
262 }
263 }
264 }
265 }
266
267 result.push(out);
268 }
269
270 Ok(result)
271}
272
273fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
276 if let Some(expr) = arg_expr {
277 evaluate_expr(expr, row)
278 } else {
279 row.get(arg).cloned().unwrap_or(Value::Null)
280 }
281}
282
283fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
284 match func {
285 AggFunc::Count => {
286 if arg == "*" && arg_expr.is_none() {
287 Value::Int(rows.len() as i64)
288 } else {
289 let count = rows
290 .iter()
291 .filter(|r| {
292 let v = resolve_agg_value(arg, arg_expr, r);
293 !v.is_null()
294 })
295 .count();
296 Value::Int(count as i64)
297 }
298 }
299 AggFunc::Sum => {
300 let mut total = 0.0f64;
301 let mut has_any = false;
302 for r in rows {
303 let v = resolve_agg_value(arg, arg_expr, r);
304 match v {
305 Value::Int(n) => { total += n as f64; has_any = true; }
306 Value::Float(f) => { total += f; has_any = true; }
307 _ => {}
308 }
309 }
310 if has_any { Value::Float(total) } else { Value::Null }
311 }
312 AggFunc::Avg => {
313 let mut total = 0.0f64;
314 let mut count = 0usize;
315 for r in rows {
316 let v = resolve_agg_value(arg, arg_expr, r);
317 match v {
318 Value::Int(n) => { total += n as f64; count += 1; }
319 Value::Float(f) => { total += f; count += 1; }
320 _ => {}
321 }
322 }
323 if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
324 }
325 AggFunc::Min => {
326 let mut min_val: Option<Value> = None;
327 for r in rows {
328 let v = resolve_agg_value(arg, arg_expr, r);
329 if v.is_null() { continue; }
330 min_val = Some(match min_val {
331 None => v,
332 Some(ref current) => {
333 if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
334 v
335 } else {
336 current.clone()
337 }
338 }
339 });
340 }
341 min_val.unwrap_or(Value::Null)
342 }
343 AggFunc::Max => {
344 let mut max_val: Option<Value> = None;
345 for r in rows {
346 let v = resolve_agg_value(arg, arg_expr, r);
347 if v.is_null() { continue; }
348 max_val = Some(match max_val {
349 None => v,
350 Some(ref current) => {
351 if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
352 v
353 } else {
354 current.clone()
355 }
356 }
357 });
358 }
359 max_val.unwrap_or(Value::Null)
360 }
361 }
362}
363
364fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
365 match clause {
366 WhereClause::BoolOp(bop) => {
367 let left = evaluate_with_fts(&bop.left, row, fts);
368 match bop.op.as_str() {
369 "AND" => left && evaluate_with_fts(&bop.right, row, fts),
370 "OR" => left || evaluate_with_fts(&bop.right, row, fts),
371 _ => false,
372 }
373 }
374 WhereClause::Comparison(cmp) => {
375 if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
377 if let Some(SqlValue::String(pattern)) = &cmp.value {
378 let key = (cmp.column.clone(), pattern.clone());
379 if let Some(matching_paths) = fts.get(&key) {
380 let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
381 let matched = matching_paths.contains(row_path);
382 return if cmp.op == "LIKE" { matched } else { !matched };
383 }
384 }
385 }
386 evaluate_comparison(cmp, row)
387 }
388 }
389}
390
391pub fn execute_join_query(
392 query: &SelectQuery,
393 tables: &HashMap<String, (Schema, Vec<Row>)>,
394) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
395 if query.joins.is_empty() {
396 return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
397 }
398
399 let left_name = &query.table;
400 let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
401
402 let mut aliases: HashMap<String, String> = HashMap::new();
404 aliases.insert(left_name.clone(), left_name.clone());
405 if let Some(ref a) = query.table_alias {
406 aliases.insert(a.clone(), left_name.clone());
407 }
408 for join in &query.joins {
409 aliases.insert(join.table.clone(), join.table.clone());
410 if let Some(ref a) = join.alias {
411 aliases.insert(a.clone(), join.table.clone());
412 }
413 }
414
415 let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
417 MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
418 })?;
419
420 let mut current_rows: Vec<Row> = left_rows
421 .iter()
422 .map(|r| {
423 let mut prefixed = Row::new();
424 for (k, v) in r {
425 prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
426 }
427 prefixed
428 })
429 .collect();
430
431 for join in &query.joins {
433 let right_name = &join.table;
434 let right_alias = join.alias.as_deref().unwrap_or(right_name);
435
436 let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
437 MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
438 })?;
439
440 let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
442 let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
443
444 let (left_key, right_key) = if on_right_table == *right_name {
446 let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
448 (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
449 } else {
450 let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
452 (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
453 };
454
455 let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
457 for r in right_rows {
458 if let Some(key) = r.get(&right_key) {
459 let key_str = key.to_display_string();
460 right_index.entry(key_str).or_default().push(r);
461 }
462 }
463
464 let mut next_rows: Vec<Row> = Vec::new();
466 for lr in ¤t_rows {
467 if let Some(key) = lr.get(&left_key) {
468 let key_str = key.to_display_string();
469 if let Some(matching) = right_index.get(&key_str) {
470 for rr in matching {
471 let mut merged = lr.clone();
472 for (k, v) in *rr {
473 merged.insert(format!("{}.{}", right_alias, k), v.clone());
474 }
475 next_rows.push(merged);
476 }
477 }
478 }
479 }
480 current_rows = next_rows;
481 }
482
483 let (mut result, columns) = execute(query, ¤t_rows, None)?;
484
485 if !result.is_empty() {
489 let mut base_counts: HashMap<String, usize> = HashMap::new();
490 for key in &columns {
491 if let Some((_prefix, base)) = key.split_once('.') {
492 *base_counts.entry(base.to_string()).or_default() += 1;
493 }
494 }
495 let unique_bases: Vec<String> = base_counts
496 .into_iter()
497 .filter(|(_, count)| *count == 1)
498 .map(|(base, _)| base)
499 .collect();
500
501 if !unique_bases.is_empty() {
502 let unique_set: std::collections::HashSet<&str> =
503 unique_bases.iter().map(|s| s.as_str()).collect();
504 for row in &mut result {
505 let additions: Vec<(String, Value)> = row
506 .iter()
507 .filter_map(|(k, v)| {
508 k.split_once('.').and_then(|(_, base)| {
509 if unique_set.contains(base) {
510 Some((base.to_string(), v.clone()))
511 } else {
512 None
513 }
514 })
515 })
516 .collect();
517 for (k, v) in additions {
518 row.insert(k, v);
519 }
520 }
521 }
522 }
523
524 Ok((result, columns))
525}
526
527fn reverse_alias(
529 table_name: &str,
530 aliases: &HashMap<String, String>,
531 query: &SelectQuery,
532 joins: &[JoinClause],
533) -> String {
534 if query.table == table_name {
536 return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
537 }
538 for j in joins {
540 if j.table == table_name {
541 return j.alias.as_deref().unwrap_or(&j.table).to_string();
542 }
543 }
544 if aliases.contains_key(table_name) {
546 return table_name.to_string();
547 }
548 table_name.to_string()
549}
550
551fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
552 if let Some((alias, column)) = col.split_once('.') {
553 let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
554 (table, column.to_string())
555 } else {
556 (String::new(), col.to_string())
557 }
558}
559
560fn execute(
561 query: &SelectQuery,
562 rows: &[Row],
563 index: Option<&crate::index::TableIndex>,
564) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
565 let empty_fts = HashMap::new();
566 execute_with_fts(query, rows, index, &empty_fts)
567}
568
569pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
570 match clause {
571 WhereClause::BoolOp(bop) => {
572 let left = evaluate(&bop.left, row);
573 match bop.op.as_str() {
574 "AND" => left && evaluate(&bop.right, row),
575 "OR" => left || evaluate(&bop.right, row),
576 _ => false,
577 }
578 }
579 WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
580 }
581}
582
583pub fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
585 match expr {
586 Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
587 Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
588 Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
589 Expr::Literal(SqlValue::Null) => Value::Null,
590 Expr::Literal(SqlValue::List(_)) => Value::Null,
591 Expr::Column(name) => {
592 if let Some(val) = row.get(name) {
593 return val.clone();
594 }
595 if let Some((dict_col, dict_key)) = name.split_once('.') {
596 if let Some(Value::Dict(map)) = row.get(dict_col) {
597 return map.get(dict_key).cloned().unwrap_or(Value::Null);
598 }
599 }
600 Value::Null
601 }
602 Expr::UnaryMinus(inner) => {
603 match evaluate_expr(inner, row) {
604 Value::Int(n) => Value::Int(-n),
605 Value::Float(f) => Value::Float(-f),
606 Value::Null => Value::Null,
607 _ => Value::Null, }
609 }
610 Expr::BinaryOp { left, op, right } => {
611 let lv = evaluate_expr(left, row);
612 let rv = evaluate_expr(right, row);
613
614 if lv.is_null() || rv.is_null() {
616 return Value::Null;
617 }
618
619 match (&lv, &rv) {
621 (Value::Int(a), Value::Int(b)) => {
622 match op {
623 ArithOp::Add => Value::Int(a.wrapping_add(*b)),
624 ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
625 ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
626 ArithOp::Div => {
627 if *b == 0 { Value::Null } else { Value::Int(a / b) }
628 }
629 ArithOp::Mod => {
630 if *b == 0 { Value::Null } else { Value::Int(a % b) }
631 }
632 }
633 }
634 _ => {
635 let a = match &lv {
637 Value::Int(n) => *n as f64,
638 Value::Float(f) => *f,
639 _ => return Value::Null,
640 };
641 let b = match &rv {
642 Value::Int(n) => *n as f64,
643 Value::Float(f) => *f,
644 _ => return Value::Null,
645 };
646 match op {
647 ArithOp::Add => Value::Float(a + b),
648 ArithOp::Sub => Value::Float(a - b),
649 ArithOp::Mul => Value::Float(a * b),
650 ArithOp::Div => {
651 if b == 0.0 { Value::Null } else { Value::Float(a / b) }
652 }
653 ArithOp::Mod => {
654 if b == 0.0 { Value::Null } else { Value::Float(a % b) }
655 }
656 }
657 }
658 }
659 }
660 Expr::Case { whens, else_expr } => {
661 for (condition, result) in whens {
662 if evaluate(condition, row) {
663 return evaluate_expr(result, row);
664 }
665 }
666 match else_expr {
667 Some(e) => evaluate_expr(e, row),
668 None => Value::Null,
669 }
670 }
671 }
672}
673
674fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
675 if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
677 if ["=", "!=", "<", ">", "<=", ">="].contains(&cmp.op.as_str()) {
678 let left_val = evaluate_expr(left_expr, row);
679 let right_val = evaluate_expr(right_expr, row);
680
681 if left_val.is_null() || right_val.is_null() {
683 return false;
684 }
685
686 let ord = compare_model_values(&left_val, &right_val);
688
689 return match cmp.op.as_str() {
690 "=" => ord == Some(Ordering::Equal),
691 "!=" => ord != Some(Ordering::Equal),
692 "<" => ord == Some(Ordering::Less),
693 ">" => ord == Some(Ordering::Greater),
694 "<=" => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
695 ">=" => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
696 _ => false,
697 };
698 }
699 }
700
701 let actual = row.get(&cmp.column);
703
704 if cmp.op == "IS NULL" {
705 return actual.map_or(true, |v| v.is_null());
706 }
707 if cmp.op == "IS NOT NULL" {
708 return actual.map_or(false, |v| !v.is_null());
709 }
710
711 let actual = match actual {
712 Some(v) if !v.is_null() => v,
713 _ => return false,
714 };
715
716 let expected = match &cmp.value {
717 Some(v) => v,
718 None => return false,
719 };
720
721 match cmp.op.as_str() {
722 "=" => eq_match(actual, expected),
723 "!=" => !eq_match(actual, expected),
724 "<" => compare_values(actual, expected) == Some(Ordering::Less),
725 ">" => compare_values(actual, expected) == Some(Ordering::Greater),
726 "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
727 ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
728 "LIKE" => like_match(actual, expected),
729 "NOT LIKE" => !like_match(actual, expected),
730 "IN" => {
731 if let SqlValue::List(items) = expected {
732 items.iter().any(|v| eq_match(actual, v))
733 } else {
734 eq_match(actual, expected)
735 }
736 }
737 _ => false,
738 }
739}
740
741fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
743 match (a, b) {
744 (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
745 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
746 _ => a.partial_cmp(b),
747 }
748}
749
750fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
751 match sql_val {
752 SqlValue::Null => Value::Null,
753 SqlValue::String(s) => {
754 match target {
755 Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
756 Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
757 Value::Date(_) => {
758 chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
759 .map(Value::Date)
760 .unwrap_or(Value::String(s.clone()))
761 }
762 Value::DateTime(_) => {
763 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S")
764 .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f"))
765 .map(Value::DateTime)
766 .unwrap_or(Value::String(s.clone()))
767 }
768 _ => Value::String(s.clone()),
769 }
770 }
771 SqlValue::Int(n) => {
772 match target {
773 Value::Float(_) => Value::Float(*n as f64),
774 _ => Value::Int(*n),
775 }
776 }
777 SqlValue::Float(f) => Value::Float(*f),
778 SqlValue::List(_) => Value::Null, }
780}
781
782fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
783 if let Value::List(items) = actual {
785 if let SqlValue::String(s) = expected {
786 return items.contains(s);
787 }
788 }
789
790 let coerced = coerce_sql_to_value(expected, actual);
791 actual == &coerced
792}
793
794fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
795 let pattern_str = match pattern {
796 SqlValue::String(s) => s,
797 _ => return false,
798 };
799
800 let mut regex_str = String::from("(?is)^");
802 for ch in pattern_str.chars() {
803 match ch {
804 '%' => regex_str.push_str(".*"),
805 '_' => regex_str.push('.'),
806 c => {
807 if regex::escape(&c.to_string()) != c.to_string() {
808 regex_str.push_str(®ex::escape(&c.to_string()));
809 } else {
810 regex_str.push(c);
811 }
812 }
813 }
814 }
815 regex_str.push('$');
816
817 let re = match Regex::new(®ex_str) {
818 Ok(r) => r,
819 Err(_) => return false,
820 };
821
822 match actual {
823 Value::List(items) => items.iter().any(|item| re.is_match(item)),
824 _ => re.is_match(&actual.to_display_string()),
825 }
826}
827
828fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
829 let coerced = coerce_sql_to_value(expected, actual);
830 actual.partial_cmp(&coerced).map(|o| o)
831}
832
833fn sql_value_to_index_value(sv: &SqlValue) -> Value {
835 match sv {
836 SqlValue::String(s) => {
837 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
839 return Value::DateTime(dt);
840 }
841 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") {
842 return Value::DateTime(dt);
843 }
844 if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
846 return Value::Date(d);
847 }
848 Value::String(s.clone())
849 }
850 SqlValue::Int(n) => Value::Int(*n),
851 SqlValue::Float(f) => Value::Float(*f),
852 SqlValue::Null => Value::Null,
853 SqlValue::List(_) => Value::Null,
854 }
855}
856
857fn try_index_filter(
861 clause: &WhereClause,
862 index: &crate::index::TableIndex,
863) -> Option<std::collections::HashSet<String>> {
864 match clause {
865 WhereClause::Comparison(cmp) => {
866 if !index.has_index(&cmp.column) {
867 return None;
868 }
869 match cmp.op.as_str() {
870 "=" => {
871 let val = sql_value_to_index_value(cmp.value.as_ref()?);
872 let paths = index.lookup_eq(&cmp.column, &val);
873 Some(paths.into_iter().map(|s| s.to_string()).collect())
874 }
875 "<" => {
876 let val = sql_value_to_index_value(cmp.value.as_ref()?);
877 let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
880 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
881 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
882 }
883 ">" => {
884 let val = sql_value_to_index_value(cmp.value.as_ref()?);
885 let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
886 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
887 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
888 }
889 "<=" => {
890 let val = sql_value_to_index_value(cmp.value.as_ref()?);
891 let paths = index.lookup_range(&cmp.column, None, Some(&val));
892 Some(paths.into_iter().map(|s| s.to_string()).collect())
893 }
894 ">=" => {
895 let val = sql_value_to_index_value(cmp.value.as_ref()?);
896 let paths = index.lookup_range(&cmp.column, Some(&val), None);
897 Some(paths.into_iter().map(|s| s.to_string()).collect())
898 }
899 "IN" => {
900 if let Some(SqlValue::List(items)) = &cmp.value {
901 let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
902 let paths = index.lookup_in(&cmp.column, &vals);
903 Some(paths.into_iter().map(|s| s.to_string()).collect())
904 } else {
905 None
906 }
907 }
908 _ => None, }
910 }
911 WhereClause::BoolOp(bop) => {
912 let left = try_index_filter(&bop.left, index);
913 let right = try_index_filter(&bop.right, index);
914 match bop.op.as_str() {
915 "AND" => {
916 match (left, right) {
917 (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
918 (Some(l), None) => Some(l), (None, Some(r)) => Some(r),
920 (None, None) => None,
921 }
922 }
923 "OR" => {
924 match (left, right) {
925 (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
926 _ => None, }
928 }
929 _ => None,
930 }
931 }
932 }
933}
934
935fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
938 let named = match columns {
939 ColumnList::Named(exprs) => exprs,
940 _ => return specs.to_vec(),
941 };
942
943 let alias_map: HashMap<String, &Expr> = named
945 .iter()
946 .filter_map(|se| match se {
947 SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
948 _ => None,
949 })
950 .collect();
951
952 specs
953 .iter()
954 .map(|spec| {
955 if let Some(expr) = alias_map.get(&spec.column) {
957 OrderSpec {
958 column: spec.column.clone(),
959 expr: Some((*expr).clone()),
960 descending: spec.descending,
961 }
962 } else {
963 spec.clone()
964 }
965 })
966 .collect()
967}
968
969fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
970 rows.sort_by(|a, b| {
971 for spec in specs {
972 let (va, vb) = if let Some(ref expr) = spec.expr {
973 (evaluate_expr(expr, a), evaluate_expr(expr, b))
974 } else {
975 (
976 a.get(&spec.column).cloned().unwrap_or(Value::Null),
977 b.get(&spec.column).cloned().unwrap_or(Value::Null),
978 )
979 };
980
981 let ordering = match (&va, &vb) {
983 (Value::Null, Value::Null) => Ordering::Equal,
984 (Value::Null, _) => Ordering::Greater,
985 (_, Value::Null) => Ordering::Less,
986 (a_val, b_val) => {
987 compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
988 }
989 };
990
991 let ordering = if spec.descending {
992 ordering.reverse()
993 } else {
994 ordering
995 };
996
997 if ordering != Ordering::Equal {
998 return ordering;
999 }
1000 }
1001 Ordering::Equal
1002 });
1003}
1004
1005pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
1007 match sql_val {
1008 SqlValue::Null => Value::Null,
1009 SqlValue::String(s) => Value::String(s.clone()),
1010 SqlValue::Int(n) => Value::Int(*n),
1011 SqlValue::Float(f) => Value::Float(*f),
1012 SqlValue::List(items) => {
1013 let strings: Vec<String> = items
1014 .iter()
1015 .filter_map(|v| match v {
1016 SqlValue::String(s) => Some(s.clone()),
1017 _ => None,
1018 })
1019 .collect();
1020 Value::List(strings)
1021 }
1022 }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027 use super::*;
1028
1029 fn make_rows() -> Vec<Row> {
1030 vec![
1031 Row::from([
1032 ("path".into(), Value::String("a.md".into())),
1033 ("title".into(), Value::String("Alpha".into())),
1034 ("count".into(), Value::Int(10)),
1035 ]),
1036 Row::from([
1037 ("path".into(), Value::String("b.md".into())),
1038 ("title".into(), Value::String("Beta".into())),
1039 ("count".into(), Value::Int(5)),
1040 ]),
1041 Row::from([
1042 ("path".into(), Value::String("c.md".into())),
1043 ("title".into(), Value::String("Gamma".into())),
1044 ("count".into(), Value::Int(20)),
1045 ]),
1046 ]
1047 }
1048
1049 #[test]
1050 fn test_select_all() {
1051 let q = SelectQuery {
1052 columns: ColumnList::All,
1053 table: "test".into(),
1054 table_alias: None,
1055 joins: vec![],
1056 where_clause: None,
1057 group_by: None,
1058 order_by: None,
1059 limit: None,
1060 };
1061 let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
1062 assert_eq!(rows.len(), 3);
1063 }
1064
1065 #[test]
1066 fn test_where_gt() {
1067 let q = SelectQuery {
1068 columns: ColumnList::All,
1069 table: "test".into(),
1070 table_alias: None,
1071 joins: vec![],
1072 where_clause: Some(WhereClause::Comparison(Comparison {
1073 column: "count".into(),
1074 op: ">".into(),
1075 value: Some(SqlValue::Int(5)),
1076 left_expr: Some(Expr::Column("count".into())),
1077 right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1078 })),
1079 group_by: None,
1080 order_by: None,
1081 limit: None,
1082 };
1083 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1084 assert_eq!(rows.len(), 2);
1085 }
1086
1087 #[test]
1088 fn test_order_by_desc() {
1089 let q = SelectQuery {
1090 columns: ColumnList::All,
1091 table: "test".into(),
1092 table_alias: None,
1093 joins: vec![],
1094 where_clause: None,
1095 group_by: None,
1096 order_by: Some(vec![OrderSpec {
1097 column: "count".into(),
1098 expr: Some(Expr::Column("count".into())),
1099 descending: true,
1100 }]),
1101 limit: None,
1102 };
1103 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1104 assert_eq!(rows[0]["count"], Value::Int(20));
1105 assert_eq!(rows[2]["count"], Value::Int(5));
1106 }
1107
1108 #[test]
1109 fn test_limit() {
1110 let q = SelectQuery {
1111 columns: ColumnList::All,
1112 table: "test".into(),
1113 table_alias: None,
1114 joins: vec![],
1115 where_clause: None,
1116 group_by: None,
1117 order_by: None,
1118 limit: Some(2),
1119 };
1120 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1121 assert_eq!(rows.len(), 2);
1122 }
1123
1124 #[test]
1125 fn test_like() {
1126 let q = SelectQuery {
1127 columns: ColumnList::All,
1128 table: "test".into(),
1129 table_alias: None,
1130 joins: vec![],
1131 where_clause: Some(WhereClause::Comparison(Comparison {
1132 column: "title".into(),
1133 op: "LIKE".into(),
1134 value: Some(SqlValue::String("%lph%".into())),
1135 left_expr: Some(Expr::Column("title".into())),
1136 right_expr: None,
1137 })),
1138 group_by: None,
1139 order_by: None,
1140 limit: None,
1141 };
1142 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1143 assert_eq!(rows.len(), 1);
1144 assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1145 }
1146
1147 #[test]
1148 fn test_is_null() {
1149 let mut rows = make_rows();
1150 rows[1].insert("optional".into(), Value::Null);
1151
1152 let q = SelectQuery {
1153 columns: ColumnList::All,
1154 table: "test".into(),
1155 table_alias: None,
1156 joins: vec![],
1157 where_clause: Some(WhereClause::Comparison(Comparison {
1158 column: "optional".into(),
1159 op: "IS NULL".into(),
1160 value: None,
1161 left_expr: Some(Expr::Column("optional".into())),
1162 right_expr: None,
1163 })),
1164 group_by: None,
1165 order_by: None,
1166 limit: None,
1167 };
1168 let (result, _) = execute(&q, &rows, None).unwrap();
1169 assert_eq!(result.len(), 3);
1171 }
1172
1173 #[test]
1176 fn test_evaluate_expr_literal() {
1177 let row = Row::new();
1178 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1179 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1180 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1181 }
1182
1183 #[test]
1184 fn test_evaluate_expr_column() {
1185 let row = Row::from([("x".into(), Value::Int(10))]);
1186 assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1187 assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1188 }
1189
1190 #[test]
1191 fn test_evaluate_expr_int_arithmetic() {
1192 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1193 let add = Expr::BinaryOp {
1194 left: Box::new(Expr::Column("a".into())),
1195 op: ArithOp::Add,
1196 right: Box::new(Expr::Column("b".into())),
1197 };
1198 assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1199
1200 let sub = Expr::BinaryOp {
1201 left: Box::new(Expr::Column("a".into())),
1202 op: ArithOp::Sub,
1203 right: Box::new(Expr::Column("b".into())),
1204 };
1205 assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1206
1207 let mul = Expr::BinaryOp {
1208 left: Box::new(Expr::Column("a".into())),
1209 op: ArithOp::Mul,
1210 right: Box::new(Expr::Column("b".into())),
1211 };
1212 assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1213
1214 let div = Expr::BinaryOp {
1215 left: Box::new(Expr::Column("a".into())),
1216 op: ArithOp::Div,
1217 right: Box::new(Expr::Column("b".into())),
1218 };
1219 assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); let modulo = Expr::BinaryOp {
1222 left: Box::new(Expr::Column("a".into())),
1223 op: ArithOp::Mod,
1224 right: Box::new(Expr::Column("b".into())),
1225 };
1226 assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1227 }
1228
1229 #[test]
1230 fn test_evaluate_expr_float_coercion() {
1231 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1232 let add = Expr::BinaryOp {
1233 left: Box::new(Expr::Column("a".into())),
1234 op: ArithOp::Add,
1235 right: Box::new(Expr::Column("b".into())),
1236 };
1237 assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1238 }
1239
1240 #[test]
1241 fn test_evaluate_expr_null_propagation() {
1242 let row = Row::from([("a".into(), Value::Int(10))]);
1243 let add = Expr::BinaryOp {
1244 left: Box::new(Expr::Column("a".into())),
1245 op: ArithOp::Add,
1246 right: Box::new(Expr::Column("missing".into())),
1247 };
1248 assert_eq!(evaluate_expr(&add, &row), Value::Null);
1249 }
1250
1251 #[test]
1252 fn test_evaluate_expr_div_by_zero() {
1253 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1254 let div = Expr::BinaryOp {
1255 left: Box::new(Expr::Column("a".into())),
1256 op: ArithOp::Div,
1257 right: Box::new(Expr::Column("b".into())),
1258 };
1259 assert_eq!(evaluate_expr(&div, &row), Value::Null);
1260 }
1261
1262 #[test]
1263 fn test_evaluate_expr_unary_minus() {
1264 let row = Row::from([("x".into(), Value::Int(5))]);
1265 let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1266 assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1267 }
1268
1269 #[test]
1270 fn test_select_with_expression() {
1271 let stmt = crate::query_parser::parse_query(
1273 "SELECT count * 2 AS doubled FROM test"
1274 ).unwrap();
1275 if let crate::query_parser::Statement::Select(q) = stmt {
1276 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1277 assert_eq!(cols, vec!["doubled"]);
1278 assert_eq!(rows.len(), 3);
1279 let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1281 assert!(values.contains(&Value::Int(20)));
1282 assert!(values.contains(&Value::Int(10)));
1283 assert!(values.contains(&Value::Int(40)));
1284 } else {
1285 panic!("Expected Select");
1286 }
1287 }
1288
1289 #[test]
1290 fn test_where_with_expression() {
1291 let stmt = crate::query_parser::parse_query(
1293 "SELECT * FROM test WHERE count * 2 > 15"
1294 ).unwrap();
1295 if let crate::query_parser::Statement::Select(q) = stmt {
1296 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1297 assert_eq!(rows.len(), 2);
1299 } else {
1300 panic!("Expected Select");
1301 }
1302 }
1303
1304 #[test]
1305 fn test_order_by_expression() {
1306 let stmt = crate::query_parser::parse_query(
1308 "SELECT title, count FROM test ORDER BY count * -1 ASC"
1309 ).unwrap();
1310 if let crate::query_parser::Statement::Select(q) = stmt {
1311 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1312 assert_eq!(rows[0]["count"], Value::Int(20));
1314 assert_eq!(rows[1]["count"], Value::Int(10));
1315 assert_eq!(rows[2]["count"], Value::Int(5));
1316 } else {
1317 panic!("Expected Select");
1318 }
1319 }
1320
1321 #[test]
1324 fn test_case_when_eval_basic() {
1325 let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1326 let expr = Expr::Case {
1327 whens: vec![(
1328 WhereClause::Comparison(Comparison {
1329 column: "status".into(),
1330 op: "=".into(),
1331 value: Some(SqlValue::String("ACTIVE".into())),
1332 left_expr: Some(Expr::Column("status".into())),
1333 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1334 }),
1335 Box::new(Expr::Literal(SqlValue::Int(1))),
1336 )],
1337 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1338 };
1339 assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1340 }
1341
1342 #[test]
1343 fn test_case_when_eval_else() {
1344 let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1345 let expr = Expr::Case {
1346 whens: vec![(
1347 WhereClause::Comparison(Comparison {
1348 column: "status".into(),
1349 op: "=".into(),
1350 value: Some(SqlValue::String("ACTIVE".into())),
1351 left_expr: Some(Expr::Column("status".into())),
1352 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1353 }),
1354 Box::new(Expr::Literal(SqlValue::Int(1))),
1355 )],
1356 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1357 };
1358 assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1359 }
1360
1361 #[test]
1362 fn test_case_when_eval_no_else_null() {
1363 let row = Row::from([("x".into(), Value::Int(99))]);
1364 let expr = Expr::Case {
1365 whens: vec![(
1366 WhereClause::Comparison(Comparison {
1367 column: "x".into(),
1368 op: "=".into(),
1369 value: Some(SqlValue::Int(1)),
1370 left_expr: Some(Expr::Column("x".into())),
1371 right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1372 }),
1373 Box::new(Expr::Literal(SqlValue::String("one".into()))),
1374 )],
1375 else_expr: None,
1376 };
1377 assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1378 }
1379
1380 #[test]
1381 fn test_case_when_in_aggregate_query() {
1382 let stmt = crate::query_parser::parse_query(
1385 "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1386 ).unwrap();
1387 if let crate::query_parser::Statement::Select(q) = stmt {
1388 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1389 assert_eq!(cols, vec!["total"]);
1390 assert_eq!(rows.len(), 1);
1391 assert_eq!(rows[0]["total"], Value::Float(30.0));
1392 } else {
1393 panic!("Expected Select");
1394 }
1395 }
1396
1397 #[test]
1398 fn test_case_when_with_unary_minus_in_aggregate() {
1399 let stmt = crate::query_parser::parse_query(
1402 "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1403 ).unwrap();
1404 if let crate::query_parser::Statement::Select(q) = stmt {
1405 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1406 assert_eq!(rows.len(), 1);
1407 assert_eq!(rows[0]["net"], Value::Float(-15.0));
1408 } else {
1409 panic!("Expected Select");
1410 }
1411 }
1412}