Skip to main content

mdql_core/
query_engine.rs

1//! Execute parsed queries over in-memory rows.
2
3use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14    query: &SelectQuery,
15    rows: &[Row],
16    _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18    if let Some(ref sub) = query.subquery {
19        let (sub_rows, _sub_cols) = execute_inner(sub, rows, None)?;
20        return execute_inner(query, &sub_rows, None);
21    }
22    execute_inner(query, rows, None)
23}
24
25#[allow(dead_code)]
26pub(crate) fn execute_query_indexed(
27    query: &SelectQuery,
28    rows: &[Row],
29    schema: &Schema,
30    index: Option<&crate::index::TableIndex>,
31    searcher: Option<&crate::search::TableSearcher>,
32) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
33    // Pre-compute FTS results for any LIKE clauses on section columns
34    let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
35        collect_fts_results(wc, schema, searcher)
36    } else {
37        HashMap::new()
38    };
39
40    execute_with_fts(query, rows, index, &fts_results)
41}
42
43#[allow(dead_code)]
44fn collect_fts_results(
45    clause: &WhereClause,
46    schema: &Schema,
47    searcher: &crate::search::TableSearcher,
48) -> HashMap<(String, String), std::collections::HashSet<String>> {
49    let mut results = HashMap::new();
50    collect_fts_results_inner(clause, schema, searcher, &mut results);
51    results
52}
53
54#[allow(dead_code)]
55fn collect_fts_results_inner(
56    clause: &WhereClause,
57    schema: &Schema,
58    searcher: &crate::search::TableSearcher,
59    results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
60) {
61    match clause {
62        WhereClause::Comparison(cmp) => {
63            if (cmp.op == CmpOp::Like || cmp.op == CmpOp::NotLike) && schema.sections.contains_key(&cmp.column) {
64                if let Some(SqlValue::String(pattern)) = &cmp.value {
65                    // Strip SQL wildcards for Tantivy query
66                    let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
67                    if !search_term.is_empty() {
68                        if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
69                            let key = (cmp.column.clone(), pattern.clone());
70                            results.insert(key, paths.into_iter().collect());
71                        }
72                    }
73                }
74            }
75        }
76        WhereClause::BoolOp(bop) => {
77            collect_fts_results_inner(&bop.left, schema, searcher, results);
78            collect_fts_results_inner(&bop.right, schema, searcher, results);
79        }
80    }
81}
82
83type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
84
85fn execute_with_fts(
86    query: &SelectQuery,
87    rows: &[Row],
88    index: Option<&crate::index::TableIndex>,
89    fts: &FtsResults,
90) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
91    // Determine available columns
92    let mut all_columns: Vec<String> = Vec::new();
93    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
94    for r in rows {
95        for k in r.keys() {
96            if seen.insert(k.clone()) {
97                all_columns.push(k.clone());
98            }
99        }
100    }
101
102    // Check if query has aggregates
103    let has_aggregates = match &query.columns {
104        ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
105        _ => false,
106    };
107
108    // Output column names
109    let columns: Vec<String> = match &query.columns {
110        ColumnList::All => all_columns,
111        ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
112    };
113
114    // Filter — try index first, fall back to full scan
115    let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
116        let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
117        if let Some(paths) = candidate_paths {
118            rows.iter()
119                .filter(|r| {
120                    r.get("path")
121                        .and_then(|v| v.as_str())
122                        .map_or(false, |p| paths.contains(p))
123                })
124                .filter(|r| evaluate_with_fts(wc, r, fts))
125                .cloned()
126                .collect()
127        } else {
128            rows.iter()
129                .filter(|r| evaluate_with_fts(wc, r, fts))
130                .cloned()
131                .collect()
132        }
133    } else {
134        rows.to_vec()
135    };
136
137    // Aggregate if needed
138    let mut result = if has_aggregates || query.group_by.is_some() {
139        let exprs = match &query.columns {
140            ColumnList::Named(exprs) => exprs.clone(),
141            _ => return Err(MdqlError::QueryExecution(
142                "SELECT * with GROUP BY is not supported".into(),
143            )),
144        };
145        let group_keys = query.group_by.as_deref().unwrap_or(&[]);
146        aggregate_rows(&filtered, &exprs, group_keys)?
147    } else {
148        filtered
149    };
150
151    // HAVING filter — apply after aggregation
152    if let Some(ref having) = query.having {
153        result.retain(|row| evaluate(having, row));
154    }
155
156    // Sort — resolve ORDER BY aliases against SELECT list
157    if let Some(ref order_by) = query.order_by {
158        let resolved = resolve_order_aliases(order_by, &query.columns);
159        sort_rows(&mut result, &resolved);
160    }
161
162    // Limit
163    if let Some(limit) = query.limit {
164        result.truncate(limit as usize);
165    }
166
167    // Project — evaluate expressions and strip to requested columns
168    if !matches!(query.columns, ColumnList::All) {
169        let named_exprs = match &query.columns {
170            ColumnList::Named(exprs) => exprs,
171            _ => unreachable!(),
172        };
173
174        // Compute expression columns first, then retain only requested columns.
175        // Skip if aggregation already computed them (re-evaluating would lose
176        // columns that only existed in pre-aggregation rows, e.g. dict fields).
177        let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
178        let already_aggregated = has_aggregates || query.group_by.is_some();
179        if has_expr_cols && !already_aggregated {
180            for row in &mut result {
181                for expr in named_exprs {
182                    if let SelectExpr::Expr { expr: e, alias } = expr {
183                        let name = alias.clone().unwrap_or_else(|| e.display_name());
184                        let val = evaluate_expr(e, row);
185                        row.insert(name, val);
186                    }
187                }
188            }
189        }
190
191        let col_set: std::collections::HashSet<&str> =
192            columns.iter().map(|s| s.as_str()).collect();
193        for row in &mut result {
194            row.retain(|k, _| col_set.contains(k.as_str()));
195        }
196    }
197
198    Ok((result, columns))
199}
200
201fn aggregate_rows(
202    rows: &[Row],
203    exprs: &[SelectExpr],
204    group_keys: &[String],
205) -> crate::errors::Result<Vec<Row>> {
206    // Group rows by group_keys
207    let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
208    let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
209
210    if group_keys.is_empty() {
211        // No GROUP BY — all rows are one group
212        let all_refs: Vec<&Row> = rows.iter().collect();
213        groups.push((vec![], all_refs));
214    } else {
215        for row in rows {
216            let key: Vec<String> = group_keys
217                .iter()
218                .map(|k| {
219                    row.get(k)
220                        .map(|v| v.to_display_string())
221                        .unwrap_or_default()
222                })
223                .collect();
224            let key_vals: Vec<Value> = group_keys
225                .iter()
226                .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
227                .collect();
228            if let Some(&idx) = key_index.get(&key) {
229                groups[idx].1.push(row);
230            } else {
231                let idx = groups.len();
232                key_index.insert(key, idx);
233                groups.push((key_vals, vec![row]));
234            }
235        }
236    }
237
238    // Compute aggregates per group
239    let mut result = Vec::new();
240    for (key_vals, group_rows) in &groups {
241        let mut out = Row::new();
242
243        // Fill in group key values
244        for (i, k) in group_keys.iter().enumerate() {
245            out.insert(k.clone(), key_vals[i].clone());
246        }
247
248        // Compute each expression
249        for expr in exprs {
250            match expr {
251                SelectExpr::Column(name) => {
252                    // Already filled if it's a group key; otherwise take first row's value
253                    if !out.contains_key(name) {
254                        if let Some(first) = group_rows.first() {
255                            out.insert(
256                                name.clone(),
257                                first.get(name).cloned().unwrap_or(Value::Null),
258                            );
259                        }
260                    }
261                }
262                SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
263                    let out_name = alias
264                        .clone()
265                        .unwrap_or_else(|| expr.output_name());
266                    let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
267                    out.insert(out_name, val);
268                }
269                SelectExpr::Expr { expr: e, alias } => {
270                    let out_name = alias.clone().unwrap_or_else(|| e.display_name());
271                    if e.contains_aggregate() {
272                        let val = evaluate_agg_expr(e, group_rows);
273                        out.insert(out_name, val);
274                    } else if let Some(first) = group_rows.first() {
275                        let val = evaluate_expr(e, first);
276                        out.insert(out_name, val);
277                    }
278                }
279            }
280        }
281
282        result.push(out);
283    }
284
285    Ok(result)
286}
287
288/// Resolve a per-row value for an aggregate argument.
289/// If `arg_expr` is set, evaluate it; otherwise look up `arg` as a column name.
290fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
291    if let Some(expr) = arg_expr {
292        evaluate_expr(expr, row)
293    } else {
294        row.get(arg).cloned().unwrap_or(Value::Null)
295    }
296}
297
298fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
299    match func {
300        AggFunc::Count => {
301            if arg == "*" && arg_expr.is_none() {
302                Value::Int(rows.len() as i64)
303            } else {
304                let count = rows
305                    .iter()
306                    .filter(|r| {
307                        let v = resolve_agg_value(arg, arg_expr, r);
308                        !v.is_null()
309                    })
310                    .count();
311                Value::Int(count as i64)
312            }
313        }
314        AggFunc::Sum => {
315            let mut total = 0.0f64;
316            let mut has_any = false;
317            for r in rows {
318                let v = resolve_agg_value(arg, arg_expr, r);
319                match v {
320                    Value::Int(n) => { total += n as f64; has_any = true; }
321                    Value::Float(f) => { total += f; has_any = true; }
322                    _ => {}
323                }
324            }
325            if has_any { Value::Float(total) } else { Value::Null }
326        }
327        AggFunc::Avg => {
328            let mut total = 0.0f64;
329            let mut count = 0usize;
330            for r in rows {
331                let v = resolve_agg_value(arg, arg_expr, r);
332                match v {
333                    Value::Int(n) => { total += n as f64; count += 1; }
334                    Value::Float(f) => { total += f; count += 1; }
335                    _ => {}
336                }
337            }
338            if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
339        }
340        AggFunc::Min => {
341            let mut min_val: Option<Value> = None;
342            for r in rows {
343                let v = resolve_agg_value(arg, arg_expr, r);
344                if v.is_null() { continue; }
345                min_val = Some(match min_val {
346                    None => v,
347                    Some(ref current) => {
348                        if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
349                            v
350                        } else {
351                            current.clone()
352                        }
353                    }
354                });
355            }
356            min_val.unwrap_or(Value::Null)
357        }
358        AggFunc::Max => {
359            let mut max_val: Option<Value> = None;
360            for r in rows {
361                let v = resolve_agg_value(arg, arg_expr, r);
362                if v.is_null() { continue; }
363                max_val = Some(match max_val {
364                    None => v,
365                    Some(ref current) => {
366                        if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
367                            v
368                        } else {
369                            current.clone()
370                        }
371                    }
372                });
373            }
374            max_val.unwrap_or(Value::Null)
375        }
376    }
377}
378
379fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
380    match clause {
381        WhereClause::BoolOp(bop) => {
382            let left = evaluate_with_fts(&bop.left, row, fts);
383            match bop.op {
384                BoolOpKind::And => left && evaluate_with_fts(&bop.right, row, fts),
385                BoolOpKind::Or => left || evaluate_with_fts(&bop.right, row, fts),
386            }
387        }
388        WhereClause::Comparison(cmp) => {
389            // Check if we have FTS results for this comparison
390            if cmp.op == CmpOp::Like || cmp.op == CmpOp::NotLike {
391                if let Some(SqlValue::String(pattern)) = &cmp.value {
392                    let key = (cmp.column.clone(), pattern.clone());
393                    if let Some(matching_paths) = fts.get(&key) {
394                        let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
395                        let matched = matching_paths.contains(row_path);
396                        return if cmp.op == CmpOp::Like { matched } else { !matched };
397                    }
398                }
399            }
400            evaluate_comparison(cmp, row)
401        }
402    }
403}
404
405pub use crate::query_join::execute_join_query;
406
407pub(crate) fn execute_inner(
408    query: &SelectQuery,
409    rows: &[Row],
410    index: Option<&crate::index::TableIndex>,
411) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
412    let empty_fts = HashMap::new();
413    execute_with_fts(query, rows, index, &empty_fts)
414}
415
416pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
417    match clause {
418        WhereClause::BoolOp(bop) => {
419            let left = evaluate(&bop.left, row);
420            match bop.op {
421                BoolOpKind::And => left && evaluate(&bop.right, row),
422                BoolOpKind::Or => left || evaluate(&bop.right, row),
423            }
424        }
425        WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
426    }
427}
428
429/// Evaluate an Expr against a row, returning a Value.
430pub(crate) fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
431    match expr {
432        Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
433        Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
434        Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
435        Expr::Literal(SqlValue::Null) => Value::Null,
436        Expr::Literal(SqlValue::List(_)) => Value::Null,
437        Expr::Column(name) => {
438            if let Some(val) = row.get(name) {
439                return val.clone();
440            }
441            // Try all possible dot splits for dict access (e.g. "s.params.key")
442            for (i, _) in name.match_indices('.') {
443                let dict_col = &name[..i];
444                let dict_key = &name[i + 1..];
445                if let Some(Value::Dict(map)) = row.get(dict_col) {
446                    return map.get(dict_key).cloned().unwrap_or(Value::Null);
447                }
448            }
449            Value::Null
450        }
451        Expr::UnaryMinus(inner) => {
452            match evaluate_expr(inner, row) {
453                Value::Int(n) => Value::Int(-n),
454                Value::Float(f) => Value::Float(-f),
455                Value::Null => Value::Null,
456                _ => Value::Null, // non-numeric → NULL
457            }
458        }
459        Expr::BinaryOp { left, op, right } => {
460            let lv = evaluate_expr(left, row);
461            let rv = evaluate_expr(right, row);
462
463            // NULL propagation: any NULL operand → NULL
464            if lv.is_null() || rv.is_null() {
465                return Value::Null;
466            }
467
468            // Extract numeric values with int→float coercion
469            match (&lv, &rv) {
470                (Value::Int(a), Value::Int(b)) => {
471                    match op {
472                        ArithOp::Add => Value::Int(a.wrapping_add(*b)),
473                        ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
474                        ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
475                        ArithOp::Div => {
476                            if *b == 0 { Value::Null } else { Value::Int(a / b) }
477                        }
478                        ArithOp::Mod => {
479                            if *b == 0 { Value::Null } else { Value::Int(a % b) }
480                        }
481                    }
482                }
483                _ => {
484                    // Coerce to float
485                    let a = match &lv {
486                        Value::Int(n) => *n as f64,
487                        Value::Float(f) => *f,
488                        _ => return Value::Null,
489                    };
490                    let b = match &rv {
491                        Value::Int(n) => *n as f64,
492                        Value::Float(f) => *f,
493                        _ => return Value::Null,
494                    };
495                    match op {
496                        ArithOp::Add => Value::Float(a + b),
497                        ArithOp::Sub => Value::Float(a - b),
498                        ArithOp::Mul => Value::Float(a * b),
499                        ArithOp::Div => {
500                            if b == 0.0 { Value::Null } else { Value::Float(a / b) }
501                        }
502                        ArithOp::Mod => {
503                            if b == 0.0 { Value::Null } else { Value::Float(a % b) }
504                        }
505                    }
506                }
507            }
508        }
509        Expr::Case { whens, else_expr } => {
510            for (condition, result) in whens {
511                if evaluate(condition, row) {
512                    return evaluate_expr(result, row);
513                }
514            }
515            match else_expr {
516                Some(e) => evaluate_expr(e, row),
517                None => Value::Null,
518            }
519        }
520        Expr::CurrentDate => {
521            Value::Date(chrono::Local::now().naive_local().date())
522        }
523        Expr::CurrentTimestamp => {
524            Value::DateTime(chrono::Local::now().naive_local())
525        }
526        Expr::DateAdd { date, days } => {
527            let date_val = evaluate_expr(date, row);
528            let days_val = evaluate_expr(days, row);
529            let n = match &days_val {
530                Value::Int(n) => *n,
531                Value::Float(f) => *f as i64,
532                _ => return Value::Null,
533            };
534            let duration = chrono::Duration::days(n);
535            match date_val {
536                Value::Date(d) => {
537                    match d.checked_add_signed(duration) {
538                        Some(result) => Value::Date(result),
539                        None => Value::Null,
540                    }
541                }
542                Value::DateTime(dt) => {
543                    match dt.checked_add_signed(duration) {
544                        Some(result) => Value::DateTime(result),
545                        None => Value::Null,
546                    }
547                }
548                _ => Value::Null,
549            }
550        }
551        Expr::DateDiff { left, right } => {
552            let lv = evaluate_expr(left, row);
553            let rv = evaluate_expr(right, row);
554            let left_date = match &lv {
555                Value::Date(d) => d.and_hms_opt(0, 0, 0).unwrap(),
556                Value::DateTime(dt) => *dt,
557                _ => return Value::Null,
558            };
559            let right_date = match &rv {
560                Value::Date(d) => d.and_hms_opt(0, 0, 0).unwrap(),
561                Value::DateTime(dt) => *dt,
562                _ => return Value::Null,
563            };
564            Value::Int((left_date - right_date).num_days())
565        }
566        Expr::Aggregate { func, arg, .. } => {
567            // Post-aggregation: look up the pre-computed column name
568            let func_name = match func {
569                AggFunc::Count => "COUNT",
570                AggFunc::Sum => "SUM",
571                AggFunc::Avg => "AVG",
572                AggFunc::Min => "MIN",
573                AggFunc::Max => "MAX",
574            };
575            let col = format!("{}({})", func_name, arg);
576            row.get(&col).cloned().unwrap_or(Value::Null)
577        }
578        Expr::Subquery(_) => Value::Null,
579    }
580}
581
582fn evaluate_agg_expr(expr: &Expr, group_rows: &[&Row]) -> Value {
583    match expr {
584        Expr::Aggregate { func, arg, arg_expr } => {
585            compute_aggregate(func, arg, arg_expr.as_deref(), group_rows)
586        }
587        Expr::BinaryOp { left, op, right } => {
588            let lv = evaluate_agg_expr(left, group_rows);
589            let rv = evaluate_agg_expr(right, group_rows);
590            apply_arith_op(op, &lv, &rv)
591        }
592        Expr::UnaryMinus(inner) => {
593            match evaluate_agg_expr(inner, group_rows) {
594                Value::Int(n) => Value::Int(-n),
595                Value::Float(f) => Value::Float(-f),
596                _ => Value::Null,
597            }
598        }
599        other => {
600            if let Some(first) = group_rows.first() {
601                evaluate_expr(other, first)
602            } else {
603                Value::Null
604            }
605        }
606    }
607}
608
609fn apply_arith_op(op: &ArithOp, lv: &Value, rv: &Value) -> Value {
610    if lv.is_null() || rv.is_null() {
611        return Value::Null;
612    }
613    match (lv, rv) {
614        (Value::Int(a), Value::Int(b)) => match op {
615            ArithOp::Add => Value::Int(a.wrapping_add(*b)),
616            ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
617            ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
618            ArithOp::Div => if *b == 0 { Value::Null } else { Value::Int(a / b) },
619            ArithOp::Mod => if *b == 0 { Value::Null } else { Value::Int(a % b) },
620        },
621        _ => {
622            let a = match lv {
623                Value::Int(n) => *n as f64,
624                Value::Float(f) => *f,
625                _ => return Value::Null,
626            };
627            let b = match rv {
628                Value::Int(n) => *n as f64,
629                Value::Float(f) => *f,
630                _ => return Value::Null,
631            };
632            match op {
633                ArithOp::Add => Value::Float(a + b),
634                ArithOp::Sub => Value::Float(a - b),
635                ArithOp::Mul => Value::Float(a * b),
636                ArithOp::Div => if b == 0.0 { Value::Null } else { Value::Float(a / b) },
637                ArithOp::Mod => if b == 0.0 { Value::Null } else { Value::Float(a % b) },
638            }
639        }
640    }
641}
642
643fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
644    // If we have expression-based comparison (new path), use it for standard ops
645    if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
646        if matches!(cmp.op, CmpOp::Eq | CmpOp::Ne | CmpOp::Lt | CmpOp::Gt | CmpOp::Le | CmpOp::Ge) {
647            let left_val = evaluate_expr(left_expr, row);
648            let right_val = evaluate_expr(right_expr, row);
649
650            // NULL comparison: always false (except IS NULL handled below)
651            if left_val.is_null() || right_val.is_null() {
652                return false;
653            }
654
655            // Coerce for comparison: if types differ, try int→float
656            let ord = compare_model_values(&left_val, &right_val);
657
658            return match cmp.op {
659                CmpOp::Eq => ord == Some(Ordering::Equal),
660                CmpOp::Ne => ord != Some(Ordering::Equal),
661                CmpOp::Lt => ord == Some(Ordering::Less),
662                CmpOp::Gt => ord == Some(Ordering::Greater),
663                CmpOp::Le => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
664                CmpOp::Ge => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
665                _ => false,
666            };
667        }
668    }
669
670    // Fall back to legacy column-based comparison for IS NULL, IN, LIKE, etc.
671    let actual = row.get(&cmp.column);
672
673    if cmp.op == CmpOp::IsNull {
674        return actual.map_or(true, |v| v.is_null());
675    }
676    if cmp.op == CmpOp::IsNotNull {
677        return actual.map_or(false, |v| !v.is_null());
678    }
679
680    let actual = match actual {
681        Some(v) if !v.is_null() => v,
682        _ => return false,
683    };
684
685    let expected = match &cmp.value {
686        Some(v) => v,
687        None => return false,
688    };
689
690    match cmp.op {
691        CmpOp::Eq => eq_match(actual, expected),
692        CmpOp::Ne => !eq_match(actual, expected),
693        CmpOp::Lt => compare_values(actual, expected) == Some(Ordering::Less),
694        CmpOp::Gt => compare_values(actual, expected) == Some(Ordering::Greater),
695        CmpOp::Le => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
696        CmpOp::Ge => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
697        CmpOp::Like => like_match(actual, expected),
698        CmpOp::NotLike => !like_match(actual, expected),
699        CmpOp::In => {
700            if let SqlValue::List(items) = expected {
701                items.iter().any(|v| eq_match(actual, v))
702            } else {
703                eq_match(actual, expected)
704            }
705        }
706        CmpOp::IsNull | CmpOp::IsNotNull => unreachable!(),
707    }
708}
709
710/// Compare two model::Value instances, with int↔float coercion.
711fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
712    match (a, b) {
713        (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
714        (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
715        _ => a.partial_cmp(b),
716    }
717}
718
719fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
720    match sql_val {
721        SqlValue::Null => Value::Null,
722        SqlValue::String(s) => {
723            match target {
724                Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
725                Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
726                Value::Date(_) => {
727                    chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
728                        .map(Value::Date)
729                        .unwrap_or(Value::String(s.clone()))
730                }
731                Value::DateTime(_) => {
732                    chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S")
733                        .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f"))
734                        .map(Value::DateTime)
735                        .unwrap_or(Value::String(s.clone()))
736                }
737                _ => Value::String(s.clone()),
738            }
739        }
740        SqlValue::Int(n) => {
741            match target {
742                Value::Float(_) => Value::Float(*n as f64),
743                _ => Value::Int(*n),
744            }
745        }
746        SqlValue::Float(f) => Value::Float(*f),
747        SqlValue::List(_) => Value::Null, // Lists handled separately
748    }
749}
750
751fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
752    // Special handling for lists (e.g., categories)
753    if let Value::List(items) = actual {
754        if let SqlValue::String(s) = expected {
755            return items.contains(s);
756        }
757    }
758
759    let coerced = coerce_sql_to_value(expected, actual);
760    actual == &coerced
761}
762
763fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
764    let pattern_str = match pattern {
765        SqlValue::String(s) => s,
766        _ => return false,
767    };
768
769    // Convert SQL LIKE to regex
770    let mut regex_str = String::from("(?is)^");
771    for ch in pattern_str.chars() {
772        match ch {
773            '%' => regex_str.push_str(".*"),
774            '_' => regex_str.push('.'),
775            c => {
776                if regex::escape(&c.to_string()) != c.to_string() {
777                    regex_str.push_str(&regex::escape(&c.to_string()));
778                } else {
779                    regex_str.push(c);
780                }
781            }
782        }
783    }
784    regex_str.push('$');
785
786    let re = match Regex::new(&regex_str) {
787        Ok(r) => r,
788        Err(_) => return false,
789    };
790
791    match actual {
792        Value::List(items) => items.iter().any(|item| re.is_match(item)),
793        _ => re.is_match(&actual.to_display_string()),
794    }
795}
796
797fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
798    let coerced = coerce_sql_to_value(expected, actual);
799    actual.partial_cmp(&coerced)
800}
801
802/// Convert a SqlValue to a Value for index lookups (without a target type for coercion).
803fn sql_value_to_index_value(sv: &SqlValue) -> Value {
804    match sv {
805        SqlValue::String(s) => {
806            // Try datetime first (more specific)
807            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
808                return Value::DateTime(dt);
809            }
810            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") {
811                return Value::DateTime(dt);
812            }
813            // Try date
814            if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
815                return Value::Date(d);
816            }
817            Value::String(s.clone())
818        }
819        SqlValue::Int(n) => Value::Int(*n),
820        SqlValue::Float(f) => Value::Float(*f),
821        SqlValue::Null => Value::Null,
822        SqlValue::List(_) => Value::Null,
823    }
824}
825
826/// Try to use B-tree indexes to narrow the candidate row set.
827/// Returns Some(paths) if the entire WHERE clause could be resolved via index,
828/// or None if a full scan is needed.
829fn try_index_filter(
830    clause: &WhereClause,
831    index: &crate::index::TableIndex,
832) -> Option<std::collections::HashSet<String>> {
833    match clause {
834        WhereClause::Comparison(cmp) => {
835            if !index.has_index(&cmp.column) {
836                return None;
837            }
838            match cmp.op {
839                CmpOp::Eq => {
840                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
841                    let paths = index.lookup_eq(&cmp.column, &val);
842                    Some(paths.into_iter().map(|s| s.to_string()).collect())
843                }
844                CmpOp::Lt => {
845                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
846                    // exclusive upper bound: use range with max < val
847                    // lookup_range is inclusive, so we get all <= val then remove exact matches
848                    let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
849                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
850                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
851                }
852                CmpOp::Gt => {
853                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
854                    let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
855                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
856                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
857                }
858                CmpOp::Le => {
859                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
860                    let paths = index.lookup_range(&cmp.column, None, Some(&val));
861                    Some(paths.into_iter().map(|s| s.to_string()).collect())
862                }
863                CmpOp::Ge => {
864                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
865                    let paths = index.lookup_range(&cmp.column, Some(&val), None);
866                    Some(paths.into_iter().map(|s| s.to_string()).collect())
867                }
868                CmpOp::In => {
869                    if let Some(SqlValue::List(items)) = &cmp.value {
870                        let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
871                        let paths = index.lookup_in(&cmp.column, &vals);
872                        Some(paths.into_iter().map(|s| s.to_string()).collect())
873                    } else {
874                        None
875                    }
876                }
877                _ => None, // LIKE, IS NULL, etc. can't use index
878            }
879        }
880        WhereClause::BoolOp(bop) => {
881            let left = try_index_filter(&bop.left, index);
882            let right = try_index_filter(&bop.right, index);
883            match bop.op {
884                BoolOpKind::And => {
885                    match (left, right) {
886                        (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
887                        (Some(l), None) => Some(l), // narrow with left, scan-verify right
888                        (None, Some(r)) => Some(r),
889                        (None, None) => None,
890                    }
891                }
892                BoolOpKind::Or => {
893                    match (left, right) {
894                        (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
895                        _ => None, // Can't use index if either side needs full scan
896                    }
897                }
898            }
899        }
900    }
901}
902
903/// If an ORDER BY column matches a SELECT alias, replace its expr with the
904/// aliased expression so sorting uses the computed value.
905fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
906    let named = match columns {
907        ColumnList::Named(exprs) => exprs,
908        _ => return specs.to_vec(),
909    };
910
911    // Build alias → expr map
912    let alias_map: HashMap<String, &Expr> = named
913        .iter()
914        .filter_map(|se| match se {
915            SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
916            _ => None,
917        })
918        .collect();
919
920    specs
921        .iter()
922        .map(|spec| {
923            // If the ORDER BY column name matches a SELECT alias, use that expression
924            if let Some(expr) = alias_map.get(&spec.column) {
925                OrderSpec {
926                    column: spec.column.clone(),
927                    expr: Some((*expr).clone()),
928                    descending: spec.descending,
929                }
930            } else {
931                spec.clone()
932            }
933        })
934        .collect()
935}
936
937fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
938    rows.sort_by(|a, b| {
939        for spec in specs {
940            let (va, vb) = if let Some(ref expr) = spec.expr {
941                (evaluate_expr(expr, a), evaluate_expr(expr, b))
942            } else {
943                (
944                    a.get(&spec.column).cloned().unwrap_or(Value::Null),
945                    b.get(&spec.column).cloned().unwrap_or(Value::Null),
946                )
947            };
948
949            // NULLs sort last
950            let ordering = match (&va, &vb) {
951                (Value::Null, Value::Null) => Ordering::Equal,
952                (Value::Null, _) => Ordering::Greater,
953                (_, Value::Null) => Ordering::Less,
954                (a_val, b_val) => {
955                    compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
956                }
957            };
958
959            let ordering = if spec.descending {
960                ordering.reverse()
961            } else {
962                ordering
963            };
964
965            if ordering != Ordering::Equal {
966                return ordering;
967            }
968        }
969        Ordering::Equal
970    });
971}
972
973/// Convert a SqlValue to our model Value (for use in insert/update).
974pub(crate) fn sql_value_to_value(sql_val: &SqlValue) -> Value {
975    match sql_val {
976        SqlValue::Null => Value::Null,
977        SqlValue::String(s) => Value::String(s.clone()),
978        SqlValue::Int(n) => Value::Int(*n),
979        SqlValue::Float(f) => Value::Float(*f),
980        SqlValue::List(items) => {
981            let strings: Vec<String> = items
982                .iter()
983                .filter_map(|v| match v {
984                    SqlValue::String(s) => Some(s.clone()),
985                    _ => None,
986                })
987                .collect();
988            Value::List(strings)
989        }
990    }
991}
992
993#[cfg(test)]
994mod tests {
995    use super::*;
996
997    fn make_rows() -> Vec<Row> {
998        vec![
999            Row::from([
1000                ("path".into(), Value::String("a.md".into())),
1001                ("title".into(), Value::String("Alpha".into())),
1002                ("count".into(), Value::Int(10)),
1003            ]),
1004            Row::from([
1005                ("path".into(), Value::String("b.md".into())),
1006                ("title".into(), Value::String("Beta".into())),
1007                ("count".into(), Value::Int(5)),
1008            ]),
1009            Row::from([
1010                ("path".into(), Value::String("c.md".into())),
1011                ("title".into(), Value::String("Gamma".into())),
1012                ("count".into(), Value::Int(20)),
1013            ]),
1014        ]
1015    }
1016
1017    #[test]
1018    fn test_select_all() {
1019        let q = SelectQuery {
1020            columns: ColumnList::All,
1021            table: "test".into(),
1022            table_alias: None,
1023            subquery: None,
1024            joins: vec![],
1025            where_clause: None,
1026            group_by: None,
1027            having: None,
1028            order_by: None,
1029            limit: None,
1030            ctes: vec![],
1031        };
1032        let (rows, _cols) = execute_inner(&q, &make_rows(), None).unwrap();
1033        assert_eq!(rows.len(), 3);
1034    }
1035
1036    #[test]
1037    fn test_where_gt() {
1038        let q = SelectQuery {
1039            columns: ColumnList::All,
1040            table: "test".into(),
1041            table_alias: None,
1042            subquery: None,
1043            joins: vec![],
1044            where_clause: Some(WhereClause::Comparison(Comparison {
1045                column: "count".into(),
1046                op: CmpOp::Gt,
1047                value: Some(SqlValue::Int(5)),
1048                left_expr: Some(Expr::Column("count".into())),
1049                right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1050            })),
1051            group_by: None,
1052            having: None,
1053            order_by: None,
1054            limit: None,
1055            ctes: vec![],
1056        };
1057        let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1058        assert_eq!(rows.len(), 2);
1059    }
1060
1061    #[test]
1062    fn test_order_by_desc() {
1063        let q = SelectQuery {
1064            columns: ColumnList::All,
1065            table: "test".into(),
1066            table_alias: None,
1067            subquery: None,
1068            joins: vec![],
1069            where_clause: None,
1070            group_by: None,
1071            having: None,
1072            order_by: Some(vec![OrderSpec {
1073                column: "count".into(),
1074                expr: Some(Expr::Column("count".into())),
1075                descending: true,
1076            }]),
1077            limit: None,
1078            ctes: vec![],
1079        };
1080        let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1081        assert_eq!(rows[0]["count"], Value::Int(20));
1082        assert_eq!(rows[2]["count"], Value::Int(5));
1083    }
1084
1085    #[test]
1086    fn test_limit() {
1087        let q = SelectQuery {
1088            columns: ColumnList::All,
1089            table: "test".into(),
1090            table_alias: None,
1091            subquery: None,
1092            joins: vec![],
1093            where_clause: None,
1094            group_by: None,
1095            having: None,
1096            order_by: None,
1097            limit: Some(2),
1098            ctes: vec![],
1099        };
1100        let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1101        assert_eq!(rows.len(), 2);
1102    }
1103
1104    #[test]
1105    fn test_like() {
1106        let q = SelectQuery {
1107            columns: ColumnList::All,
1108            table: "test".into(),
1109            table_alias: None,
1110            subquery: None,
1111            joins: vec![],
1112            where_clause: Some(WhereClause::Comparison(Comparison {
1113                column: "title".into(),
1114                op: CmpOp::Like,
1115                value: Some(SqlValue::String("%lph%".into())),
1116                left_expr: Some(Expr::Column("title".into())),
1117                right_expr: None,
1118            })),
1119            group_by: None,
1120            having: None,
1121            order_by: None,
1122            limit: None,
1123            ctes: vec![],
1124        };
1125        let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1126        assert_eq!(rows.len(), 1);
1127        assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1128    }
1129
1130    #[test]
1131    fn test_is_null() {
1132        let mut rows = make_rows();
1133        rows[1].insert("optional".into(), Value::Null);
1134
1135        let q = SelectQuery {
1136            columns: ColumnList::All,
1137            table: "test".into(),
1138            table_alias: None,
1139            subquery: None,
1140            joins: vec![],
1141            where_clause: Some(WhereClause::Comparison(Comparison {
1142                column: "optional".into(),
1143                op: CmpOp::IsNull,
1144                value: None,
1145                left_expr: Some(Expr::Column("optional".into())),
1146                right_expr: None,
1147            })),
1148            group_by: None,
1149            having: None,
1150            order_by: None,
1151            limit: None,
1152            ctes: vec![],
1153        };
1154        let (result, _) = execute_inner(&q, &rows, None).unwrap();
1155        // All rows where optional is NULL or missing
1156        assert_eq!(result.len(), 3);
1157    }
1158
1159    // ── Expression evaluation tests ─────────────────────────��─────
1160
1161    #[test]
1162    fn test_evaluate_expr_literal() {
1163        let row = Row::new();
1164        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1165        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1166        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1167    }
1168
1169    #[test]
1170    fn test_evaluate_expr_column() {
1171        let row = Row::from([("x".into(), Value::Int(10))]);
1172        assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1173        assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1174    }
1175
1176    #[test]
1177    fn test_evaluate_expr_int_arithmetic() {
1178        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1179        let add = Expr::BinaryOp {
1180            left: Box::new(Expr::Column("a".into())),
1181            op: ArithOp::Add,
1182            right: Box::new(Expr::Column("b".into())),
1183        };
1184        assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1185
1186        let sub = Expr::BinaryOp {
1187            left: Box::new(Expr::Column("a".into())),
1188            op: ArithOp::Sub,
1189            right: Box::new(Expr::Column("b".into())),
1190        };
1191        assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1192
1193        let mul = Expr::BinaryOp {
1194            left: Box::new(Expr::Column("a".into())),
1195            op: ArithOp::Mul,
1196            right: Box::new(Expr::Column("b".into())),
1197        };
1198        assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1199
1200        let div = Expr::BinaryOp {
1201            left: Box::new(Expr::Column("a".into())),
1202            op: ArithOp::Div,
1203            right: Box::new(Expr::Column("b".into())),
1204        };
1205        assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); // integer division
1206
1207        let modulo = Expr::BinaryOp {
1208            left: Box::new(Expr::Column("a".into())),
1209            op: ArithOp::Mod,
1210            right: Box::new(Expr::Column("b".into())),
1211        };
1212        assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1213    }
1214
1215    #[test]
1216    fn test_evaluate_expr_float_coercion() {
1217        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1218        let add = Expr::BinaryOp {
1219            left: Box::new(Expr::Column("a".into())),
1220            op: ArithOp::Add,
1221            right: Box::new(Expr::Column("b".into())),
1222        };
1223        assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1224    }
1225
1226    #[test]
1227    fn test_evaluate_expr_null_propagation() {
1228        let row = Row::from([("a".into(), Value::Int(10))]);
1229        let add = Expr::BinaryOp {
1230            left: Box::new(Expr::Column("a".into())),
1231            op: ArithOp::Add,
1232            right: Box::new(Expr::Column("missing".into())),
1233        };
1234        assert_eq!(evaluate_expr(&add, &row), Value::Null);
1235    }
1236
1237    #[test]
1238    fn test_evaluate_expr_div_by_zero() {
1239        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1240        let div = Expr::BinaryOp {
1241            left: Box::new(Expr::Column("a".into())),
1242            op: ArithOp::Div,
1243            right: Box::new(Expr::Column("b".into())),
1244        };
1245        assert_eq!(evaluate_expr(&div, &row), Value::Null);
1246    }
1247
1248    #[test]
1249    fn test_evaluate_expr_unary_minus() {
1250        let row = Row::from([("x".into(), Value::Int(5))]);
1251        let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1252        assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1253    }
1254
1255    #[test]
1256    fn test_select_with_expression() {
1257        // Integration test: SELECT count * 2 AS doubled FROM test
1258        let stmt = crate::query_parser::parse_query(
1259            "SELECT count * 2 AS doubled FROM test"
1260        ).unwrap();
1261        if let crate::query_parser::Statement::Select(q) = stmt {
1262            let (rows, cols) = execute_inner(&q, &make_rows(), None).unwrap();
1263            assert_eq!(cols, vec!["doubled"]);
1264            assert_eq!(rows.len(), 3);
1265            // Rows are: count=10, count=5, count=20
1266            let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1267            assert!(values.contains(&Value::Int(20)));
1268            assert!(values.contains(&Value::Int(10)));
1269            assert!(values.contains(&Value::Int(40)));
1270        } else {
1271            panic!("Expected Select");
1272        }
1273    }
1274
1275    #[test]
1276    fn test_where_with_expression() {
1277        // SELECT * FROM test WHERE count * 2 > 15
1278        let stmt = crate::query_parser::parse_query(
1279            "SELECT * FROM test WHERE count * 2 > 15"
1280        ).unwrap();
1281        if let crate::query_parser::Statement::Select(q) = stmt {
1282            let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1283            // count=10 → 20 > 15 ✓, count=5 → 10 > 15 ✗, count=20 → 40 > 15 ✓
1284            assert_eq!(rows.len(), 2);
1285        } else {
1286            panic!("Expected Select");
1287        }
1288    }
1289
1290    #[test]
1291    fn test_order_by_expression() {
1292        // SELECT * FROM test ORDER BY count * -1 ASC (effectively DESC by count)
1293        let stmt = crate::query_parser::parse_query(
1294            "SELECT title, count FROM test ORDER BY count * -1 ASC"
1295        ).unwrap();
1296        if let crate::query_parser::Statement::Select(q) = stmt {
1297            let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1298            // count: 20 → -20, 10 → -10, 5 → -5, ASC means -20, -10, -5
1299            assert_eq!(rows[0]["count"], Value::Int(20));
1300            assert_eq!(rows[1]["count"], Value::Int(10));
1301            assert_eq!(rows[2]["count"], Value::Int(5));
1302        } else {
1303            panic!("Expected Select");
1304        }
1305    }
1306
1307    // ── CASE WHEN evaluation tests ────────────────────────────────
1308
1309    #[test]
1310    fn test_case_when_eval_basic() {
1311        let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1312        let expr = Expr::Case {
1313            whens: vec![(
1314                WhereClause::Comparison(Comparison {
1315                    column: "status".into(),
1316                    op: CmpOp::Eq,
1317                    value: Some(SqlValue::String("ACTIVE".into())),
1318                    left_expr: Some(Expr::Column("status".into())),
1319                    right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1320                }),
1321                Box::new(Expr::Literal(SqlValue::Int(1))),
1322            )],
1323            else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1324        };
1325        assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1326    }
1327
1328    #[test]
1329    fn test_case_when_eval_else() {
1330        let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1331        let expr = Expr::Case {
1332            whens: vec![(
1333                WhereClause::Comparison(Comparison {
1334                    column: "status".into(),
1335                    op: CmpOp::Eq,
1336                    value: Some(SqlValue::String("ACTIVE".into())),
1337                    left_expr: Some(Expr::Column("status".into())),
1338                    right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1339                }),
1340                Box::new(Expr::Literal(SqlValue::Int(1))),
1341            )],
1342            else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1343        };
1344        assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1345    }
1346
1347    #[test]
1348    fn test_case_when_eval_no_else_null() {
1349        let row = Row::from([("x".into(), Value::Int(99))]);
1350        let expr = Expr::Case {
1351            whens: vec![(
1352                WhereClause::Comparison(Comparison {
1353                    column: "x".into(),
1354                    op: CmpOp::Eq,
1355                    value: Some(SqlValue::Int(1)),
1356                    left_expr: Some(Expr::Column("x".into())),
1357                    right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1358                }),
1359                Box::new(Expr::Literal(SqlValue::String("one".into()))),
1360            )],
1361            else_expr: None,
1362        };
1363        assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1364    }
1365
1366    #[test]
1367    fn test_case_when_in_aggregate_query() {
1368        // SUM(CASE WHEN count > 5 THEN count ELSE 0 END)
1369        // Rows: count=10, count=5, count=20 → should sum 10 + 0 + 20 = 30
1370        let stmt = crate::query_parser::parse_query(
1371            "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1372        ).unwrap();
1373        if let crate::query_parser::Statement::Select(q) = stmt {
1374            let (rows, cols) = execute_inner(&q, &make_rows(), None).unwrap();
1375            assert_eq!(cols, vec!["total"]);
1376            assert_eq!(rows.len(), 1);
1377            assert_eq!(rows[0]["total"], Value::Float(30.0));
1378        } else {
1379            panic!("Expected Select");
1380        }
1381    }
1382
1383    #[test]
1384    fn test_case_when_with_unary_minus_in_aggregate() {
1385        // SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END)
1386        // Alpha: 10, Beta: -5, Gamma: -20 → 10 - 5 - 20 = -15
1387        let stmt = crate::query_parser::parse_query(
1388            "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1389        ).unwrap();
1390        if let crate::query_parser::Statement::Select(q) = stmt {
1391            let (rows, _) = execute_inner(&q, &make_rows(), None).unwrap();
1392            assert_eq!(rows.len(), 1);
1393            assert_eq!(rows[0]["net"], Value::Float(-15.0));
1394        } else {
1395            panic!("Expected Select");
1396        }
1397    }
1398
1399    #[test]
1400    fn test_dateadd_with_dict_in_group_by() {
1401        // Simulate a joined row with a dict field, then GROUP BY + DateAdd expr
1402        use indexmap::IndexMap;
1403        let mut params = IndexMap::new();
1404        params.insert("exit_days".to_string(), Value::Int(21));
1405
1406        let rows = vec![
1407            Row::from([
1408                ("o.token".into(), Value::String("BTC".into())),
1409                ("o.event_date".into(), Value::Date(
1410                    chrono::NaiveDate::from_ymd_opt(2026, 1, 1).unwrap()
1411                )),
1412                ("o.size".into(), Value::Int(100)),
1413                ("s.params".into(), Value::Dict(params.clone())),
1414            ]),
1415            Row::from([
1416                ("o.token".into(), Value::String("BTC".into())),
1417                ("o.event_date".into(), Value::Date(
1418                    chrono::NaiveDate::from_ymd_opt(2026, 1, 1).unwrap()
1419                )),
1420                ("o.size".into(), Value::Int(50)),
1421                ("s.params".into(), Value::Dict(params.clone())),
1422            ]),
1423        ];
1424
1425        let q = SelectQuery {
1426            columns: ColumnList::Named(vec![
1427                SelectExpr::Column("o.token".into()),
1428                SelectExpr::Column("o.event_date".into()),
1429                SelectExpr::Expr {
1430                    expr: Expr::DateAdd {
1431                        date: Box::new(Expr::Column("o.event_date".into())),
1432                        days: Box::new(Expr::Column("s.params.exit_days".into())),
1433                    },
1434                    alias: Some("exit_date".into()),
1435                },
1436                SelectExpr::Aggregate {
1437                    func: AggFunc::Sum,
1438                    arg: "o.size".into(),
1439                    arg_expr: Some(Expr::Column("o.size".into())),
1440                    alias: Some("total".into()),
1441                },
1442            ]),
1443            table: "orders".into(),
1444            table_alias: None,
1445            subquery: None,
1446            joins: vec![],
1447            where_clause: None,
1448            group_by: Some(vec!["o.token".into(), "o.event_date".into()]),
1449            having: None,
1450            order_by: None,
1451            limit: None,
1452            ctes: vec![],
1453        };
1454
1455        let (rows, cols) = execute_inner(&q, &rows, None).unwrap();
1456        assert_eq!(rows.len(), 1);
1457        assert!(cols.contains(&"exit_date".to_string()));
1458        assert_eq!(rows[0]["total"], Value::Float(150.0));
1459        // The key test: exit_date should be 2026-01-22, not Null
1460        assert_eq!(
1461            rows[0]["exit_date"],
1462            Value::Date(chrono::NaiveDate::from_ymd_opt(2026, 1, 22).unwrap())
1463        );
1464    }
1465
1466    #[test]
1467    fn test_aggregate_arithmetic() {
1468        // SUM(count) for all rows = 10 + 5 + 20 = 35
1469        // COUNT(*) = 3
1470        // SUM produces Float, COUNT produces Int → mixed → Float division
1471        let stmt = crate::query_parser::parse_query(
1472            "SELECT SUM(count) / COUNT(*) AS avg_count FROM test"
1473        ).unwrap();
1474        if let crate::query_parser::Statement::Select(q) = stmt {
1475            let (rows, cols) = execute_inner(&q, &make_rows(), None).unwrap();
1476            assert_eq!(cols, vec!["avg_count"]);
1477            assert_eq!(rows.len(), 1);
1478            match &rows[0]["avg_count"] {
1479                Value::Float(f) => assert!((f - 11.666666666666666).abs() < 0.001),
1480                other => panic!("Expected Float, got {:?}", other),
1481            }
1482        } else {
1483            panic!("Expected Select");
1484        }
1485    }
1486
1487    #[test]
1488    fn test_aggregate_subtraction_with_group_by() {
1489        let rows = vec![
1490            {
1491                let mut r = Row::new();
1492                r.insert("token".into(), Value::String("BTC".into()));
1493                r.insert("side".into(), Value::String("BUY".into()));
1494                r.insert("size".into(), Value::Float(100.0));
1495                r
1496            },
1497            {
1498                let mut r = Row::new();
1499                r.insert("token".into(), Value::String("BTC".into()));
1500                r.insert("side".into(), Value::String("SELL".into()));
1501                r.insert("size".into(), Value::Float(60.0));
1502                r
1503            },
1504        ];
1505        let stmt = crate::query_parser::parse_query(
1506            "SELECT token, SUM(CASE WHEN side = 'BUY' THEN size ELSE 0 END) - SUM(CASE WHEN side = 'SELL' THEN size ELSE 0 END) AS net FROM test GROUP BY token"
1507        ).unwrap();
1508        if let crate::query_parser::Statement::Select(q) = stmt {
1509            let (result, _) = execute_inner(&q, &rows, None).unwrap();
1510            assert_eq!(result.len(), 1);
1511            assert_eq!(result[0]["net"], Value::Float(40.0));
1512        } else {
1513            panic!("Expected Select");
1514        }
1515    }
1516
1517    // ── Issue #42: Aggregate subtraction without GROUP BY ──
1518
1519    #[test]
1520    fn test_aggregate_subtraction_no_group() {
1521        // SUM(count) = 10 + 5 + 20 = 35, COUNT(*) = 3, diff = 35 - 3 = 32
1522        let stmt = crate::query_parser::parse_query(
1523            "SELECT SUM(count) - COUNT(*) as diff FROM test"
1524        ).unwrap();
1525        if let crate::query_parser::Statement::Select(q) = stmt {
1526            let (rows, cols) = execute_inner(&q, &make_rows(), None).unwrap();
1527            assert_eq!(cols, vec!["diff"]);
1528            assert_eq!(rows.len(), 1);
1529            assert_eq!(rows[0]["diff"], Value::Float(32.0));
1530        } else {
1531            panic!("Expected Select");
1532        }
1533    }
1534
1535    // ── Issue #42: Aggregate division with GROUP BY ──
1536
1537    #[test]
1538    fn test_aggregate_division_with_group_by() {
1539        let rows = vec![
1540            {
1541                let mut r = Row::new();
1542                r.insert("category".into(), Value::String("A".into()));
1543                r.insert("count".into(), Value::Int(10));
1544                r
1545            },
1546            {
1547                let mut r = Row::new();
1548                r.insert("category".into(), Value::String("A".into()));
1549                r.insert("count".into(), Value::Int(20));
1550                r
1551            },
1552            {
1553                let mut r = Row::new();
1554                r.insert("category".into(), Value::String("B".into()));
1555                r.insert("count".into(), Value::Int(6));
1556                r
1557            },
1558        ];
1559        // Group A: SUM(count)=30, COUNT(*)=2, ratio=15.0
1560        // Group B: SUM(count)=6, COUNT(*)=1, ratio=6.0
1561        let stmt = crate::query_parser::parse_query(
1562            "SELECT category, SUM(count) / COUNT(*) as ratio FROM test GROUP BY category"
1563        ).unwrap();
1564        if let crate::query_parser::Statement::Select(q) = stmt {
1565            let (result, cols) = execute_inner(&q, &rows, None).unwrap();
1566            assert!(cols.contains(&"ratio".to_string()));
1567            assert_eq!(result.len(), 2);
1568            // Find group A and B by category value
1569            let group_a = result.iter().find(|r| r["category"] == Value::String("A".into())).unwrap();
1570            let group_b = result.iter().find(|r| r["category"] == Value::String("B".into())).unwrap();
1571            match &group_a["ratio"] {
1572                Value::Float(f) => assert!((f - 15.0).abs() < 0.001),
1573                other => panic!("Expected Float for group A ratio, got {:?}", other),
1574            }
1575            match &group_b["ratio"] {
1576                Value::Float(f) => assert!((f - 6.0).abs() < 0.001),
1577                other => panic!("Expected Float for group B ratio, got {:?}", other),
1578            }
1579        } else {
1580            panic!("Expected Select");
1581        }
1582    }
1583}