Skip to main content

mdql_core/
query_engine.rs

1//! Execute parsed queries over in-memory rows.
2
3use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14    query: &SelectQuery,
15    rows: &[Row],
16    _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18    execute(query, rows, None)
19}
20
21/// Execute a query with optional B-tree index and FTS searcher.
22pub fn execute_query_indexed(
23    query: &SelectQuery,
24    rows: &[Row],
25    schema: &Schema,
26    index: Option<&crate::index::TableIndex>,
27    searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29    // Pre-compute FTS results for any LIKE clauses on section columns
30    let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31        collect_fts_results(wc, schema, searcher)
32    } else {
33        HashMap::new()
34    };
35
36    execute_with_fts(query, rows, index, &fts_results)
37}
38
39/// Collect FTS results for LIKE comparisons on section columns.
40/// Returns a map from (column, pattern) → set of matching paths.
41fn collect_fts_results(
42    clause: &WhereClause,
43    schema: &Schema,
44    searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46    let mut results = HashMap::new();
47    collect_fts_results_inner(clause, schema, searcher, &mut results);
48    results
49}
50
51fn collect_fts_results_inner(
52    clause: &WhereClause,
53    schema: &Schema,
54    searcher: &crate::search::TableSearcher,
55    results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57    match clause {
58        WhereClause::Comparison(cmp) => {
59            if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60                if let Some(SqlValue::String(pattern)) = &cmp.value {
61                    // Strip SQL wildcards for Tantivy query
62                    let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63                    if !search_term.is_empty() {
64                        if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65                            let key = (cmp.column.clone(), pattern.clone());
66                            results.insert(key, paths.into_iter().collect());
67                        }
68                    }
69                }
70            }
71        }
72        WhereClause::BoolOp(bop) => {
73            collect_fts_results_inner(&bop.left, schema, searcher, results);
74            collect_fts_results_inner(&bop.right, schema, searcher, results);
75        }
76    }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82    query: &SelectQuery,
83    rows: &[Row],
84    index: Option<&crate::index::TableIndex>,
85    fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87    // Determine available columns
88    let mut all_columns: Vec<String> = Vec::new();
89    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90    for r in rows {
91        for k in r.keys() {
92            if seen.insert(k.clone()) {
93                all_columns.push(k.clone());
94            }
95        }
96    }
97
98    // Check if query has aggregates
99    let has_aggregates = match &query.columns {
100        ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101        _ => false,
102    };
103
104    // Output column names
105    let columns: Vec<String> = match &query.columns {
106        ColumnList::All => all_columns,
107        ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108    };
109
110    // Filter — try index first, fall back to full scan
111    let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112        let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113        if let Some(paths) = candidate_paths {
114            rows.iter()
115                .filter(|r| {
116                    r.get("path")
117                        .and_then(|v| v.as_str())
118                        .map_or(false, |p| paths.contains(p))
119                })
120                .filter(|r| evaluate_with_fts(wc, r, fts))
121                .cloned()
122                .collect()
123        } else {
124            rows.iter()
125                .filter(|r| evaluate_with_fts(wc, r, fts))
126                .cloned()
127                .collect()
128        }
129    } else {
130        rows.to_vec()
131    };
132
133    // Aggregate if needed
134    let mut result = if has_aggregates || query.group_by.is_some() {
135        let exprs = match &query.columns {
136            ColumnList::Named(exprs) => exprs.clone(),
137            _ => return Err(MdqlError::QueryExecution(
138                "SELECT * with GROUP BY is not supported".into(),
139            )),
140        };
141        let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142        aggregate_rows(&filtered, &exprs, group_keys)?
143    } else {
144        filtered
145    };
146
147    // HAVING filter — apply after aggregation
148    if let Some(ref having) = query.having {
149        result.retain(|row| evaluate(having, row));
150    }
151
152    // Sort — resolve ORDER BY aliases against SELECT list
153    if let Some(ref order_by) = query.order_by {
154        let resolved = resolve_order_aliases(order_by, &query.columns);
155        sort_rows(&mut result, &resolved);
156    }
157
158    // Limit
159    if let Some(limit) = query.limit {
160        result.truncate(limit as usize);
161    }
162
163    // Project — evaluate expressions and strip to requested columns
164    if !matches!(query.columns, ColumnList::All) {
165        let named_exprs = match &query.columns {
166            ColumnList::Named(exprs) => exprs,
167            _ => unreachable!(),
168        };
169
170        // Compute expression columns first, then retain only requested columns
171        let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
172        if has_expr_cols {
173            for row in &mut result {
174                for expr in named_exprs {
175                    if let SelectExpr::Expr { expr: e, alias } = expr {
176                        let name = alias.clone().unwrap_or_else(|| e.display_name());
177                        let val = evaluate_expr(e, row);
178                        row.insert(name, val);
179                    }
180                }
181            }
182        }
183
184        let col_set: std::collections::HashSet<&str> =
185            columns.iter().map(|s| s.as_str()).collect();
186        for row in &mut result {
187            row.retain(|k, _| col_set.contains(k.as_str()));
188        }
189    }
190
191    Ok((result, columns))
192}
193
194fn aggregate_rows(
195    rows: &[Row],
196    exprs: &[SelectExpr],
197    group_keys: &[String],
198) -> crate::errors::Result<Vec<Row>> {
199    // Group rows by group_keys
200    let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
201    let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
202
203    if group_keys.is_empty() {
204        // No GROUP BY — all rows are one group
205        let all_refs: Vec<&Row> = rows.iter().collect();
206        groups.push((vec![], all_refs));
207    } else {
208        for row in rows {
209            let key: Vec<String> = group_keys
210                .iter()
211                .map(|k| {
212                    row.get(k)
213                        .map(|v| v.to_display_string())
214                        .unwrap_or_default()
215                })
216                .collect();
217            let key_vals: Vec<Value> = group_keys
218                .iter()
219                .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
220                .collect();
221            if let Some(&idx) = key_index.get(&key) {
222                groups[idx].1.push(row);
223            } else {
224                let idx = groups.len();
225                key_index.insert(key, idx);
226                groups.push((key_vals, vec![row]));
227            }
228        }
229    }
230
231    // Compute aggregates per group
232    let mut result = Vec::new();
233    for (key_vals, group_rows) in &groups {
234        let mut out = Row::new();
235
236        // Fill in group key values
237        for (i, k) in group_keys.iter().enumerate() {
238            out.insert(k.clone(), key_vals[i].clone());
239        }
240
241        // Compute each expression
242        for expr in exprs {
243            match expr {
244                SelectExpr::Column(name) => {
245                    // Already filled if it's a group key; otherwise take first row's value
246                    if !out.contains_key(name) {
247                        if let Some(first) = group_rows.first() {
248                            out.insert(
249                                name.clone(),
250                                first.get(name).cloned().unwrap_or(Value::Null),
251                            );
252                        }
253                    }
254                }
255                SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
256                    let out_name = alias
257                        .clone()
258                        .unwrap_or_else(|| expr.output_name());
259                    let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
260                    out.insert(out_name, val);
261                }
262                SelectExpr::Expr { expr: e, alias } => {
263                    let out_name = alias.clone().unwrap_or_else(|| e.display_name());
264                    if let Some(first) = group_rows.first() {
265                        let val = evaluate_expr(e, first);
266                        out.insert(out_name, val);
267                    }
268                }
269            }
270        }
271
272        result.push(out);
273    }
274
275    Ok(result)
276}
277
278/// Resolve a per-row value for an aggregate argument.
279/// If `arg_expr` is set, evaluate it; otherwise look up `arg` as a column name.
280fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
281    if let Some(expr) = arg_expr {
282        evaluate_expr(expr, row)
283    } else {
284        row.get(arg).cloned().unwrap_or(Value::Null)
285    }
286}
287
288fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
289    match func {
290        AggFunc::Count => {
291            if arg == "*" && arg_expr.is_none() {
292                Value::Int(rows.len() as i64)
293            } else {
294                let count = rows
295                    .iter()
296                    .filter(|r| {
297                        let v = resolve_agg_value(arg, arg_expr, r);
298                        !v.is_null()
299                    })
300                    .count();
301                Value::Int(count as i64)
302            }
303        }
304        AggFunc::Sum => {
305            let mut total = 0.0f64;
306            let mut has_any = false;
307            for r in rows {
308                let v = resolve_agg_value(arg, arg_expr, r);
309                match v {
310                    Value::Int(n) => { total += n as f64; has_any = true; }
311                    Value::Float(f) => { total += f; has_any = true; }
312                    _ => {}
313                }
314            }
315            if has_any { Value::Float(total) } else { Value::Null }
316        }
317        AggFunc::Avg => {
318            let mut total = 0.0f64;
319            let mut count = 0usize;
320            for r in rows {
321                let v = resolve_agg_value(arg, arg_expr, r);
322                match v {
323                    Value::Int(n) => { total += n as f64; count += 1; }
324                    Value::Float(f) => { total += f; count += 1; }
325                    _ => {}
326                }
327            }
328            if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
329        }
330        AggFunc::Min => {
331            let mut min_val: Option<Value> = None;
332            for r in rows {
333                let v = resolve_agg_value(arg, arg_expr, r);
334                if v.is_null() { continue; }
335                min_val = Some(match min_val {
336                    None => v,
337                    Some(ref current) => {
338                        if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
339                            v
340                        } else {
341                            current.clone()
342                        }
343                    }
344                });
345            }
346            min_val.unwrap_or(Value::Null)
347        }
348        AggFunc::Max => {
349            let mut max_val: Option<Value> = None;
350            for r in rows {
351                let v = resolve_agg_value(arg, arg_expr, r);
352                if v.is_null() { continue; }
353                max_val = Some(match max_val {
354                    None => v,
355                    Some(ref current) => {
356                        if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
357                            v
358                        } else {
359                            current.clone()
360                        }
361                    }
362                });
363            }
364            max_val.unwrap_or(Value::Null)
365        }
366    }
367}
368
369fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
370    match clause {
371        WhereClause::BoolOp(bop) => {
372            let left = evaluate_with_fts(&bop.left, row, fts);
373            match bop.op.as_str() {
374                "AND" => left && evaluate_with_fts(&bop.right, row, fts),
375                "OR" => left || evaluate_with_fts(&bop.right, row, fts),
376                _ => false,
377            }
378        }
379        WhereClause::Comparison(cmp) => {
380            // Check if we have FTS results for this comparison
381            if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
382                if let Some(SqlValue::String(pattern)) = &cmp.value {
383                    let key = (cmp.column.clone(), pattern.clone());
384                    if let Some(matching_paths) = fts.get(&key) {
385                        let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
386                        let matched = matching_paths.contains(row_path);
387                        return if cmp.op == "LIKE" { matched } else { !matched };
388                    }
389                }
390            }
391            evaluate_comparison(cmp, row)
392        }
393    }
394}
395
396pub fn execute_join_query(
397    query: &SelectQuery,
398    tables: &HashMap<String, (Schema, Vec<Row>)>,
399) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
400    if query.joins.is_empty() {
401        return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
402    }
403
404    let left_name = &query.table;
405    let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
406
407    // Build alias→table mapping for all tables
408    let mut aliases: HashMap<String, String> = HashMap::new();
409    aliases.insert(left_name.clone(), left_name.clone());
410    if let Some(ref a) = query.table_alias {
411        aliases.insert(a.clone(), left_name.clone());
412    }
413    for join in &query.joins {
414        aliases.insert(join.table.clone(), join.table.clone());
415        if let Some(ref a) = join.alias {
416            aliases.insert(a.clone(), join.table.clone());
417        }
418    }
419
420    // Start with the left table rows, prefixed with alias
421    let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
422        MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
423    })?;
424
425    let mut current_rows: Vec<Row> = left_rows
426        .iter()
427        .map(|r| {
428            let mut prefixed = Row::new();
429            for (k, v) in r {
430                prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
431            }
432            prefixed
433        })
434        .collect();
435
436    // Process each JOIN sequentially
437    for join in &query.joins {
438        let right_name = &join.table;
439        let right_alias = join.alias.as_deref().unwrap_or(right_name);
440
441        let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
442            MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
443        })?;
444
445        // Resolve ON columns to determine which is left vs right
446        let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
447        let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
448
449        // Figure out which ON column refers to the new right table
450        let (left_key, right_key) = if on_right_table == *right_name {
451            // left_col is from the left side, right_col is from the right table
452            let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
453            (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
454        } else {
455            // right_col is from the left side, left_col is from the right table
456            let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
457            (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
458        };
459
460        // Build index on right table
461        let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
462        for r in right_rows {
463            if let Some(key) = r.get(&right_key) {
464                let key_str = key.to_display_string();
465                right_index.entry(key_str).or_default().push(r);
466            }
467        }
468
469        // Join current rows with right table
470        let mut next_rows: Vec<Row> = Vec::new();
471        for lr in &current_rows {
472            if let Some(key) = lr.get(&left_key) {
473                let key_str = key.to_display_string();
474                if let Some(matching) = right_index.get(&key_str) {
475                    for rr in matching {
476                        let mut merged = lr.clone();
477                        for (k, v) in *rr {
478                            merged.insert(format!("{}.{}", right_alias, k), v.clone());
479                        }
480                        next_rows.push(merged);
481                    }
482                }
483            }
484        }
485        current_rows = next_rows;
486    }
487
488    let (mut result, columns) = execute(query, &current_rows, None)?;
489
490    // Add unprefixed aliases for non-colliding column names in the output.
491    // e.g., if result has s.title and b.sharpe (no other "title" or "sharpe"),
492    // add "title" and "sharpe" as shorthand keys.
493    if !result.is_empty() {
494        let mut base_counts: HashMap<String, usize> = HashMap::new();
495        for key in &columns {
496            if let Some((_prefix, base)) = key.split_once('.') {
497                *base_counts.entry(base.to_string()).or_default() += 1;
498            }
499        }
500        let unique_bases: Vec<String> = base_counts
501            .into_iter()
502            .filter(|(_, count)| *count == 1)
503            .map(|(base, _)| base)
504            .collect();
505
506        if !unique_bases.is_empty() {
507            let unique_set: std::collections::HashSet<&str> =
508                unique_bases.iter().map(|s| s.as_str()).collect();
509            for row in &mut result {
510                let additions: Vec<(String, Value)> = row
511                    .iter()
512                    .filter_map(|(k, v)| {
513                        k.split_once('.').and_then(|(_, base)| {
514                            if unique_set.contains(base) {
515                                Some((base.to_string(), v.clone()))
516                            } else {
517                                None
518                            }
519                        })
520                    })
521                    .collect();
522                for (k, v) in additions {
523                    row.insert(k, v);
524                }
525            }
526        }
527    }
528
529    Ok((result, columns))
530}
531
532/// Given a table name, find the alias used for it.
533fn reverse_alias(
534    table_name: &str,
535    aliases: &HashMap<String, String>,
536    query: &SelectQuery,
537    joins: &[JoinClause],
538) -> String {
539    // Check if the FROM table matches
540    if query.table == table_name {
541        return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
542    }
543    // Check join tables
544    for j in joins {
545        if j.table == table_name {
546            return j.alias.as_deref().unwrap_or(&j.table).to_string();
547        }
548    }
549    // Fall back: check if table_name is itself an alias
550    if aliases.contains_key(table_name) {
551        return table_name.to_string();
552    }
553    table_name.to_string()
554}
555
556fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
557    if let Some((alias, column)) = col.split_once('.') {
558        let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
559        (table, column.to_string())
560    } else {
561        (String::new(), col.to_string())
562    }
563}
564
565fn execute(
566    query: &SelectQuery,
567    rows: &[Row],
568    index: Option<&crate::index::TableIndex>,
569) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
570    let empty_fts = HashMap::new();
571    execute_with_fts(query, rows, index, &empty_fts)
572}
573
574pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
575    match clause {
576        WhereClause::BoolOp(bop) => {
577            let left = evaluate(&bop.left, row);
578            match bop.op.as_str() {
579                "AND" => left && evaluate(&bop.right, row),
580                "OR" => left || evaluate(&bop.right, row),
581                _ => false,
582            }
583        }
584        WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
585    }
586}
587
588/// Evaluate an Expr against a row, returning a Value.
589pub fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
590    match expr {
591        Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
592        Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
593        Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
594        Expr::Literal(SqlValue::Null) => Value::Null,
595        Expr::Literal(SqlValue::List(_)) => Value::Null,
596        Expr::Column(name) => {
597            if let Some(val) = row.get(name) {
598                return val.clone();
599            }
600            // Try all possible dot splits for dict access (e.g. "s.params.key")
601            for (i, _) in name.match_indices('.') {
602                let dict_col = &name[..i];
603                let dict_key = &name[i + 1..];
604                if let Some(Value::Dict(map)) = row.get(dict_col) {
605                    return map.get(dict_key).cloned().unwrap_or(Value::Null);
606                }
607            }
608            Value::Null
609        }
610        Expr::UnaryMinus(inner) => {
611            match evaluate_expr(inner, row) {
612                Value::Int(n) => Value::Int(-n),
613                Value::Float(f) => Value::Float(-f),
614                Value::Null => Value::Null,
615                _ => Value::Null, // non-numeric → NULL
616            }
617        }
618        Expr::BinaryOp { left, op, right } => {
619            let lv = evaluate_expr(left, row);
620            let rv = evaluate_expr(right, row);
621
622            // NULL propagation: any NULL operand → NULL
623            if lv.is_null() || rv.is_null() {
624                return Value::Null;
625            }
626
627            // Extract numeric values with int→float coercion
628            match (&lv, &rv) {
629                (Value::Int(a), Value::Int(b)) => {
630                    match op {
631                        ArithOp::Add => Value::Int(a.wrapping_add(*b)),
632                        ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
633                        ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
634                        ArithOp::Div => {
635                            if *b == 0 { Value::Null } else { Value::Int(a / b) }
636                        }
637                        ArithOp::Mod => {
638                            if *b == 0 { Value::Null } else { Value::Int(a % b) }
639                        }
640                    }
641                }
642                _ => {
643                    // Coerce to float
644                    let a = match &lv {
645                        Value::Int(n) => *n as f64,
646                        Value::Float(f) => *f,
647                        _ => return Value::Null,
648                    };
649                    let b = match &rv {
650                        Value::Int(n) => *n as f64,
651                        Value::Float(f) => *f,
652                        _ => return Value::Null,
653                    };
654                    match op {
655                        ArithOp::Add => Value::Float(a + b),
656                        ArithOp::Sub => Value::Float(a - b),
657                        ArithOp::Mul => Value::Float(a * b),
658                        ArithOp::Div => {
659                            if b == 0.0 { Value::Null } else { Value::Float(a / b) }
660                        }
661                        ArithOp::Mod => {
662                            if b == 0.0 { Value::Null } else { Value::Float(a % b) }
663                        }
664                    }
665                }
666            }
667        }
668        Expr::Case { whens, else_expr } => {
669            for (condition, result) in whens {
670                if evaluate(condition, row) {
671                    return evaluate_expr(result, row);
672                }
673            }
674            match else_expr {
675                Some(e) => evaluate_expr(e, row),
676                None => Value::Null,
677            }
678        }
679        Expr::CurrentDate => {
680            Value::Date(chrono::Local::now().naive_local().date())
681        }
682        Expr::CurrentTimestamp => {
683            Value::DateTime(chrono::Local::now().naive_local())
684        }
685        Expr::DateAdd { date, days } => {
686            let date_val = evaluate_expr(date, row);
687            let days_val = evaluate_expr(days, row);
688            let n = match &days_val {
689                Value::Int(n) => *n,
690                Value::Float(f) => *f as i64,
691                _ => return Value::Null,
692            };
693            let duration = chrono::Duration::days(n);
694            match date_val {
695                Value::Date(d) => {
696                    match d.checked_add_signed(duration) {
697                        Some(result) => Value::Date(result),
698                        None => Value::Null,
699                    }
700                }
701                Value::DateTime(dt) => {
702                    match dt.checked_add_signed(duration) {
703                        Some(result) => Value::DateTime(result),
704                        None => Value::Null,
705                    }
706                }
707                _ => Value::Null,
708            }
709        }
710        Expr::DateDiff { left, right } => {
711            let lv = evaluate_expr(left, row);
712            let rv = evaluate_expr(right, row);
713            let left_date = match &lv {
714                Value::Date(d) => d.and_hms_opt(0, 0, 0).unwrap(),
715                Value::DateTime(dt) => *dt,
716                _ => return Value::Null,
717            };
718            let right_date = match &rv {
719                Value::Date(d) => d.and_hms_opt(0, 0, 0).unwrap(),
720                Value::DateTime(dt) => *dt,
721                _ => return Value::Null,
722            };
723            Value::Int((left_date - right_date).num_days())
724        }
725    }
726}
727
728fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
729    // If we have expression-based comparison (new path), use it for standard ops
730    if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
731        if ["=", "!=", "<", ">", "<=", ">="].contains(&cmp.op.as_str()) {
732            let left_val = evaluate_expr(left_expr, row);
733            let right_val = evaluate_expr(right_expr, row);
734
735            // NULL comparison: always false (except IS NULL handled below)
736            if left_val.is_null() || right_val.is_null() {
737                return false;
738            }
739
740            // Coerce for comparison: if types differ, try int→float
741            let ord = compare_model_values(&left_val, &right_val);
742
743            return match cmp.op.as_str() {
744                "=" => ord == Some(Ordering::Equal),
745                "!=" => ord != Some(Ordering::Equal),
746                "<" => ord == Some(Ordering::Less),
747                ">" => ord == Some(Ordering::Greater),
748                "<=" => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
749                ">=" => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
750                _ => false,
751            };
752        }
753    }
754
755    // Fall back to legacy column-based comparison for IS NULL, IN, LIKE, etc.
756    let actual = row.get(&cmp.column);
757
758    if cmp.op == "IS NULL" {
759        return actual.map_or(true, |v| v.is_null());
760    }
761    if cmp.op == "IS NOT NULL" {
762        return actual.map_or(false, |v| !v.is_null());
763    }
764
765    let actual = match actual {
766        Some(v) if !v.is_null() => v,
767        _ => return false,
768    };
769
770    let expected = match &cmp.value {
771        Some(v) => v,
772        None => return false,
773    };
774
775    match cmp.op.as_str() {
776        "=" => eq_match(actual, expected),
777        "!=" => !eq_match(actual, expected),
778        "<" => compare_values(actual, expected) == Some(Ordering::Less),
779        ">" => compare_values(actual, expected) == Some(Ordering::Greater),
780        "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
781        ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
782        "LIKE" => like_match(actual, expected),
783        "NOT LIKE" => !like_match(actual, expected),
784        "IN" => {
785            if let SqlValue::List(items) = expected {
786                items.iter().any(|v| eq_match(actual, v))
787            } else {
788                eq_match(actual, expected)
789            }
790        }
791        _ => false,
792    }
793}
794
795/// Compare two model::Value instances, with int↔float coercion.
796fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
797    match (a, b) {
798        (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
799        (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
800        _ => a.partial_cmp(b),
801    }
802}
803
804fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
805    match sql_val {
806        SqlValue::Null => Value::Null,
807        SqlValue::String(s) => {
808            match target {
809                Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
810                Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
811                Value::Date(_) => {
812                    chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
813                        .map(Value::Date)
814                        .unwrap_or(Value::String(s.clone()))
815                }
816                Value::DateTime(_) => {
817                    chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S")
818                        .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f"))
819                        .map(Value::DateTime)
820                        .unwrap_or(Value::String(s.clone()))
821                }
822                _ => Value::String(s.clone()),
823            }
824        }
825        SqlValue::Int(n) => {
826            match target {
827                Value::Float(_) => Value::Float(*n as f64),
828                _ => Value::Int(*n),
829            }
830        }
831        SqlValue::Float(f) => Value::Float(*f),
832        SqlValue::List(_) => Value::Null, // Lists handled separately
833    }
834}
835
836fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
837    // Special handling for lists (e.g., categories)
838    if let Value::List(items) = actual {
839        if let SqlValue::String(s) = expected {
840            return items.contains(s);
841        }
842    }
843
844    let coerced = coerce_sql_to_value(expected, actual);
845    actual == &coerced
846}
847
848fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
849    let pattern_str = match pattern {
850        SqlValue::String(s) => s,
851        _ => return false,
852    };
853
854    // Convert SQL LIKE to regex
855    let mut regex_str = String::from("(?is)^");
856    for ch in pattern_str.chars() {
857        match ch {
858            '%' => regex_str.push_str(".*"),
859            '_' => regex_str.push('.'),
860            c => {
861                if regex::escape(&c.to_string()) != c.to_string() {
862                    regex_str.push_str(&regex::escape(&c.to_string()));
863                } else {
864                    regex_str.push(c);
865                }
866            }
867        }
868    }
869    regex_str.push('$');
870
871    let re = match Regex::new(&regex_str) {
872        Ok(r) => r,
873        Err(_) => return false,
874    };
875
876    match actual {
877        Value::List(items) => items.iter().any(|item| re.is_match(item)),
878        _ => re.is_match(&actual.to_display_string()),
879    }
880}
881
882fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
883    let coerced = coerce_sql_to_value(expected, actual);
884    actual.partial_cmp(&coerced).map(|o| o)
885}
886
887/// Convert a SqlValue to a Value for index lookups (without a target type for coercion).
888fn sql_value_to_index_value(sv: &SqlValue) -> Value {
889    match sv {
890        SqlValue::String(s) => {
891            // Try datetime first (more specific)
892            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
893                return Value::DateTime(dt);
894            }
895            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") {
896                return Value::DateTime(dt);
897            }
898            // Try date
899            if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
900                return Value::Date(d);
901            }
902            Value::String(s.clone())
903        }
904        SqlValue::Int(n) => Value::Int(*n),
905        SqlValue::Float(f) => Value::Float(*f),
906        SqlValue::Null => Value::Null,
907        SqlValue::List(_) => Value::Null,
908    }
909}
910
911/// Try to use B-tree indexes to narrow the candidate row set.
912/// Returns Some(paths) if the entire WHERE clause could be resolved via index,
913/// or None if a full scan is needed.
914fn try_index_filter(
915    clause: &WhereClause,
916    index: &crate::index::TableIndex,
917) -> Option<std::collections::HashSet<String>> {
918    match clause {
919        WhereClause::Comparison(cmp) => {
920            if !index.has_index(&cmp.column) {
921                return None;
922            }
923            match cmp.op.as_str() {
924                "=" => {
925                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
926                    let paths = index.lookup_eq(&cmp.column, &val);
927                    Some(paths.into_iter().map(|s| s.to_string()).collect())
928                }
929                "<" => {
930                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
931                    // exclusive upper bound: use range with max < val
932                    // lookup_range is inclusive, so we get all <= val then remove exact matches
933                    let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
934                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
935                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
936                }
937                ">" => {
938                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
939                    let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
940                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
941                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
942                }
943                "<=" => {
944                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
945                    let paths = index.lookup_range(&cmp.column, None, Some(&val));
946                    Some(paths.into_iter().map(|s| s.to_string()).collect())
947                }
948                ">=" => {
949                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
950                    let paths = index.lookup_range(&cmp.column, Some(&val), None);
951                    Some(paths.into_iter().map(|s| s.to_string()).collect())
952                }
953                "IN" => {
954                    if let Some(SqlValue::List(items)) = &cmp.value {
955                        let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
956                        let paths = index.lookup_in(&cmp.column, &vals);
957                        Some(paths.into_iter().map(|s| s.to_string()).collect())
958                    } else {
959                        None
960                    }
961                }
962                _ => None, // LIKE, IS NULL, etc. can't use index
963            }
964        }
965        WhereClause::BoolOp(bop) => {
966            let left = try_index_filter(&bop.left, index);
967            let right = try_index_filter(&bop.right, index);
968            match bop.op.as_str() {
969                "AND" => {
970                    match (left, right) {
971                        (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
972                        (Some(l), None) => Some(l), // narrow with left, scan-verify right
973                        (None, Some(r)) => Some(r),
974                        (None, None) => None,
975                    }
976                }
977                "OR" => {
978                    match (left, right) {
979                        (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
980                        _ => None, // Can't use index if either side needs full scan
981                    }
982                }
983                _ => None,
984            }
985        }
986    }
987}
988
989/// If an ORDER BY column matches a SELECT alias, replace its expr with the
990/// aliased expression so sorting uses the computed value.
991fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
992    let named = match columns {
993        ColumnList::Named(exprs) => exprs,
994        _ => return specs.to_vec(),
995    };
996
997    // Build alias → expr map
998    let alias_map: HashMap<String, &Expr> = named
999        .iter()
1000        .filter_map(|se| match se {
1001            SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
1002            _ => None,
1003        })
1004        .collect();
1005
1006    specs
1007        .iter()
1008        .map(|spec| {
1009            // If the ORDER BY column name matches a SELECT alias, use that expression
1010            if let Some(expr) = alias_map.get(&spec.column) {
1011                OrderSpec {
1012                    column: spec.column.clone(),
1013                    expr: Some((*expr).clone()),
1014                    descending: spec.descending,
1015                }
1016            } else {
1017                spec.clone()
1018            }
1019        })
1020        .collect()
1021}
1022
1023fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
1024    rows.sort_by(|a, b| {
1025        for spec in specs {
1026            let (va, vb) = if let Some(ref expr) = spec.expr {
1027                (evaluate_expr(expr, a), evaluate_expr(expr, b))
1028            } else {
1029                (
1030                    a.get(&spec.column).cloned().unwrap_or(Value::Null),
1031                    b.get(&spec.column).cloned().unwrap_or(Value::Null),
1032                )
1033            };
1034
1035            // NULLs sort last
1036            let ordering = match (&va, &vb) {
1037                (Value::Null, Value::Null) => Ordering::Equal,
1038                (Value::Null, _) => Ordering::Greater,
1039                (_, Value::Null) => Ordering::Less,
1040                (a_val, b_val) => {
1041                    compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
1042                }
1043            };
1044
1045            let ordering = if spec.descending {
1046                ordering.reverse()
1047            } else {
1048                ordering
1049            };
1050
1051            if ordering != Ordering::Equal {
1052                return ordering;
1053            }
1054        }
1055        Ordering::Equal
1056    });
1057}
1058
1059/// Convert a SqlValue to our model Value (for use in insert/update).
1060pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
1061    match sql_val {
1062        SqlValue::Null => Value::Null,
1063        SqlValue::String(s) => Value::String(s.clone()),
1064        SqlValue::Int(n) => Value::Int(*n),
1065        SqlValue::Float(f) => Value::Float(*f),
1066        SqlValue::List(items) => {
1067            let strings: Vec<String> = items
1068                .iter()
1069                .filter_map(|v| match v {
1070                    SqlValue::String(s) => Some(s.clone()),
1071                    _ => None,
1072                })
1073                .collect();
1074            Value::List(strings)
1075        }
1076    }
1077}
1078
1079#[cfg(test)]
1080mod tests {
1081    use super::*;
1082
1083    fn make_rows() -> Vec<Row> {
1084        vec![
1085            Row::from([
1086                ("path".into(), Value::String("a.md".into())),
1087                ("title".into(), Value::String("Alpha".into())),
1088                ("count".into(), Value::Int(10)),
1089            ]),
1090            Row::from([
1091                ("path".into(), Value::String("b.md".into())),
1092                ("title".into(), Value::String("Beta".into())),
1093                ("count".into(), Value::Int(5)),
1094            ]),
1095            Row::from([
1096                ("path".into(), Value::String("c.md".into())),
1097                ("title".into(), Value::String("Gamma".into())),
1098                ("count".into(), Value::Int(20)),
1099            ]),
1100        ]
1101    }
1102
1103    #[test]
1104    fn test_select_all() {
1105        let q = SelectQuery {
1106            columns: ColumnList::All,
1107            table: "test".into(),
1108            table_alias: None,
1109            joins: vec![],
1110            where_clause: None,
1111            group_by: None,
1112            having: None,
1113            order_by: None,
1114            limit: None,
1115        };
1116        let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
1117        assert_eq!(rows.len(), 3);
1118    }
1119
1120    #[test]
1121    fn test_where_gt() {
1122        let q = SelectQuery {
1123            columns: ColumnList::All,
1124            table: "test".into(),
1125            table_alias: None,
1126            joins: vec![],
1127            where_clause: Some(WhereClause::Comparison(Comparison {
1128                column: "count".into(),
1129                op: ">".into(),
1130                value: Some(SqlValue::Int(5)),
1131                left_expr: Some(Expr::Column("count".into())),
1132                right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1133            })),
1134            group_by: None,
1135            having: None,
1136            order_by: None,
1137            limit: None,
1138        };
1139        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1140        assert_eq!(rows.len(), 2);
1141    }
1142
1143    #[test]
1144    fn test_order_by_desc() {
1145        let q = SelectQuery {
1146            columns: ColumnList::All,
1147            table: "test".into(),
1148            table_alias: None,
1149            joins: vec![],
1150            where_clause: None,
1151            group_by: None,
1152            having: None,
1153            order_by: Some(vec![OrderSpec {
1154                column: "count".into(),
1155                expr: Some(Expr::Column("count".into())),
1156                descending: true,
1157            }]),
1158            limit: None,
1159        };
1160        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1161        assert_eq!(rows[0]["count"], Value::Int(20));
1162        assert_eq!(rows[2]["count"], Value::Int(5));
1163    }
1164
1165    #[test]
1166    fn test_limit() {
1167        let q = SelectQuery {
1168            columns: ColumnList::All,
1169            table: "test".into(),
1170            table_alias: None,
1171            joins: vec![],
1172            where_clause: None,
1173            group_by: None,
1174            having: None,
1175            order_by: None,
1176            limit: Some(2),
1177        };
1178        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1179        assert_eq!(rows.len(), 2);
1180    }
1181
1182    #[test]
1183    fn test_like() {
1184        let q = SelectQuery {
1185            columns: ColumnList::All,
1186            table: "test".into(),
1187            table_alias: None,
1188            joins: vec![],
1189            where_clause: Some(WhereClause::Comparison(Comparison {
1190                column: "title".into(),
1191                op: "LIKE".into(),
1192                value: Some(SqlValue::String("%lph%".into())),
1193                left_expr: Some(Expr::Column("title".into())),
1194                right_expr: None,
1195            })),
1196            group_by: None,
1197            having: None,
1198            order_by: None,
1199            limit: None,
1200        };
1201        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1202        assert_eq!(rows.len(), 1);
1203        assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1204    }
1205
1206    #[test]
1207    fn test_is_null() {
1208        let mut rows = make_rows();
1209        rows[1].insert("optional".into(), Value::Null);
1210
1211        let q = SelectQuery {
1212            columns: ColumnList::All,
1213            table: "test".into(),
1214            table_alias: None,
1215            joins: vec![],
1216            where_clause: Some(WhereClause::Comparison(Comparison {
1217                column: "optional".into(),
1218                op: "IS NULL".into(),
1219                value: None,
1220                left_expr: Some(Expr::Column("optional".into())),
1221                right_expr: None,
1222            })),
1223            group_by: None,
1224            having: None,
1225            order_by: None,
1226            limit: None,
1227        };
1228        let (result, _) = execute(&q, &rows, None).unwrap();
1229        // All rows where optional is NULL or missing
1230        assert_eq!(result.len(), 3);
1231    }
1232
1233    // ── Expression evaluation tests ─────────────────────────��─────
1234
1235    #[test]
1236    fn test_evaluate_expr_literal() {
1237        let row = Row::new();
1238        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1239        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1240        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1241    }
1242
1243    #[test]
1244    fn test_evaluate_expr_column() {
1245        let row = Row::from([("x".into(), Value::Int(10))]);
1246        assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1247        assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1248    }
1249
1250    #[test]
1251    fn test_evaluate_expr_int_arithmetic() {
1252        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1253        let add = Expr::BinaryOp {
1254            left: Box::new(Expr::Column("a".into())),
1255            op: ArithOp::Add,
1256            right: Box::new(Expr::Column("b".into())),
1257        };
1258        assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1259
1260        let sub = Expr::BinaryOp {
1261            left: Box::new(Expr::Column("a".into())),
1262            op: ArithOp::Sub,
1263            right: Box::new(Expr::Column("b".into())),
1264        };
1265        assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1266
1267        let mul = Expr::BinaryOp {
1268            left: Box::new(Expr::Column("a".into())),
1269            op: ArithOp::Mul,
1270            right: Box::new(Expr::Column("b".into())),
1271        };
1272        assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1273
1274        let div = Expr::BinaryOp {
1275            left: Box::new(Expr::Column("a".into())),
1276            op: ArithOp::Div,
1277            right: Box::new(Expr::Column("b".into())),
1278        };
1279        assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); // integer division
1280
1281        let modulo = Expr::BinaryOp {
1282            left: Box::new(Expr::Column("a".into())),
1283            op: ArithOp::Mod,
1284            right: Box::new(Expr::Column("b".into())),
1285        };
1286        assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1287    }
1288
1289    #[test]
1290    fn test_evaluate_expr_float_coercion() {
1291        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1292        let add = Expr::BinaryOp {
1293            left: Box::new(Expr::Column("a".into())),
1294            op: ArithOp::Add,
1295            right: Box::new(Expr::Column("b".into())),
1296        };
1297        assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1298    }
1299
1300    #[test]
1301    fn test_evaluate_expr_null_propagation() {
1302        let row = Row::from([("a".into(), Value::Int(10))]);
1303        let add = Expr::BinaryOp {
1304            left: Box::new(Expr::Column("a".into())),
1305            op: ArithOp::Add,
1306            right: Box::new(Expr::Column("missing".into())),
1307        };
1308        assert_eq!(evaluate_expr(&add, &row), Value::Null);
1309    }
1310
1311    #[test]
1312    fn test_evaluate_expr_div_by_zero() {
1313        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1314        let div = Expr::BinaryOp {
1315            left: Box::new(Expr::Column("a".into())),
1316            op: ArithOp::Div,
1317            right: Box::new(Expr::Column("b".into())),
1318        };
1319        assert_eq!(evaluate_expr(&div, &row), Value::Null);
1320    }
1321
1322    #[test]
1323    fn test_evaluate_expr_unary_minus() {
1324        let row = Row::from([("x".into(), Value::Int(5))]);
1325        let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1326        assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1327    }
1328
1329    #[test]
1330    fn test_select_with_expression() {
1331        // Integration test: SELECT count * 2 AS doubled FROM test
1332        let stmt = crate::query_parser::parse_query(
1333            "SELECT count * 2 AS doubled FROM test"
1334        ).unwrap();
1335        if let crate::query_parser::Statement::Select(q) = stmt {
1336            let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1337            assert_eq!(cols, vec!["doubled"]);
1338            assert_eq!(rows.len(), 3);
1339            // Rows are: count=10, count=5, count=20
1340            let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1341            assert!(values.contains(&Value::Int(20)));
1342            assert!(values.contains(&Value::Int(10)));
1343            assert!(values.contains(&Value::Int(40)));
1344        } else {
1345            panic!("Expected Select");
1346        }
1347    }
1348
1349    #[test]
1350    fn test_where_with_expression() {
1351        // SELECT * FROM test WHERE count * 2 > 15
1352        let stmt = crate::query_parser::parse_query(
1353            "SELECT * FROM test WHERE count * 2 > 15"
1354        ).unwrap();
1355        if let crate::query_parser::Statement::Select(q) = stmt {
1356            let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1357            // count=10 → 20 > 15 ✓, count=5 → 10 > 15 ✗, count=20 → 40 > 15 ✓
1358            assert_eq!(rows.len(), 2);
1359        } else {
1360            panic!("Expected Select");
1361        }
1362    }
1363
1364    #[test]
1365    fn test_order_by_expression() {
1366        // SELECT * FROM test ORDER BY count * -1 ASC (effectively DESC by count)
1367        let stmt = crate::query_parser::parse_query(
1368            "SELECT title, count FROM test ORDER BY count * -1 ASC"
1369        ).unwrap();
1370        if let crate::query_parser::Statement::Select(q) = stmt {
1371            let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1372            // count: 20 → -20, 10 → -10, 5 → -5, ASC means -20, -10, -5
1373            assert_eq!(rows[0]["count"], Value::Int(20));
1374            assert_eq!(rows[1]["count"], Value::Int(10));
1375            assert_eq!(rows[2]["count"], Value::Int(5));
1376        } else {
1377            panic!("Expected Select");
1378        }
1379    }
1380
1381    // ── CASE WHEN evaluation tests ────────────────────────────────
1382
1383    #[test]
1384    fn test_case_when_eval_basic() {
1385        let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1386        let expr = Expr::Case {
1387            whens: vec![(
1388                WhereClause::Comparison(Comparison {
1389                    column: "status".into(),
1390                    op: "=".into(),
1391                    value: Some(SqlValue::String("ACTIVE".into())),
1392                    left_expr: Some(Expr::Column("status".into())),
1393                    right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1394                }),
1395                Box::new(Expr::Literal(SqlValue::Int(1))),
1396            )],
1397            else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1398        };
1399        assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1400    }
1401
1402    #[test]
1403    fn test_case_when_eval_else() {
1404        let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1405        let expr = Expr::Case {
1406            whens: vec![(
1407                WhereClause::Comparison(Comparison {
1408                    column: "status".into(),
1409                    op: "=".into(),
1410                    value: Some(SqlValue::String("ACTIVE".into())),
1411                    left_expr: Some(Expr::Column("status".into())),
1412                    right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1413                }),
1414                Box::new(Expr::Literal(SqlValue::Int(1))),
1415            )],
1416            else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1417        };
1418        assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1419    }
1420
1421    #[test]
1422    fn test_case_when_eval_no_else_null() {
1423        let row = Row::from([("x".into(), Value::Int(99))]);
1424        let expr = Expr::Case {
1425            whens: vec![(
1426                WhereClause::Comparison(Comparison {
1427                    column: "x".into(),
1428                    op: "=".into(),
1429                    value: Some(SqlValue::Int(1)),
1430                    left_expr: Some(Expr::Column("x".into())),
1431                    right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1432                }),
1433                Box::new(Expr::Literal(SqlValue::String("one".into()))),
1434            )],
1435            else_expr: None,
1436        };
1437        assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1438    }
1439
1440    #[test]
1441    fn test_case_when_in_aggregate_query() {
1442        // SUM(CASE WHEN count > 5 THEN count ELSE 0 END)
1443        // Rows: count=10, count=5, count=20 → should sum 10 + 0 + 20 = 30
1444        let stmt = crate::query_parser::parse_query(
1445            "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1446        ).unwrap();
1447        if let crate::query_parser::Statement::Select(q) = stmt {
1448            let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1449            assert_eq!(cols, vec!["total"]);
1450            assert_eq!(rows.len(), 1);
1451            assert_eq!(rows[0]["total"], Value::Float(30.0));
1452        } else {
1453            panic!("Expected Select");
1454        }
1455    }
1456
1457    #[test]
1458    fn test_case_when_with_unary_minus_in_aggregate() {
1459        // SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END)
1460        // Alpha: 10, Beta: -5, Gamma: -20 → 10 - 5 - 20 = -15
1461        let stmt = crate::query_parser::parse_query(
1462            "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1463        ).unwrap();
1464        if let crate::query_parser::Statement::Select(q) = stmt {
1465            let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1466            assert_eq!(rows.len(), 1);
1467            assert_eq!(rows[0]["net"], Value::Float(-15.0));
1468        } else {
1469            panic!("Expected Select");
1470        }
1471    }
1472}