1use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14 query: &SelectQuery,
15 rows: &[Row],
16 _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18 execute(query, rows, None)
19}
20
21pub fn execute_query_indexed(
23 query: &SelectQuery,
24 rows: &[Row],
25 schema: &Schema,
26 index: Option<&crate::index::TableIndex>,
27 searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29 let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31 collect_fts_results(wc, schema, searcher)
32 } else {
33 HashMap::new()
34 };
35
36 execute_with_fts(query, rows, index, &fts_results)
37}
38
39fn collect_fts_results(
42 clause: &WhereClause,
43 schema: &Schema,
44 searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46 let mut results = HashMap::new();
47 collect_fts_results_inner(clause, schema, searcher, &mut results);
48 results
49}
50
51fn collect_fts_results_inner(
52 clause: &WhereClause,
53 schema: &Schema,
54 searcher: &crate::search::TableSearcher,
55 results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57 match clause {
58 WhereClause::Comparison(cmp) => {
59 if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60 if let Some(SqlValue::String(pattern)) = &cmp.value {
61 let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63 if !search_term.is_empty() {
64 if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65 let key = (cmp.column.clone(), pattern.clone());
66 results.insert(key, paths.into_iter().collect());
67 }
68 }
69 }
70 }
71 }
72 WhereClause::BoolOp(bop) => {
73 collect_fts_results_inner(&bop.left, schema, searcher, results);
74 collect_fts_results_inner(&bop.right, schema, searcher, results);
75 }
76 }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82 query: &SelectQuery,
83 rows: &[Row],
84 index: Option<&crate::index::TableIndex>,
85 fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87 let mut all_columns: Vec<String> = Vec::new();
89 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90 for r in rows {
91 for k in r.keys() {
92 if seen.insert(k.clone()) {
93 all_columns.push(k.clone());
94 }
95 }
96 }
97
98 let has_aggregates = match &query.columns {
100 ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101 _ => false,
102 };
103
104 let columns: Vec<String> = match &query.columns {
106 ColumnList::All => all_columns,
107 ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108 };
109
110 let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112 let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113 if let Some(paths) = candidate_paths {
114 rows.iter()
115 .filter(|r| {
116 r.get("path")
117 .and_then(|v| v.as_str())
118 .map_or(false, |p| paths.contains(p))
119 })
120 .filter(|r| evaluate_with_fts(wc, r, fts))
121 .cloned()
122 .collect()
123 } else {
124 rows.iter()
125 .filter(|r| evaluate_with_fts(wc, r, fts))
126 .cloned()
127 .collect()
128 }
129 } else {
130 rows.to_vec()
131 };
132
133 let mut result = if has_aggregates || query.group_by.is_some() {
135 let exprs = match &query.columns {
136 ColumnList::Named(exprs) => exprs.clone(),
137 _ => return Err(MdqlError::QueryExecution(
138 "SELECT * with GROUP BY is not supported".into(),
139 )),
140 };
141 let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142 aggregate_rows(&filtered, &exprs, group_keys)?
143 } else {
144 filtered
145 };
146
147 if let Some(ref order_by) = query.order_by {
149 let resolved = resolve_order_aliases(order_by, &query.columns);
150 sort_rows(&mut result, &resolved);
151 }
152
153 if let Some(limit) = query.limit {
155 result.truncate(limit as usize);
156 }
157
158 if !matches!(query.columns, ColumnList::All) {
160 let named_exprs = match &query.columns {
161 ColumnList::Named(exprs) => exprs,
162 _ => unreachable!(),
163 };
164
165 let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
167 if has_expr_cols {
168 for row in &mut result {
169 for expr in named_exprs {
170 if let SelectExpr::Expr { expr: e, alias } = expr {
171 let name = alias.clone().unwrap_or_else(|| e.display_name());
172 let val = evaluate_expr(e, row);
173 row.insert(name, val);
174 }
175 }
176 }
177 }
178
179 let col_set: std::collections::HashSet<&str> =
180 columns.iter().map(|s| s.as_str()).collect();
181 for row in &mut result {
182 row.retain(|k, _| col_set.contains(k.as_str()));
183 }
184 }
185
186 Ok((result, columns))
187}
188
189fn aggregate_rows(
190 rows: &[Row],
191 exprs: &[SelectExpr],
192 group_keys: &[String],
193) -> crate::errors::Result<Vec<Row>> {
194 let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
196 let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
197
198 if group_keys.is_empty() {
199 let all_refs: Vec<&Row> = rows.iter().collect();
201 groups.push((vec![], all_refs));
202 } else {
203 for row in rows {
204 let key: Vec<String> = group_keys
205 .iter()
206 .map(|k| {
207 row.get(k)
208 .map(|v| v.to_display_string())
209 .unwrap_or_default()
210 })
211 .collect();
212 let key_vals: Vec<Value> = group_keys
213 .iter()
214 .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
215 .collect();
216 if let Some(&idx) = key_index.get(&key) {
217 groups[idx].1.push(row);
218 } else {
219 let idx = groups.len();
220 key_index.insert(key, idx);
221 groups.push((key_vals, vec![row]));
222 }
223 }
224 }
225
226 let mut result = Vec::new();
228 for (key_vals, group_rows) in &groups {
229 let mut out = Row::new();
230
231 for (i, k) in group_keys.iter().enumerate() {
233 out.insert(k.clone(), key_vals[i].clone());
234 }
235
236 for expr in exprs {
238 match expr {
239 SelectExpr::Column(name) => {
240 if !out.contains_key(name) {
242 if let Some(first) = group_rows.first() {
243 out.insert(
244 name.clone(),
245 first.get(name).cloned().unwrap_or(Value::Null),
246 );
247 }
248 }
249 }
250 SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
251 let out_name = alias
252 .clone()
253 .unwrap_or_else(|| expr.output_name());
254 let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
255 out.insert(out_name, val);
256 }
257 SelectExpr::Expr { expr: e, alias } => {
258 let out_name = alias.clone().unwrap_or_else(|| e.display_name());
259 if let Some(first) = group_rows.first() {
260 let val = evaluate_expr(e, first);
261 out.insert(out_name, val);
262 }
263 }
264 }
265 }
266
267 result.push(out);
268 }
269
270 Ok(result)
271}
272
273fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
276 if let Some(expr) = arg_expr {
277 evaluate_expr(expr, row)
278 } else {
279 row.get(arg).cloned().unwrap_or(Value::Null)
280 }
281}
282
283fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
284 match func {
285 AggFunc::Count => {
286 if arg == "*" && arg_expr.is_none() {
287 Value::Int(rows.len() as i64)
288 } else {
289 let count = rows
290 .iter()
291 .filter(|r| {
292 let v = resolve_agg_value(arg, arg_expr, r);
293 !v.is_null()
294 })
295 .count();
296 Value::Int(count as i64)
297 }
298 }
299 AggFunc::Sum => {
300 let mut total = 0.0f64;
301 let mut has_any = false;
302 for r in rows {
303 let v = resolve_agg_value(arg, arg_expr, r);
304 match v {
305 Value::Int(n) => { total += n as f64; has_any = true; }
306 Value::Float(f) => { total += f; has_any = true; }
307 _ => {}
308 }
309 }
310 if has_any { Value::Float(total) } else { Value::Null }
311 }
312 AggFunc::Avg => {
313 let mut total = 0.0f64;
314 let mut count = 0usize;
315 for r in rows {
316 let v = resolve_agg_value(arg, arg_expr, r);
317 match v {
318 Value::Int(n) => { total += n as f64; count += 1; }
319 Value::Float(f) => { total += f; count += 1; }
320 _ => {}
321 }
322 }
323 if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
324 }
325 AggFunc::Min => {
326 let mut min_val: Option<Value> = None;
327 for r in rows {
328 let v = resolve_agg_value(arg, arg_expr, r);
329 if v.is_null() { continue; }
330 min_val = Some(match min_val {
331 None => v,
332 Some(ref current) => {
333 if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
334 v
335 } else {
336 current.clone()
337 }
338 }
339 });
340 }
341 min_val.unwrap_or(Value::Null)
342 }
343 AggFunc::Max => {
344 let mut max_val: Option<Value> = None;
345 for r in rows {
346 let v = resolve_agg_value(arg, arg_expr, r);
347 if v.is_null() { continue; }
348 max_val = Some(match max_val {
349 None => v,
350 Some(ref current) => {
351 if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
352 v
353 } else {
354 current.clone()
355 }
356 }
357 });
358 }
359 max_val.unwrap_or(Value::Null)
360 }
361 }
362}
363
364fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
365 match clause {
366 WhereClause::BoolOp(bop) => {
367 let left = evaluate_with_fts(&bop.left, row, fts);
368 match bop.op.as_str() {
369 "AND" => left && evaluate_with_fts(&bop.right, row, fts),
370 "OR" => left || evaluate_with_fts(&bop.right, row, fts),
371 _ => false,
372 }
373 }
374 WhereClause::Comparison(cmp) => {
375 if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
377 if let Some(SqlValue::String(pattern)) = &cmp.value {
378 let key = (cmp.column.clone(), pattern.clone());
379 if let Some(matching_paths) = fts.get(&key) {
380 let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
381 let matched = matching_paths.contains(row_path);
382 return if cmp.op == "LIKE" { matched } else { !matched };
383 }
384 }
385 }
386 evaluate_comparison(cmp, row)
387 }
388 }
389}
390
391pub fn execute_join_query(
392 query: &SelectQuery,
393 tables: &HashMap<String, (Schema, Vec<Row>)>,
394) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
395 if query.joins.is_empty() {
396 return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
397 }
398
399 let left_name = &query.table;
400 let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
401
402 let mut aliases: HashMap<String, String> = HashMap::new();
404 aliases.insert(left_name.clone(), left_name.clone());
405 if let Some(ref a) = query.table_alias {
406 aliases.insert(a.clone(), left_name.clone());
407 }
408 for join in &query.joins {
409 aliases.insert(join.table.clone(), join.table.clone());
410 if let Some(ref a) = join.alias {
411 aliases.insert(a.clone(), join.table.clone());
412 }
413 }
414
415 let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
417 MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
418 })?;
419
420 let mut current_rows: Vec<Row> = left_rows
421 .iter()
422 .map(|r| {
423 let mut prefixed = Row::new();
424 for (k, v) in r {
425 prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
426 }
427 prefixed
428 })
429 .collect();
430
431 for join in &query.joins {
433 let right_name = &join.table;
434 let right_alias = join.alias.as_deref().unwrap_or(right_name);
435
436 let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
437 MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
438 })?;
439
440 let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
442 let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
443
444 let (left_key, right_key) = if on_right_table == *right_name {
446 let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
448 (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
449 } else {
450 let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
452 (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
453 };
454
455 let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
457 for r in right_rows {
458 if let Some(key) = r.get(&right_key) {
459 let key_str = key.to_display_string();
460 right_index.entry(key_str).or_default().push(r);
461 }
462 }
463
464 let mut next_rows: Vec<Row> = Vec::new();
466 for lr in ¤t_rows {
467 if let Some(key) = lr.get(&left_key) {
468 let key_str = key.to_display_string();
469 if let Some(matching) = right_index.get(&key_str) {
470 for rr in matching {
471 let mut merged = lr.clone();
472 for (k, v) in *rr {
473 merged.insert(format!("{}.{}", right_alias, k), v.clone());
474 }
475 next_rows.push(merged);
476 }
477 }
478 }
479 }
480 current_rows = next_rows;
481 }
482
483 let (mut result, columns) = execute(query, ¤t_rows, None)?;
484
485 if !result.is_empty() {
489 let mut base_counts: HashMap<String, usize> = HashMap::new();
490 for key in &columns {
491 if let Some((_prefix, base)) = key.split_once('.') {
492 *base_counts.entry(base.to_string()).or_default() += 1;
493 }
494 }
495 let unique_bases: Vec<String> = base_counts
496 .into_iter()
497 .filter(|(_, count)| *count == 1)
498 .map(|(base, _)| base)
499 .collect();
500
501 if !unique_bases.is_empty() {
502 let unique_set: std::collections::HashSet<&str> =
503 unique_bases.iter().map(|s| s.as_str()).collect();
504 for row in &mut result {
505 let additions: Vec<(String, Value)> = row
506 .iter()
507 .filter_map(|(k, v)| {
508 k.split_once('.').and_then(|(_, base)| {
509 if unique_set.contains(base) {
510 Some((base.to_string(), v.clone()))
511 } else {
512 None
513 }
514 })
515 })
516 .collect();
517 for (k, v) in additions {
518 row.insert(k, v);
519 }
520 }
521 }
522 }
523
524 Ok((result, columns))
525}
526
527fn reverse_alias(
529 table_name: &str,
530 aliases: &HashMap<String, String>,
531 query: &SelectQuery,
532 joins: &[JoinClause],
533) -> String {
534 if query.table == table_name {
536 return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
537 }
538 for j in joins {
540 if j.table == table_name {
541 return j.alias.as_deref().unwrap_or(&j.table).to_string();
542 }
543 }
544 if aliases.contains_key(table_name) {
546 return table_name.to_string();
547 }
548 table_name.to_string()
549}
550
551fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
552 if let Some((alias, column)) = col.split_once('.') {
553 let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
554 (table, column.to_string())
555 } else {
556 (String::new(), col.to_string())
557 }
558}
559
560fn execute(
561 query: &SelectQuery,
562 rows: &[Row],
563 index: Option<&crate::index::TableIndex>,
564) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
565 let empty_fts = HashMap::new();
566 execute_with_fts(query, rows, index, &empty_fts)
567}
568
569pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
570 match clause {
571 WhereClause::BoolOp(bop) => {
572 let left = evaluate(&bop.left, row);
573 match bop.op.as_str() {
574 "AND" => left && evaluate(&bop.right, row),
575 "OR" => left || evaluate(&bop.right, row),
576 _ => false,
577 }
578 }
579 WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
580 }
581}
582
583pub fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
585 match expr {
586 Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
587 Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
588 Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
589 Expr::Literal(SqlValue::Null) => Value::Null,
590 Expr::Literal(SqlValue::List(_)) => Value::Null,
591 Expr::Column(name) => row.get(name).cloned().unwrap_or(Value::Null),
592 Expr::UnaryMinus(inner) => {
593 match evaluate_expr(inner, row) {
594 Value::Int(n) => Value::Int(-n),
595 Value::Float(f) => Value::Float(-f),
596 Value::Null => Value::Null,
597 _ => Value::Null, }
599 }
600 Expr::BinaryOp { left, op, right } => {
601 let lv = evaluate_expr(left, row);
602 let rv = evaluate_expr(right, row);
603
604 if lv.is_null() || rv.is_null() {
606 return Value::Null;
607 }
608
609 match (&lv, &rv) {
611 (Value::Int(a), Value::Int(b)) => {
612 match op {
613 ArithOp::Add => Value::Int(a.wrapping_add(*b)),
614 ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
615 ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
616 ArithOp::Div => {
617 if *b == 0 { Value::Null } else { Value::Int(a / b) }
618 }
619 ArithOp::Mod => {
620 if *b == 0 { Value::Null } else { Value::Int(a % b) }
621 }
622 }
623 }
624 _ => {
625 let a = match &lv {
627 Value::Int(n) => *n as f64,
628 Value::Float(f) => *f,
629 _ => return Value::Null,
630 };
631 let b = match &rv {
632 Value::Int(n) => *n as f64,
633 Value::Float(f) => *f,
634 _ => return Value::Null,
635 };
636 match op {
637 ArithOp::Add => Value::Float(a + b),
638 ArithOp::Sub => Value::Float(a - b),
639 ArithOp::Mul => Value::Float(a * b),
640 ArithOp::Div => {
641 if b == 0.0 { Value::Null } else { Value::Float(a / b) }
642 }
643 ArithOp::Mod => {
644 if b == 0.0 { Value::Null } else { Value::Float(a % b) }
645 }
646 }
647 }
648 }
649 }
650 Expr::Case { whens, else_expr } => {
651 for (condition, result) in whens {
652 if evaluate(condition, row) {
653 return evaluate_expr(result, row);
654 }
655 }
656 match else_expr {
657 Some(e) => evaluate_expr(e, row),
658 None => Value::Null,
659 }
660 }
661 }
662}
663
664fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
665 if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
667 if ["=", "!=", "<", ">", "<=", ">="].contains(&cmp.op.as_str()) {
668 let left_val = evaluate_expr(left_expr, row);
669 let right_val = evaluate_expr(right_expr, row);
670
671 if left_val.is_null() || right_val.is_null() {
673 return false;
674 }
675
676 let ord = compare_model_values(&left_val, &right_val);
678
679 return match cmp.op.as_str() {
680 "=" => ord == Some(Ordering::Equal),
681 "!=" => ord != Some(Ordering::Equal),
682 "<" => ord == Some(Ordering::Less),
683 ">" => ord == Some(Ordering::Greater),
684 "<=" => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
685 ">=" => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
686 _ => false,
687 };
688 }
689 }
690
691 let actual = row.get(&cmp.column);
693
694 if cmp.op == "IS NULL" {
695 return actual.map_or(true, |v| v.is_null());
696 }
697 if cmp.op == "IS NOT NULL" {
698 return actual.map_or(false, |v| !v.is_null());
699 }
700
701 let actual = match actual {
702 Some(v) if !v.is_null() => v,
703 _ => return false,
704 };
705
706 let expected = match &cmp.value {
707 Some(v) => v,
708 None => return false,
709 };
710
711 match cmp.op.as_str() {
712 "=" => eq_match(actual, expected),
713 "!=" => !eq_match(actual, expected),
714 "<" => compare_values(actual, expected) == Some(Ordering::Less),
715 ">" => compare_values(actual, expected) == Some(Ordering::Greater),
716 "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
717 ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
718 "LIKE" => like_match(actual, expected),
719 "NOT LIKE" => !like_match(actual, expected),
720 "IN" => {
721 if let SqlValue::List(items) = expected {
722 items.iter().any(|v| eq_match(actual, v))
723 } else {
724 eq_match(actual, expected)
725 }
726 }
727 _ => false,
728 }
729}
730
731fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
733 match (a, b) {
734 (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
735 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
736 _ => a.partial_cmp(b),
737 }
738}
739
740fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
741 match sql_val {
742 SqlValue::Null => Value::Null,
743 SqlValue::String(s) => {
744 match target {
745 Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
746 Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
747 Value::Date(_) => {
748 chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
749 .map(Value::Date)
750 .unwrap_or(Value::String(s.clone()))
751 }
752 _ => Value::String(s.clone()),
753 }
754 }
755 SqlValue::Int(n) => {
756 match target {
757 Value::Float(_) => Value::Float(*n as f64),
758 _ => Value::Int(*n),
759 }
760 }
761 SqlValue::Float(f) => Value::Float(*f),
762 SqlValue::List(_) => Value::Null, }
764}
765
766fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
767 if let Value::List(items) = actual {
769 if let SqlValue::String(s) = expected {
770 return items.contains(s);
771 }
772 }
773
774 let coerced = coerce_sql_to_value(expected, actual);
775 actual == &coerced
776}
777
778fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
779 let pattern_str = match pattern {
780 SqlValue::String(s) => s,
781 _ => return false,
782 };
783
784 let mut regex_str = String::from("(?is)^");
786 for ch in pattern_str.chars() {
787 match ch {
788 '%' => regex_str.push_str(".*"),
789 '_' => regex_str.push('.'),
790 c => {
791 if regex::escape(&c.to_string()) != c.to_string() {
792 regex_str.push_str(®ex::escape(&c.to_string()));
793 } else {
794 regex_str.push(c);
795 }
796 }
797 }
798 }
799 regex_str.push('$');
800
801 let re = match Regex::new(®ex_str) {
802 Ok(r) => r,
803 Err(_) => return false,
804 };
805
806 match actual {
807 Value::List(items) => items.iter().any(|item| re.is_match(item)),
808 _ => re.is_match(&actual.to_display_string()),
809 }
810}
811
812fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
813 let coerced = coerce_sql_to_value(expected, actual);
814 actual.partial_cmp(&coerced).map(|o| o)
815}
816
817fn sql_value_to_index_value(sv: &SqlValue) -> Value {
819 match sv {
820 SqlValue::String(s) => {
821 if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
823 return Value::Date(d);
824 }
825 Value::String(s.clone())
826 }
827 SqlValue::Int(n) => Value::Int(*n),
828 SqlValue::Float(f) => Value::Float(*f),
829 SqlValue::Null => Value::Null,
830 SqlValue::List(_) => Value::Null,
831 }
832}
833
834fn try_index_filter(
838 clause: &WhereClause,
839 index: &crate::index::TableIndex,
840) -> Option<std::collections::HashSet<String>> {
841 match clause {
842 WhereClause::Comparison(cmp) => {
843 if !index.has_index(&cmp.column) {
844 return None;
845 }
846 match cmp.op.as_str() {
847 "=" => {
848 let val = sql_value_to_index_value(cmp.value.as_ref()?);
849 let paths = index.lookup_eq(&cmp.column, &val);
850 Some(paths.into_iter().map(|s| s.to_string()).collect())
851 }
852 "<" => {
853 let val = sql_value_to_index_value(cmp.value.as_ref()?);
854 let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
857 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
858 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
859 }
860 ">" => {
861 let val = sql_value_to_index_value(cmp.value.as_ref()?);
862 let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
863 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
864 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
865 }
866 "<=" => {
867 let val = sql_value_to_index_value(cmp.value.as_ref()?);
868 let paths = index.lookup_range(&cmp.column, None, Some(&val));
869 Some(paths.into_iter().map(|s| s.to_string()).collect())
870 }
871 ">=" => {
872 let val = sql_value_to_index_value(cmp.value.as_ref()?);
873 let paths = index.lookup_range(&cmp.column, Some(&val), None);
874 Some(paths.into_iter().map(|s| s.to_string()).collect())
875 }
876 "IN" => {
877 if let Some(SqlValue::List(items)) = &cmp.value {
878 let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
879 let paths = index.lookup_in(&cmp.column, &vals);
880 Some(paths.into_iter().map(|s| s.to_string()).collect())
881 } else {
882 None
883 }
884 }
885 _ => None, }
887 }
888 WhereClause::BoolOp(bop) => {
889 let left = try_index_filter(&bop.left, index);
890 let right = try_index_filter(&bop.right, index);
891 match bop.op.as_str() {
892 "AND" => {
893 match (left, right) {
894 (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
895 (Some(l), None) => Some(l), (None, Some(r)) => Some(r),
897 (None, None) => None,
898 }
899 }
900 "OR" => {
901 match (left, right) {
902 (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
903 _ => None, }
905 }
906 _ => None,
907 }
908 }
909 }
910}
911
912fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
915 let named = match columns {
916 ColumnList::Named(exprs) => exprs,
917 _ => return specs.to_vec(),
918 };
919
920 let alias_map: HashMap<String, &Expr> = named
922 .iter()
923 .filter_map(|se| match se {
924 SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
925 _ => None,
926 })
927 .collect();
928
929 specs
930 .iter()
931 .map(|spec| {
932 if let Some(expr) = alias_map.get(&spec.column) {
934 OrderSpec {
935 column: spec.column.clone(),
936 expr: Some((*expr).clone()),
937 descending: spec.descending,
938 }
939 } else {
940 spec.clone()
941 }
942 })
943 .collect()
944}
945
946fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
947 rows.sort_by(|a, b| {
948 for spec in specs {
949 let (va, vb) = if let Some(ref expr) = spec.expr {
950 (evaluate_expr(expr, a), evaluate_expr(expr, b))
951 } else {
952 (
953 a.get(&spec.column).cloned().unwrap_or(Value::Null),
954 b.get(&spec.column).cloned().unwrap_or(Value::Null),
955 )
956 };
957
958 let ordering = match (&va, &vb) {
960 (Value::Null, Value::Null) => Ordering::Equal,
961 (Value::Null, _) => Ordering::Greater,
962 (_, Value::Null) => Ordering::Less,
963 (a_val, b_val) => {
964 compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
965 }
966 };
967
968 let ordering = if spec.descending {
969 ordering.reverse()
970 } else {
971 ordering
972 };
973
974 if ordering != Ordering::Equal {
975 return ordering;
976 }
977 }
978 Ordering::Equal
979 });
980}
981
982pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
984 match sql_val {
985 SqlValue::Null => Value::Null,
986 SqlValue::String(s) => Value::String(s.clone()),
987 SqlValue::Int(n) => Value::Int(*n),
988 SqlValue::Float(f) => Value::Float(*f),
989 SqlValue::List(items) => {
990 let strings: Vec<String> = items
991 .iter()
992 .filter_map(|v| match v {
993 SqlValue::String(s) => Some(s.clone()),
994 _ => None,
995 })
996 .collect();
997 Value::List(strings)
998 }
999 }
1000}
1001
1002#[cfg(test)]
1003mod tests {
1004 use super::*;
1005
1006 fn make_rows() -> Vec<Row> {
1007 vec![
1008 Row::from([
1009 ("path".into(), Value::String("a.md".into())),
1010 ("title".into(), Value::String("Alpha".into())),
1011 ("count".into(), Value::Int(10)),
1012 ]),
1013 Row::from([
1014 ("path".into(), Value::String("b.md".into())),
1015 ("title".into(), Value::String("Beta".into())),
1016 ("count".into(), Value::Int(5)),
1017 ]),
1018 Row::from([
1019 ("path".into(), Value::String("c.md".into())),
1020 ("title".into(), Value::String("Gamma".into())),
1021 ("count".into(), Value::Int(20)),
1022 ]),
1023 ]
1024 }
1025
1026 #[test]
1027 fn test_select_all() {
1028 let q = SelectQuery {
1029 columns: ColumnList::All,
1030 table: "test".into(),
1031 table_alias: None,
1032 joins: vec![],
1033 where_clause: None,
1034 group_by: None,
1035 order_by: None,
1036 limit: None,
1037 };
1038 let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
1039 assert_eq!(rows.len(), 3);
1040 }
1041
1042 #[test]
1043 fn test_where_gt() {
1044 let q = SelectQuery {
1045 columns: ColumnList::All,
1046 table: "test".into(),
1047 table_alias: None,
1048 joins: vec![],
1049 where_clause: Some(WhereClause::Comparison(Comparison {
1050 column: "count".into(),
1051 op: ">".into(),
1052 value: Some(SqlValue::Int(5)),
1053 left_expr: Some(Expr::Column("count".into())),
1054 right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1055 })),
1056 group_by: None,
1057 order_by: None,
1058 limit: None,
1059 };
1060 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1061 assert_eq!(rows.len(), 2);
1062 }
1063
1064 #[test]
1065 fn test_order_by_desc() {
1066 let q = SelectQuery {
1067 columns: ColumnList::All,
1068 table: "test".into(),
1069 table_alias: None,
1070 joins: vec![],
1071 where_clause: None,
1072 group_by: None,
1073 order_by: Some(vec![OrderSpec {
1074 column: "count".into(),
1075 expr: Some(Expr::Column("count".into())),
1076 descending: true,
1077 }]),
1078 limit: None,
1079 };
1080 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1081 assert_eq!(rows[0]["count"], Value::Int(20));
1082 assert_eq!(rows[2]["count"], Value::Int(5));
1083 }
1084
1085 #[test]
1086 fn test_limit() {
1087 let q = SelectQuery {
1088 columns: ColumnList::All,
1089 table: "test".into(),
1090 table_alias: None,
1091 joins: vec![],
1092 where_clause: None,
1093 group_by: None,
1094 order_by: None,
1095 limit: Some(2),
1096 };
1097 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1098 assert_eq!(rows.len(), 2);
1099 }
1100
1101 #[test]
1102 fn test_like() {
1103 let q = SelectQuery {
1104 columns: ColumnList::All,
1105 table: "test".into(),
1106 table_alias: None,
1107 joins: vec![],
1108 where_clause: Some(WhereClause::Comparison(Comparison {
1109 column: "title".into(),
1110 op: "LIKE".into(),
1111 value: Some(SqlValue::String("%lph%".into())),
1112 left_expr: Some(Expr::Column("title".into())),
1113 right_expr: None,
1114 })),
1115 group_by: None,
1116 order_by: None,
1117 limit: None,
1118 };
1119 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1120 assert_eq!(rows.len(), 1);
1121 assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1122 }
1123
1124 #[test]
1125 fn test_is_null() {
1126 let mut rows = make_rows();
1127 rows[1].insert("optional".into(), Value::Null);
1128
1129 let q = SelectQuery {
1130 columns: ColumnList::All,
1131 table: "test".into(),
1132 table_alias: None,
1133 joins: vec![],
1134 where_clause: Some(WhereClause::Comparison(Comparison {
1135 column: "optional".into(),
1136 op: "IS NULL".into(),
1137 value: None,
1138 left_expr: Some(Expr::Column("optional".into())),
1139 right_expr: None,
1140 })),
1141 group_by: None,
1142 order_by: None,
1143 limit: None,
1144 };
1145 let (result, _) = execute(&q, &rows, None).unwrap();
1146 assert_eq!(result.len(), 3);
1148 }
1149
1150 #[test]
1153 fn test_evaluate_expr_literal() {
1154 let row = Row::new();
1155 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1156 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1157 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1158 }
1159
1160 #[test]
1161 fn test_evaluate_expr_column() {
1162 let row = Row::from([("x".into(), Value::Int(10))]);
1163 assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1164 assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1165 }
1166
1167 #[test]
1168 fn test_evaluate_expr_int_arithmetic() {
1169 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1170 let add = Expr::BinaryOp {
1171 left: Box::new(Expr::Column("a".into())),
1172 op: ArithOp::Add,
1173 right: Box::new(Expr::Column("b".into())),
1174 };
1175 assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1176
1177 let sub = Expr::BinaryOp {
1178 left: Box::new(Expr::Column("a".into())),
1179 op: ArithOp::Sub,
1180 right: Box::new(Expr::Column("b".into())),
1181 };
1182 assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1183
1184 let mul = Expr::BinaryOp {
1185 left: Box::new(Expr::Column("a".into())),
1186 op: ArithOp::Mul,
1187 right: Box::new(Expr::Column("b".into())),
1188 };
1189 assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1190
1191 let div = Expr::BinaryOp {
1192 left: Box::new(Expr::Column("a".into())),
1193 op: ArithOp::Div,
1194 right: Box::new(Expr::Column("b".into())),
1195 };
1196 assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); let modulo = Expr::BinaryOp {
1199 left: Box::new(Expr::Column("a".into())),
1200 op: ArithOp::Mod,
1201 right: Box::new(Expr::Column("b".into())),
1202 };
1203 assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1204 }
1205
1206 #[test]
1207 fn test_evaluate_expr_float_coercion() {
1208 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1209 let add = Expr::BinaryOp {
1210 left: Box::new(Expr::Column("a".into())),
1211 op: ArithOp::Add,
1212 right: Box::new(Expr::Column("b".into())),
1213 };
1214 assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1215 }
1216
1217 #[test]
1218 fn test_evaluate_expr_null_propagation() {
1219 let row = Row::from([("a".into(), Value::Int(10))]);
1220 let add = Expr::BinaryOp {
1221 left: Box::new(Expr::Column("a".into())),
1222 op: ArithOp::Add,
1223 right: Box::new(Expr::Column("missing".into())),
1224 };
1225 assert_eq!(evaluate_expr(&add, &row), Value::Null);
1226 }
1227
1228 #[test]
1229 fn test_evaluate_expr_div_by_zero() {
1230 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1231 let div = Expr::BinaryOp {
1232 left: Box::new(Expr::Column("a".into())),
1233 op: ArithOp::Div,
1234 right: Box::new(Expr::Column("b".into())),
1235 };
1236 assert_eq!(evaluate_expr(&div, &row), Value::Null);
1237 }
1238
1239 #[test]
1240 fn test_evaluate_expr_unary_minus() {
1241 let row = Row::from([("x".into(), Value::Int(5))]);
1242 let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1243 assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1244 }
1245
1246 #[test]
1247 fn test_select_with_expression() {
1248 let stmt = crate::query_parser::parse_query(
1250 "SELECT count * 2 AS doubled FROM test"
1251 ).unwrap();
1252 if let crate::query_parser::Statement::Select(q) = stmt {
1253 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1254 assert_eq!(cols, vec!["doubled"]);
1255 assert_eq!(rows.len(), 3);
1256 let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1258 assert!(values.contains(&Value::Int(20)));
1259 assert!(values.contains(&Value::Int(10)));
1260 assert!(values.contains(&Value::Int(40)));
1261 } else {
1262 panic!("Expected Select");
1263 }
1264 }
1265
1266 #[test]
1267 fn test_where_with_expression() {
1268 let stmt = crate::query_parser::parse_query(
1270 "SELECT * FROM test WHERE count * 2 > 15"
1271 ).unwrap();
1272 if let crate::query_parser::Statement::Select(q) = stmt {
1273 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1274 assert_eq!(rows.len(), 2);
1276 } else {
1277 panic!("Expected Select");
1278 }
1279 }
1280
1281 #[test]
1282 fn test_order_by_expression() {
1283 let stmt = crate::query_parser::parse_query(
1285 "SELECT title, count FROM test ORDER BY count * -1 ASC"
1286 ).unwrap();
1287 if let crate::query_parser::Statement::Select(q) = stmt {
1288 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1289 assert_eq!(rows[0]["count"], Value::Int(20));
1291 assert_eq!(rows[1]["count"], Value::Int(10));
1292 assert_eq!(rows[2]["count"], Value::Int(5));
1293 } else {
1294 panic!("Expected Select");
1295 }
1296 }
1297
1298 #[test]
1301 fn test_case_when_eval_basic() {
1302 let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1303 let expr = Expr::Case {
1304 whens: vec![(
1305 WhereClause::Comparison(Comparison {
1306 column: "status".into(),
1307 op: "=".into(),
1308 value: Some(SqlValue::String("ACTIVE".into())),
1309 left_expr: Some(Expr::Column("status".into())),
1310 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1311 }),
1312 Box::new(Expr::Literal(SqlValue::Int(1))),
1313 )],
1314 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1315 };
1316 assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1317 }
1318
1319 #[test]
1320 fn test_case_when_eval_else() {
1321 let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1322 let expr = Expr::Case {
1323 whens: vec![(
1324 WhereClause::Comparison(Comparison {
1325 column: "status".into(),
1326 op: "=".into(),
1327 value: Some(SqlValue::String("ACTIVE".into())),
1328 left_expr: Some(Expr::Column("status".into())),
1329 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1330 }),
1331 Box::new(Expr::Literal(SqlValue::Int(1))),
1332 )],
1333 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1334 };
1335 assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1336 }
1337
1338 #[test]
1339 fn test_case_when_eval_no_else_null() {
1340 let row = Row::from([("x".into(), Value::Int(99))]);
1341 let expr = Expr::Case {
1342 whens: vec![(
1343 WhereClause::Comparison(Comparison {
1344 column: "x".into(),
1345 op: "=".into(),
1346 value: Some(SqlValue::Int(1)),
1347 left_expr: Some(Expr::Column("x".into())),
1348 right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1349 }),
1350 Box::new(Expr::Literal(SqlValue::String("one".into()))),
1351 )],
1352 else_expr: None,
1353 };
1354 assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1355 }
1356
1357 #[test]
1358 fn test_case_when_in_aggregate_query() {
1359 let stmt = crate::query_parser::parse_query(
1362 "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1363 ).unwrap();
1364 if let crate::query_parser::Statement::Select(q) = stmt {
1365 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1366 assert_eq!(cols, vec!["total"]);
1367 assert_eq!(rows.len(), 1);
1368 assert_eq!(rows[0]["total"], Value::Float(30.0));
1369 } else {
1370 panic!("Expected Select");
1371 }
1372 }
1373
1374 #[test]
1375 fn test_case_when_with_unary_minus_in_aggregate() {
1376 let stmt = crate::query_parser::parse_query(
1379 "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1380 ).unwrap();
1381 if let crate::query_parser::Statement::Select(q) = stmt {
1382 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1383 assert_eq!(rows.len(), 1);
1384 assert_eq!(rows[0]["net"], Value::Float(-15.0));
1385 } else {
1386 panic!("Expected Select");
1387 }
1388 }
1389}