1use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14 query: &SelectQuery,
15 rows: &[Row],
16 _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18 execute(query, rows, None)
19}
20
21pub fn execute_query_indexed(
23 query: &SelectQuery,
24 rows: &[Row],
25 schema: &Schema,
26 index: Option<&crate::index::TableIndex>,
27 searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29 let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31 collect_fts_results(wc, schema, searcher)
32 } else {
33 HashMap::new()
34 };
35
36 execute_with_fts(query, rows, index, &fts_results)
37}
38
39fn collect_fts_results(
42 clause: &WhereClause,
43 schema: &Schema,
44 searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46 let mut results = HashMap::new();
47 collect_fts_results_inner(clause, schema, searcher, &mut results);
48 results
49}
50
51fn collect_fts_results_inner(
52 clause: &WhereClause,
53 schema: &Schema,
54 searcher: &crate::search::TableSearcher,
55 results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57 match clause {
58 WhereClause::Comparison(cmp) => {
59 if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60 if let Some(SqlValue::String(pattern)) = &cmp.value {
61 let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63 if !search_term.is_empty() {
64 if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65 let key = (cmp.column.clone(), pattern.clone());
66 results.insert(key, paths.into_iter().collect());
67 }
68 }
69 }
70 }
71 }
72 WhereClause::BoolOp(bop) => {
73 collect_fts_results_inner(&bop.left, schema, searcher, results);
74 collect_fts_results_inner(&bop.right, schema, searcher, results);
75 }
76 }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82 query: &SelectQuery,
83 rows: &[Row],
84 index: Option<&crate::index::TableIndex>,
85 fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87 let mut all_columns: Vec<String> = Vec::new();
89 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90 for r in rows {
91 for k in r.keys() {
92 if seen.insert(k.clone()) {
93 all_columns.push(k.clone());
94 }
95 }
96 }
97
98 let has_aggregates = match &query.columns {
100 ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101 _ => false,
102 };
103
104 let columns: Vec<String> = match &query.columns {
106 ColumnList::All => all_columns,
107 ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108 };
109
110 let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112 let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113 if let Some(paths) = candidate_paths {
114 rows.iter()
115 .filter(|r| {
116 r.get("path")
117 .and_then(|v| v.as_str())
118 .map_or(false, |p| paths.contains(p))
119 })
120 .filter(|r| evaluate_with_fts(wc, r, fts))
121 .cloned()
122 .collect()
123 } else {
124 rows.iter()
125 .filter(|r| evaluate_with_fts(wc, r, fts))
126 .cloned()
127 .collect()
128 }
129 } else {
130 rows.to_vec()
131 };
132
133 let mut result = if has_aggregates || query.group_by.is_some() {
135 let exprs = match &query.columns {
136 ColumnList::Named(exprs) => exprs.clone(),
137 _ => return Err(MdqlError::QueryExecution(
138 "SELECT * with GROUP BY is not supported".into(),
139 )),
140 };
141 let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142 aggregate_rows(&filtered, &exprs, group_keys)?
143 } else {
144 filtered
145 };
146
147 if let Some(ref order_by) = query.order_by {
149 let resolved = resolve_order_aliases(order_by, &query.columns);
150 sort_rows(&mut result, &resolved);
151 }
152
153 if let Some(limit) = query.limit {
155 result.truncate(limit as usize);
156 }
157
158 if !matches!(query.columns, ColumnList::All) {
160 let named_exprs = match &query.columns {
161 ColumnList::Named(exprs) => exprs,
162 _ => unreachable!(),
163 };
164
165 let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
167 if has_expr_cols {
168 for row in &mut result {
169 for expr in named_exprs {
170 if let SelectExpr::Expr { expr: e, alias } = expr {
171 let name = alias.clone().unwrap_or_else(|| e.display_name());
172 let val = evaluate_expr(e, row);
173 row.insert(name, val);
174 }
175 }
176 }
177 }
178
179 let col_set: std::collections::HashSet<&str> =
180 columns.iter().map(|s| s.as_str()).collect();
181 for row in &mut result {
182 row.retain(|k, _| col_set.contains(k.as_str()));
183 }
184 }
185
186 Ok((result, columns))
187}
188
189fn aggregate_rows(
190 rows: &[Row],
191 exprs: &[SelectExpr],
192 group_keys: &[String],
193) -> crate::errors::Result<Vec<Row>> {
194 let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
196 let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
197
198 if group_keys.is_empty() {
199 let all_refs: Vec<&Row> = rows.iter().collect();
201 groups.push((vec![], all_refs));
202 } else {
203 for row in rows {
204 let key: Vec<String> = group_keys
205 .iter()
206 .map(|k| {
207 row.get(k)
208 .map(|v| v.to_display_string())
209 .unwrap_or_default()
210 })
211 .collect();
212 let key_vals: Vec<Value> = group_keys
213 .iter()
214 .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
215 .collect();
216 if let Some(&idx) = key_index.get(&key) {
217 groups[idx].1.push(row);
218 } else {
219 let idx = groups.len();
220 key_index.insert(key, idx);
221 groups.push((key_vals, vec![row]));
222 }
223 }
224 }
225
226 let mut result = Vec::new();
228 for (key_vals, group_rows) in &groups {
229 let mut out = Row::new();
230
231 for (i, k) in group_keys.iter().enumerate() {
233 out.insert(k.clone(), key_vals[i].clone());
234 }
235
236 for expr in exprs {
238 match expr {
239 SelectExpr::Column(name) => {
240 if !out.contains_key(name) {
242 if let Some(first) = group_rows.first() {
243 out.insert(
244 name.clone(),
245 first.get(name).cloned().unwrap_or(Value::Null),
246 );
247 }
248 }
249 }
250 SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
251 let out_name = alias
252 .clone()
253 .unwrap_or_else(|| expr.output_name());
254 let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
255 out.insert(out_name, val);
256 }
257 SelectExpr::Expr { expr: e, alias } => {
258 let out_name = alias.clone().unwrap_or_else(|| e.display_name());
259 if let Some(first) = group_rows.first() {
260 let val = evaluate_expr(e, first);
261 out.insert(out_name, val);
262 }
263 }
264 }
265 }
266
267 result.push(out);
268 }
269
270 Ok(result)
271}
272
273fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
276 if let Some(expr) = arg_expr {
277 evaluate_expr(expr, row)
278 } else {
279 row.get(arg).cloned().unwrap_or(Value::Null)
280 }
281}
282
283fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
284 match func {
285 AggFunc::Count => {
286 if arg == "*" && arg_expr.is_none() {
287 Value::Int(rows.len() as i64)
288 } else {
289 let count = rows
290 .iter()
291 .filter(|r| {
292 let v = resolve_agg_value(arg, arg_expr, r);
293 !v.is_null()
294 })
295 .count();
296 Value::Int(count as i64)
297 }
298 }
299 AggFunc::Sum => {
300 let mut total = 0.0f64;
301 let mut has_any = false;
302 for r in rows {
303 let v = resolve_agg_value(arg, arg_expr, r);
304 match v {
305 Value::Int(n) => { total += n as f64; has_any = true; }
306 Value::Float(f) => { total += f; has_any = true; }
307 _ => {}
308 }
309 }
310 if has_any { Value::Float(total) } else { Value::Null }
311 }
312 AggFunc::Avg => {
313 let mut total = 0.0f64;
314 let mut count = 0usize;
315 for r in rows {
316 let v = resolve_agg_value(arg, arg_expr, r);
317 match v {
318 Value::Int(n) => { total += n as f64; count += 1; }
319 Value::Float(f) => { total += f; count += 1; }
320 _ => {}
321 }
322 }
323 if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
324 }
325 AggFunc::Min => {
326 let mut min_val: Option<Value> = None;
327 for r in rows {
328 let v = resolve_agg_value(arg, arg_expr, r);
329 if v.is_null() { continue; }
330 min_val = Some(match min_val {
331 None => v,
332 Some(ref current) => {
333 if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
334 v
335 } else {
336 current.clone()
337 }
338 }
339 });
340 }
341 min_val.unwrap_or(Value::Null)
342 }
343 AggFunc::Max => {
344 let mut max_val: Option<Value> = None;
345 for r in rows {
346 let v = resolve_agg_value(arg, arg_expr, r);
347 if v.is_null() { continue; }
348 max_val = Some(match max_val {
349 None => v,
350 Some(ref current) => {
351 if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
352 v
353 } else {
354 current.clone()
355 }
356 }
357 });
358 }
359 max_val.unwrap_or(Value::Null)
360 }
361 }
362}
363
364fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
365 match clause {
366 WhereClause::BoolOp(bop) => {
367 let left = evaluate_with_fts(&bop.left, row, fts);
368 match bop.op.as_str() {
369 "AND" => left && evaluate_with_fts(&bop.right, row, fts),
370 "OR" => left || evaluate_with_fts(&bop.right, row, fts),
371 _ => false,
372 }
373 }
374 WhereClause::Comparison(cmp) => {
375 if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
377 if let Some(SqlValue::String(pattern)) = &cmp.value {
378 let key = (cmp.column.clone(), pattern.clone());
379 if let Some(matching_paths) = fts.get(&key) {
380 let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
381 let matched = matching_paths.contains(row_path);
382 return if cmp.op == "LIKE" { matched } else { !matched };
383 }
384 }
385 }
386 evaluate_comparison(cmp, row)
387 }
388 }
389}
390
391pub fn execute_join_query(
392 query: &SelectQuery,
393 tables: &HashMap<String, (Schema, Vec<Row>)>,
394) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
395 if query.joins.is_empty() {
396 return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
397 }
398
399 let left_name = &query.table;
400 let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
401
402 let mut aliases: HashMap<String, String> = HashMap::new();
404 aliases.insert(left_name.clone(), left_name.clone());
405 if let Some(ref a) = query.table_alias {
406 aliases.insert(a.clone(), left_name.clone());
407 }
408 for join in &query.joins {
409 aliases.insert(join.table.clone(), join.table.clone());
410 if let Some(ref a) = join.alias {
411 aliases.insert(a.clone(), join.table.clone());
412 }
413 }
414
415 let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
417 MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
418 })?;
419
420 let mut current_rows: Vec<Row> = left_rows
421 .iter()
422 .map(|r| {
423 let mut prefixed = Row::new();
424 for (k, v) in r {
425 prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
426 }
427 prefixed
428 })
429 .collect();
430
431 for join in &query.joins {
433 let right_name = &join.table;
434 let right_alias = join.alias.as_deref().unwrap_or(right_name);
435
436 let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
437 MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
438 })?;
439
440 let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
442 let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
443
444 let (left_key, right_key) = if on_right_table == *right_name {
446 let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
448 (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
449 } else {
450 let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
452 (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
453 };
454
455 let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
457 for r in right_rows {
458 if let Some(key) = r.get(&right_key) {
459 let key_str = key.to_display_string();
460 right_index.entry(key_str).or_default().push(r);
461 }
462 }
463
464 let mut next_rows: Vec<Row> = Vec::new();
466 for lr in ¤t_rows {
467 if let Some(key) = lr.get(&left_key) {
468 let key_str = key.to_display_string();
469 if let Some(matching) = right_index.get(&key_str) {
470 for rr in matching {
471 let mut merged = lr.clone();
472 for (k, v) in *rr {
473 merged.insert(format!("{}.{}", right_alias, k), v.clone());
474 }
475 next_rows.push(merged);
476 }
477 }
478 }
479 }
480 current_rows = next_rows;
481 }
482
483 let (mut result, columns) = execute(query, ¤t_rows, None)?;
484
485 if !result.is_empty() {
489 let mut base_counts: HashMap<String, usize> = HashMap::new();
490 for key in &columns {
491 if let Some((_prefix, base)) = key.split_once('.') {
492 *base_counts.entry(base.to_string()).or_default() += 1;
493 }
494 }
495 let unique_bases: Vec<String> = base_counts
496 .into_iter()
497 .filter(|(_, count)| *count == 1)
498 .map(|(base, _)| base)
499 .collect();
500
501 if !unique_bases.is_empty() {
502 let unique_set: std::collections::HashSet<&str> =
503 unique_bases.iter().map(|s| s.as_str()).collect();
504 for row in &mut result {
505 let additions: Vec<(String, Value)> = row
506 .iter()
507 .filter_map(|(k, v)| {
508 k.split_once('.').and_then(|(_, base)| {
509 if unique_set.contains(base) {
510 Some((base.to_string(), v.clone()))
511 } else {
512 None
513 }
514 })
515 })
516 .collect();
517 for (k, v) in additions {
518 row.insert(k, v);
519 }
520 }
521 }
522 }
523
524 Ok((result, columns))
525}
526
527fn reverse_alias(
529 table_name: &str,
530 aliases: &HashMap<String, String>,
531 query: &SelectQuery,
532 joins: &[JoinClause],
533) -> String {
534 if query.table == table_name {
536 return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
537 }
538 for j in joins {
540 if j.table == table_name {
541 return j.alias.as_deref().unwrap_or(&j.table).to_string();
542 }
543 }
544 if aliases.contains_key(table_name) {
546 return table_name.to_string();
547 }
548 table_name.to_string()
549}
550
551fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
552 if let Some((alias, column)) = col.split_once('.') {
553 let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
554 (table, column.to_string())
555 } else {
556 (String::new(), col.to_string())
557 }
558}
559
560fn execute(
561 query: &SelectQuery,
562 rows: &[Row],
563 index: Option<&crate::index::TableIndex>,
564) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
565 let empty_fts = HashMap::new();
566 execute_with_fts(query, rows, index, &empty_fts)
567}
568
569pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
570 match clause {
571 WhereClause::BoolOp(bop) => {
572 let left = evaluate(&bop.left, row);
573 match bop.op.as_str() {
574 "AND" => left && evaluate(&bop.right, row),
575 "OR" => left || evaluate(&bop.right, row),
576 _ => false,
577 }
578 }
579 WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
580 }
581}
582
583pub fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
585 match expr {
586 Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
587 Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
588 Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
589 Expr::Literal(SqlValue::Null) => Value::Null,
590 Expr::Literal(SqlValue::List(_)) => Value::Null,
591 Expr::Column(name) => row.get(name).cloned().unwrap_or(Value::Null),
592 Expr::UnaryMinus(inner) => {
593 match evaluate_expr(inner, row) {
594 Value::Int(n) => Value::Int(-n),
595 Value::Float(f) => Value::Float(-f),
596 Value::Null => Value::Null,
597 _ => Value::Null, }
599 }
600 Expr::BinaryOp { left, op, right } => {
601 let lv = evaluate_expr(left, row);
602 let rv = evaluate_expr(right, row);
603
604 if lv.is_null() || rv.is_null() {
606 return Value::Null;
607 }
608
609 match (&lv, &rv) {
611 (Value::Int(a), Value::Int(b)) => {
612 match op {
613 ArithOp::Add => Value::Int(a.wrapping_add(*b)),
614 ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
615 ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
616 ArithOp::Div => {
617 if *b == 0 { Value::Null } else { Value::Int(a / b) }
618 }
619 ArithOp::Mod => {
620 if *b == 0 { Value::Null } else { Value::Int(a % b) }
621 }
622 }
623 }
624 _ => {
625 let a = match &lv {
627 Value::Int(n) => *n as f64,
628 Value::Float(f) => *f,
629 _ => return Value::Null,
630 };
631 let b = match &rv {
632 Value::Int(n) => *n as f64,
633 Value::Float(f) => *f,
634 _ => return Value::Null,
635 };
636 match op {
637 ArithOp::Add => Value::Float(a + b),
638 ArithOp::Sub => Value::Float(a - b),
639 ArithOp::Mul => Value::Float(a * b),
640 ArithOp::Div => {
641 if b == 0.0 { Value::Null } else { Value::Float(a / b) }
642 }
643 ArithOp::Mod => {
644 if b == 0.0 { Value::Null } else { Value::Float(a % b) }
645 }
646 }
647 }
648 }
649 }
650 Expr::Case { whens, else_expr } => {
651 for (condition, result) in whens {
652 if evaluate(condition, row) {
653 return evaluate_expr(result, row);
654 }
655 }
656 match else_expr {
657 Some(e) => evaluate_expr(e, row),
658 None => Value::Null,
659 }
660 }
661 }
662}
663
664fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
665 if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
667 if ["=", "!=", "<", ">", "<=", ">="].contains(&cmp.op.as_str()) {
668 let left_val = evaluate_expr(left_expr, row);
669 let right_val = evaluate_expr(right_expr, row);
670
671 if left_val.is_null() || right_val.is_null() {
673 return false;
674 }
675
676 let ord = compare_model_values(&left_val, &right_val);
678
679 return match cmp.op.as_str() {
680 "=" => ord == Some(Ordering::Equal),
681 "!=" => ord != Some(Ordering::Equal),
682 "<" => ord == Some(Ordering::Less),
683 ">" => ord == Some(Ordering::Greater),
684 "<=" => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
685 ">=" => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
686 _ => false,
687 };
688 }
689 }
690
691 let actual = row.get(&cmp.column);
693
694 if cmp.op == "IS NULL" {
695 return actual.map_or(true, |v| v.is_null());
696 }
697 if cmp.op == "IS NOT NULL" {
698 return actual.map_or(false, |v| !v.is_null());
699 }
700
701 let actual = match actual {
702 Some(v) if !v.is_null() => v,
703 _ => return false,
704 };
705
706 let expected = match &cmp.value {
707 Some(v) => v,
708 None => return false,
709 };
710
711 match cmp.op.as_str() {
712 "=" => eq_match(actual, expected),
713 "!=" => !eq_match(actual, expected),
714 "<" => compare_values(actual, expected) == Some(Ordering::Less),
715 ">" => compare_values(actual, expected) == Some(Ordering::Greater),
716 "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
717 ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
718 "LIKE" => like_match(actual, expected),
719 "NOT LIKE" => !like_match(actual, expected),
720 "IN" => {
721 if let SqlValue::List(items) = expected {
722 items.iter().any(|v| eq_match(actual, v))
723 } else {
724 eq_match(actual, expected)
725 }
726 }
727 _ => false,
728 }
729}
730
731fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
733 match (a, b) {
734 (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
735 (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
736 _ => a.partial_cmp(b),
737 }
738}
739
740fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
741 match sql_val {
742 SqlValue::Null => Value::Null,
743 SqlValue::String(s) => {
744 match target {
745 Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
746 Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
747 Value::Date(_) => {
748 chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
749 .map(Value::Date)
750 .unwrap_or(Value::String(s.clone()))
751 }
752 Value::DateTime(_) => {
753 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S")
754 .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f"))
755 .map(Value::DateTime)
756 .unwrap_or(Value::String(s.clone()))
757 }
758 _ => Value::String(s.clone()),
759 }
760 }
761 SqlValue::Int(n) => {
762 match target {
763 Value::Float(_) => Value::Float(*n as f64),
764 _ => Value::Int(*n),
765 }
766 }
767 SqlValue::Float(f) => Value::Float(*f),
768 SqlValue::List(_) => Value::Null, }
770}
771
772fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
773 if let Value::List(items) = actual {
775 if let SqlValue::String(s) = expected {
776 return items.contains(s);
777 }
778 }
779
780 let coerced = coerce_sql_to_value(expected, actual);
781 actual == &coerced
782}
783
784fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
785 let pattern_str = match pattern {
786 SqlValue::String(s) => s,
787 _ => return false,
788 };
789
790 let mut regex_str = String::from("(?is)^");
792 for ch in pattern_str.chars() {
793 match ch {
794 '%' => regex_str.push_str(".*"),
795 '_' => regex_str.push('.'),
796 c => {
797 if regex::escape(&c.to_string()) != c.to_string() {
798 regex_str.push_str(®ex::escape(&c.to_string()));
799 } else {
800 regex_str.push(c);
801 }
802 }
803 }
804 }
805 regex_str.push('$');
806
807 let re = match Regex::new(®ex_str) {
808 Ok(r) => r,
809 Err(_) => return false,
810 };
811
812 match actual {
813 Value::List(items) => items.iter().any(|item| re.is_match(item)),
814 _ => re.is_match(&actual.to_display_string()),
815 }
816}
817
818fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
819 let coerced = coerce_sql_to_value(expected, actual);
820 actual.partial_cmp(&coerced).map(|o| o)
821}
822
823fn sql_value_to_index_value(sv: &SqlValue) -> Value {
825 match sv {
826 SqlValue::String(s) => {
827 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
829 return Value::DateTime(dt);
830 }
831 if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") {
832 return Value::DateTime(dt);
833 }
834 if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
836 return Value::Date(d);
837 }
838 Value::String(s.clone())
839 }
840 SqlValue::Int(n) => Value::Int(*n),
841 SqlValue::Float(f) => Value::Float(*f),
842 SqlValue::Null => Value::Null,
843 SqlValue::List(_) => Value::Null,
844 }
845}
846
847fn try_index_filter(
851 clause: &WhereClause,
852 index: &crate::index::TableIndex,
853) -> Option<std::collections::HashSet<String>> {
854 match clause {
855 WhereClause::Comparison(cmp) => {
856 if !index.has_index(&cmp.column) {
857 return None;
858 }
859 match cmp.op.as_str() {
860 "=" => {
861 let val = sql_value_to_index_value(cmp.value.as_ref()?);
862 let paths = index.lookup_eq(&cmp.column, &val);
863 Some(paths.into_iter().map(|s| s.to_string()).collect())
864 }
865 "<" => {
866 let val = sql_value_to_index_value(cmp.value.as_ref()?);
867 let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
870 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
871 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
872 }
873 ">" => {
874 let val = sql_value_to_index_value(cmp.value.as_ref()?);
875 let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
876 let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
877 Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
878 }
879 "<=" => {
880 let val = sql_value_to_index_value(cmp.value.as_ref()?);
881 let paths = index.lookup_range(&cmp.column, None, Some(&val));
882 Some(paths.into_iter().map(|s| s.to_string()).collect())
883 }
884 ">=" => {
885 let val = sql_value_to_index_value(cmp.value.as_ref()?);
886 let paths = index.lookup_range(&cmp.column, Some(&val), None);
887 Some(paths.into_iter().map(|s| s.to_string()).collect())
888 }
889 "IN" => {
890 if let Some(SqlValue::List(items)) = &cmp.value {
891 let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
892 let paths = index.lookup_in(&cmp.column, &vals);
893 Some(paths.into_iter().map(|s| s.to_string()).collect())
894 } else {
895 None
896 }
897 }
898 _ => None, }
900 }
901 WhereClause::BoolOp(bop) => {
902 let left = try_index_filter(&bop.left, index);
903 let right = try_index_filter(&bop.right, index);
904 match bop.op.as_str() {
905 "AND" => {
906 match (left, right) {
907 (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
908 (Some(l), None) => Some(l), (None, Some(r)) => Some(r),
910 (None, None) => None,
911 }
912 }
913 "OR" => {
914 match (left, right) {
915 (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
916 _ => None, }
918 }
919 _ => None,
920 }
921 }
922 }
923}
924
925fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
928 let named = match columns {
929 ColumnList::Named(exprs) => exprs,
930 _ => return specs.to_vec(),
931 };
932
933 let alias_map: HashMap<String, &Expr> = named
935 .iter()
936 .filter_map(|se| match se {
937 SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
938 _ => None,
939 })
940 .collect();
941
942 specs
943 .iter()
944 .map(|spec| {
945 if let Some(expr) = alias_map.get(&spec.column) {
947 OrderSpec {
948 column: spec.column.clone(),
949 expr: Some((*expr).clone()),
950 descending: spec.descending,
951 }
952 } else {
953 spec.clone()
954 }
955 })
956 .collect()
957}
958
959fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
960 rows.sort_by(|a, b| {
961 for spec in specs {
962 let (va, vb) = if let Some(ref expr) = spec.expr {
963 (evaluate_expr(expr, a), evaluate_expr(expr, b))
964 } else {
965 (
966 a.get(&spec.column).cloned().unwrap_or(Value::Null),
967 b.get(&spec.column).cloned().unwrap_or(Value::Null),
968 )
969 };
970
971 let ordering = match (&va, &vb) {
973 (Value::Null, Value::Null) => Ordering::Equal,
974 (Value::Null, _) => Ordering::Greater,
975 (_, Value::Null) => Ordering::Less,
976 (a_val, b_val) => {
977 compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
978 }
979 };
980
981 let ordering = if spec.descending {
982 ordering.reverse()
983 } else {
984 ordering
985 };
986
987 if ordering != Ordering::Equal {
988 return ordering;
989 }
990 }
991 Ordering::Equal
992 });
993}
994
995pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
997 match sql_val {
998 SqlValue::Null => Value::Null,
999 SqlValue::String(s) => Value::String(s.clone()),
1000 SqlValue::Int(n) => Value::Int(*n),
1001 SqlValue::Float(f) => Value::Float(*f),
1002 SqlValue::List(items) => {
1003 let strings: Vec<String> = items
1004 .iter()
1005 .filter_map(|v| match v {
1006 SqlValue::String(s) => Some(s.clone()),
1007 _ => None,
1008 })
1009 .collect();
1010 Value::List(strings)
1011 }
1012 }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use super::*;
1018
1019 fn make_rows() -> Vec<Row> {
1020 vec![
1021 Row::from([
1022 ("path".into(), Value::String("a.md".into())),
1023 ("title".into(), Value::String("Alpha".into())),
1024 ("count".into(), Value::Int(10)),
1025 ]),
1026 Row::from([
1027 ("path".into(), Value::String("b.md".into())),
1028 ("title".into(), Value::String("Beta".into())),
1029 ("count".into(), Value::Int(5)),
1030 ]),
1031 Row::from([
1032 ("path".into(), Value::String("c.md".into())),
1033 ("title".into(), Value::String("Gamma".into())),
1034 ("count".into(), Value::Int(20)),
1035 ]),
1036 ]
1037 }
1038
1039 #[test]
1040 fn test_select_all() {
1041 let q = SelectQuery {
1042 columns: ColumnList::All,
1043 table: "test".into(),
1044 table_alias: None,
1045 joins: vec![],
1046 where_clause: None,
1047 group_by: None,
1048 order_by: None,
1049 limit: None,
1050 };
1051 let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
1052 assert_eq!(rows.len(), 3);
1053 }
1054
1055 #[test]
1056 fn test_where_gt() {
1057 let q = SelectQuery {
1058 columns: ColumnList::All,
1059 table: "test".into(),
1060 table_alias: None,
1061 joins: vec![],
1062 where_clause: Some(WhereClause::Comparison(Comparison {
1063 column: "count".into(),
1064 op: ">".into(),
1065 value: Some(SqlValue::Int(5)),
1066 left_expr: Some(Expr::Column("count".into())),
1067 right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1068 })),
1069 group_by: None,
1070 order_by: None,
1071 limit: None,
1072 };
1073 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1074 assert_eq!(rows.len(), 2);
1075 }
1076
1077 #[test]
1078 fn test_order_by_desc() {
1079 let q = SelectQuery {
1080 columns: ColumnList::All,
1081 table: "test".into(),
1082 table_alias: None,
1083 joins: vec![],
1084 where_clause: None,
1085 group_by: None,
1086 order_by: Some(vec![OrderSpec {
1087 column: "count".into(),
1088 expr: Some(Expr::Column("count".into())),
1089 descending: true,
1090 }]),
1091 limit: None,
1092 };
1093 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1094 assert_eq!(rows[0]["count"], Value::Int(20));
1095 assert_eq!(rows[2]["count"], Value::Int(5));
1096 }
1097
1098 #[test]
1099 fn test_limit() {
1100 let q = SelectQuery {
1101 columns: ColumnList::All,
1102 table: "test".into(),
1103 table_alias: None,
1104 joins: vec![],
1105 where_clause: None,
1106 group_by: None,
1107 order_by: None,
1108 limit: Some(2),
1109 };
1110 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1111 assert_eq!(rows.len(), 2);
1112 }
1113
1114 #[test]
1115 fn test_like() {
1116 let q = SelectQuery {
1117 columns: ColumnList::All,
1118 table: "test".into(),
1119 table_alias: None,
1120 joins: vec![],
1121 where_clause: Some(WhereClause::Comparison(Comparison {
1122 column: "title".into(),
1123 op: "LIKE".into(),
1124 value: Some(SqlValue::String("%lph%".into())),
1125 left_expr: Some(Expr::Column("title".into())),
1126 right_expr: None,
1127 })),
1128 group_by: None,
1129 order_by: None,
1130 limit: None,
1131 };
1132 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1133 assert_eq!(rows.len(), 1);
1134 assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1135 }
1136
1137 #[test]
1138 fn test_is_null() {
1139 let mut rows = make_rows();
1140 rows[1].insert("optional".into(), Value::Null);
1141
1142 let q = SelectQuery {
1143 columns: ColumnList::All,
1144 table: "test".into(),
1145 table_alias: None,
1146 joins: vec![],
1147 where_clause: Some(WhereClause::Comparison(Comparison {
1148 column: "optional".into(),
1149 op: "IS NULL".into(),
1150 value: None,
1151 left_expr: Some(Expr::Column("optional".into())),
1152 right_expr: None,
1153 })),
1154 group_by: None,
1155 order_by: None,
1156 limit: None,
1157 };
1158 let (result, _) = execute(&q, &rows, None).unwrap();
1159 assert_eq!(result.len(), 3);
1161 }
1162
1163 #[test]
1166 fn test_evaluate_expr_literal() {
1167 let row = Row::new();
1168 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1169 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1170 assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1171 }
1172
1173 #[test]
1174 fn test_evaluate_expr_column() {
1175 let row = Row::from([("x".into(), Value::Int(10))]);
1176 assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1177 assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1178 }
1179
1180 #[test]
1181 fn test_evaluate_expr_int_arithmetic() {
1182 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1183 let add = Expr::BinaryOp {
1184 left: Box::new(Expr::Column("a".into())),
1185 op: ArithOp::Add,
1186 right: Box::new(Expr::Column("b".into())),
1187 };
1188 assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1189
1190 let sub = Expr::BinaryOp {
1191 left: Box::new(Expr::Column("a".into())),
1192 op: ArithOp::Sub,
1193 right: Box::new(Expr::Column("b".into())),
1194 };
1195 assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1196
1197 let mul = Expr::BinaryOp {
1198 left: Box::new(Expr::Column("a".into())),
1199 op: ArithOp::Mul,
1200 right: Box::new(Expr::Column("b".into())),
1201 };
1202 assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1203
1204 let div = Expr::BinaryOp {
1205 left: Box::new(Expr::Column("a".into())),
1206 op: ArithOp::Div,
1207 right: Box::new(Expr::Column("b".into())),
1208 };
1209 assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); let modulo = Expr::BinaryOp {
1212 left: Box::new(Expr::Column("a".into())),
1213 op: ArithOp::Mod,
1214 right: Box::new(Expr::Column("b".into())),
1215 };
1216 assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1217 }
1218
1219 #[test]
1220 fn test_evaluate_expr_float_coercion() {
1221 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1222 let add = Expr::BinaryOp {
1223 left: Box::new(Expr::Column("a".into())),
1224 op: ArithOp::Add,
1225 right: Box::new(Expr::Column("b".into())),
1226 };
1227 assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1228 }
1229
1230 #[test]
1231 fn test_evaluate_expr_null_propagation() {
1232 let row = Row::from([("a".into(), Value::Int(10))]);
1233 let add = Expr::BinaryOp {
1234 left: Box::new(Expr::Column("a".into())),
1235 op: ArithOp::Add,
1236 right: Box::new(Expr::Column("missing".into())),
1237 };
1238 assert_eq!(evaluate_expr(&add, &row), Value::Null);
1239 }
1240
1241 #[test]
1242 fn test_evaluate_expr_div_by_zero() {
1243 let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1244 let div = Expr::BinaryOp {
1245 left: Box::new(Expr::Column("a".into())),
1246 op: ArithOp::Div,
1247 right: Box::new(Expr::Column("b".into())),
1248 };
1249 assert_eq!(evaluate_expr(&div, &row), Value::Null);
1250 }
1251
1252 #[test]
1253 fn test_evaluate_expr_unary_minus() {
1254 let row = Row::from([("x".into(), Value::Int(5))]);
1255 let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1256 assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1257 }
1258
1259 #[test]
1260 fn test_select_with_expression() {
1261 let stmt = crate::query_parser::parse_query(
1263 "SELECT count * 2 AS doubled FROM test"
1264 ).unwrap();
1265 if let crate::query_parser::Statement::Select(q) = stmt {
1266 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1267 assert_eq!(cols, vec!["doubled"]);
1268 assert_eq!(rows.len(), 3);
1269 let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1271 assert!(values.contains(&Value::Int(20)));
1272 assert!(values.contains(&Value::Int(10)));
1273 assert!(values.contains(&Value::Int(40)));
1274 } else {
1275 panic!("Expected Select");
1276 }
1277 }
1278
1279 #[test]
1280 fn test_where_with_expression() {
1281 let stmt = crate::query_parser::parse_query(
1283 "SELECT * FROM test WHERE count * 2 > 15"
1284 ).unwrap();
1285 if let crate::query_parser::Statement::Select(q) = stmt {
1286 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1287 assert_eq!(rows.len(), 2);
1289 } else {
1290 panic!("Expected Select");
1291 }
1292 }
1293
1294 #[test]
1295 fn test_order_by_expression() {
1296 let stmt = crate::query_parser::parse_query(
1298 "SELECT title, count FROM test ORDER BY count * -1 ASC"
1299 ).unwrap();
1300 if let crate::query_parser::Statement::Select(q) = stmt {
1301 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1302 assert_eq!(rows[0]["count"], Value::Int(20));
1304 assert_eq!(rows[1]["count"], Value::Int(10));
1305 assert_eq!(rows[2]["count"], Value::Int(5));
1306 } else {
1307 panic!("Expected Select");
1308 }
1309 }
1310
1311 #[test]
1314 fn test_case_when_eval_basic() {
1315 let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1316 let expr = Expr::Case {
1317 whens: vec![(
1318 WhereClause::Comparison(Comparison {
1319 column: "status".into(),
1320 op: "=".into(),
1321 value: Some(SqlValue::String("ACTIVE".into())),
1322 left_expr: Some(Expr::Column("status".into())),
1323 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1324 }),
1325 Box::new(Expr::Literal(SqlValue::Int(1))),
1326 )],
1327 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1328 };
1329 assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1330 }
1331
1332 #[test]
1333 fn test_case_when_eval_else() {
1334 let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1335 let expr = Expr::Case {
1336 whens: vec![(
1337 WhereClause::Comparison(Comparison {
1338 column: "status".into(),
1339 op: "=".into(),
1340 value: Some(SqlValue::String("ACTIVE".into())),
1341 left_expr: Some(Expr::Column("status".into())),
1342 right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1343 }),
1344 Box::new(Expr::Literal(SqlValue::Int(1))),
1345 )],
1346 else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1347 };
1348 assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1349 }
1350
1351 #[test]
1352 fn test_case_when_eval_no_else_null() {
1353 let row = Row::from([("x".into(), Value::Int(99))]);
1354 let expr = Expr::Case {
1355 whens: vec![(
1356 WhereClause::Comparison(Comparison {
1357 column: "x".into(),
1358 op: "=".into(),
1359 value: Some(SqlValue::Int(1)),
1360 left_expr: Some(Expr::Column("x".into())),
1361 right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1362 }),
1363 Box::new(Expr::Literal(SqlValue::String("one".into()))),
1364 )],
1365 else_expr: None,
1366 };
1367 assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1368 }
1369
1370 #[test]
1371 fn test_case_when_in_aggregate_query() {
1372 let stmt = crate::query_parser::parse_query(
1375 "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1376 ).unwrap();
1377 if let crate::query_parser::Statement::Select(q) = stmt {
1378 let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1379 assert_eq!(cols, vec!["total"]);
1380 assert_eq!(rows.len(), 1);
1381 assert_eq!(rows[0]["total"], Value::Float(30.0));
1382 } else {
1383 panic!("Expected Select");
1384 }
1385 }
1386
1387 #[test]
1388 fn test_case_when_with_unary_minus_in_aggregate() {
1389 let stmt = crate::query_parser::parse_query(
1392 "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1393 ).unwrap();
1394 if let crate::query_parser::Statement::Select(q) = stmt {
1395 let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1396 assert_eq!(rows.len(), 1);
1397 assert_eq!(rows[0]["net"], Value::Float(-15.0));
1398 } else {
1399 panic!("Expected Select");
1400 }
1401 }
1402}