Skip to main content

mdql_core/
query_engine.rs

1//! Execute parsed queries over in-memory rows.
2
3use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14    query: &SelectQuery,
15    rows: &[Row],
16    _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18    execute(query, rows, None)
19}
20
21/// Execute a query with optional B-tree index and FTS searcher.
22pub fn execute_query_indexed(
23    query: &SelectQuery,
24    rows: &[Row],
25    schema: &Schema,
26    index: Option<&crate::index::TableIndex>,
27    searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29    // Pre-compute FTS results for any LIKE clauses on section columns
30    let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31        collect_fts_results(wc, schema, searcher)
32    } else {
33        HashMap::new()
34    };
35
36    execute_with_fts(query, rows, index, &fts_results)
37}
38
39/// Collect FTS results for LIKE comparisons on section columns.
40/// Returns a map from (column, pattern) → set of matching paths.
41fn collect_fts_results(
42    clause: &WhereClause,
43    schema: &Schema,
44    searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46    let mut results = HashMap::new();
47    collect_fts_results_inner(clause, schema, searcher, &mut results);
48    results
49}
50
51fn collect_fts_results_inner(
52    clause: &WhereClause,
53    schema: &Schema,
54    searcher: &crate::search::TableSearcher,
55    results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57    match clause {
58        WhereClause::Comparison(cmp) => {
59            if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60                if let Some(SqlValue::String(pattern)) = &cmp.value {
61                    // Strip SQL wildcards for Tantivy query
62                    let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63                    if !search_term.is_empty() {
64                        if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65                            let key = (cmp.column.clone(), pattern.clone());
66                            results.insert(key, paths.into_iter().collect());
67                        }
68                    }
69                }
70            }
71        }
72        WhereClause::BoolOp(bop) => {
73            collect_fts_results_inner(&bop.left, schema, searcher, results);
74            collect_fts_results_inner(&bop.right, schema, searcher, results);
75        }
76    }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82    query: &SelectQuery,
83    rows: &[Row],
84    index: Option<&crate::index::TableIndex>,
85    fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87    // Determine available columns
88    let mut all_columns: Vec<String> = Vec::new();
89    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90    for r in rows {
91        for k in r.keys() {
92            if seen.insert(k.clone()) {
93                all_columns.push(k.clone());
94            }
95        }
96    }
97
98    // Check if query has aggregates
99    let has_aggregates = match &query.columns {
100        ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101        _ => false,
102    };
103
104    // Output column names
105    let columns: Vec<String> = match &query.columns {
106        ColumnList::All => all_columns,
107        ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108    };
109
110    // Filter — try index first, fall back to full scan
111    let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112        let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113        if let Some(paths) = candidate_paths {
114            rows.iter()
115                .filter(|r| {
116                    r.get("path")
117                        .and_then(|v| v.as_str())
118                        .map_or(false, |p| paths.contains(p))
119                })
120                .filter(|r| evaluate_with_fts(wc, r, fts))
121                .cloned()
122                .collect()
123        } else {
124            rows.iter()
125                .filter(|r| evaluate_with_fts(wc, r, fts))
126                .cloned()
127                .collect()
128        }
129    } else {
130        rows.to_vec()
131    };
132
133    // Aggregate if needed
134    let mut result = if has_aggregates || query.group_by.is_some() {
135        let exprs = match &query.columns {
136            ColumnList::Named(exprs) => exprs.clone(),
137            _ => return Err(MdqlError::QueryExecution(
138                "SELECT * with GROUP BY is not supported".into(),
139            )),
140        };
141        let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142        aggregate_rows(&filtered, &exprs, group_keys)?
143    } else {
144        filtered
145    };
146
147    // Sort — resolve ORDER BY aliases against SELECT list
148    if let Some(ref order_by) = query.order_by {
149        let resolved = resolve_order_aliases(order_by, &query.columns);
150        sort_rows(&mut result, &resolved);
151    }
152
153    // Limit
154    if let Some(limit) = query.limit {
155        result.truncate(limit as usize);
156    }
157
158    // Project — evaluate expressions and strip to requested columns
159    if !matches!(query.columns, ColumnList::All) {
160        let named_exprs = match &query.columns {
161            ColumnList::Named(exprs) => exprs,
162            _ => unreachable!(),
163        };
164
165        // Compute expression columns first, then retain only requested columns
166        let has_expr_cols = named_exprs.iter().any(|e| matches!(e, SelectExpr::Expr { .. }));
167        if has_expr_cols {
168            for row in &mut result {
169                for expr in named_exprs {
170                    if let SelectExpr::Expr { expr: e, alias } = expr {
171                        let name = alias.clone().unwrap_or_else(|| e.display_name());
172                        let val = evaluate_expr(e, row);
173                        row.insert(name, val);
174                    }
175                }
176            }
177        }
178
179        let col_set: std::collections::HashSet<&str> =
180            columns.iter().map(|s| s.as_str()).collect();
181        for row in &mut result {
182            row.retain(|k, _| col_set.contains(k.as_str()));
183        }
184    }
185
186    Ok((result, columns))
187}
188
189fn aggregate_rows(
190    rows: &[Row],
191    exprs: &[SelectExpr],
192    group_keys: &[String],
193) -> crate::errors::Result<Vec<Row>> {
194    // Group rows by group_keys
195    let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
196    let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
197
198    if group_keys.is_empty() {
199        // No GROUP BY — all rows are one group
200        let all_refs: Vec<&Row> = rows.iter().collect();
201        groups.push((vec![], all_refs));
202    } else {
203        for row in rows {
204            let key: Vec<String> = group_keys
205                .iter()
206                .map(|k| {
207                    row.get(k)
208                        .map(|v| v.to_display_string())
209                        .unwrap_or_default()
210                })
211                .collect();
212            let key_vals: Vec<Value> = group_keys
213                .iter()
214                .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
215                .collect();
216            if let Some(&idx) = key_index.get(&key) {
217                groups[idx].1.push(row);
218            } else {
219                let idx = groups.len();
220                key_index.insert(key, idx);
221                groups.push((key_vals, vec![row]));
222            }
223        }
224    }
225
226    // Compute aggregates per group
227    let mut result = Vec::new();
228    for (key_vals, group_rows) in &groups {
229        let mut out = Row::new();
230
231        // Fill in group key values
232        for (i, k) in group_keys.iter().enumerate() {
233            out.insert(k.clone(), key_vals[i].clone());
234        }
235
236        // Compute each expression
237        for expr in exprs {
238            match expr {
239                SelectExpr::Column(name) => {
240                    // Already filled if it's a group key; otherwise take first row's value
241                    if !out.contains_key(name) {
242                        if let Some(first) = group_rows.first() {
243                            out.insert(
244                                name.clone(),
245                                first.get(name).cloned().unwrap_or(Value::Null),
246                            );
247                        }
248                    }
249                }
250                SelectExpr::Aggregate { func, arg, arg_expr, alias } => {
251                    let out_name = alias
252                        .clone()
253                        .unwrap_or_else(|| expr.output_name());
254                    let val = compute_aggregate(func, arg, arg_expr.as_ref(), group_rows);
255                    out.insert(out_name, val);
256                }
257                SelectExpr::Expr { expr: e, alias } => {
258                    let out_name = alias.clone().unwrap_or_else(|| e.display_name());
259                    if let Some(first) = group_rows.first() {
260                        let val = evaluate_expr(e, first);
261                        out.insert(out_name, val);
262                    }
263                }
264            }
265        }
266
267        result.push(out);
268    }
269
270    Ok(result)
271}
272
273/// Resolve a per-row value for an aggregate argument.
274/// If `arg_expr` is set, evaluate it; otherwise look up `arg` as a column name.
275fn resolve_agg_value<'a>(arg: &str, arg_expr: Option<&Expr>, row: &'a Row) -> Value {
276    if let Some(expr) = arg_expr {
277        evaluate_expr(expr, row)
278    } else {
279        row.get(arg).cloned().unwrap_or(Value::Null)
280    }
281}
282
283fn compute_aggregate(func: &AggFunc, arg: &str, arg_expr: Option<&Expr>, rows: &[&Row]) -> Value {
284    match func {
285        AggFunc::Count => {
286            if arg == "*" && arg_expr.is_none() {
287                Value::Int(rows.len() as i64)
288            } else {
289                let count = rows
290                    .iter()
291                    .filter(|r| {
292                        let v = resolve_agg_value(arg, arg_expr, r);
293                        !v.is_null()
294                    })
295                    .count();
296                Value::Int(count as i64)
297            }
298        }
299        AggFunc::Sum => {
300            let mut total = 0.0f64;
301            let mut has_any = false;
302            for r in rows {
303                let v = resolve_agg_value(arg, arg_expr, r);
304                match v {
305                    Value::Int(n) => { total += n as f64; has_any = true; }
306                    Value::Float(f) => { total += f; has_any = true; }
307                    _ => {}
308                }
309            }
310            if has_any { Value::Float(total) } else { Value::Null }
311        }
312        AggFunc::Avg => {
313            let mut total = 0.0f64;
314            let mut count = 0usize;
315            for r in rows {
316                let v = resolve_agg_value(arg, arg_expr, r);
317                match v {
318                    Value::Int(n) => { total += n as f64; count += 1; }
319                    Value::Float(f) => { total += f; count += 1; }
320                    _ => {}
321                }
322            }
323            if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
324        }
325        AggFunc::Min => {
326            let mut min_val: Option<Value> = None;
327            for r in rows {
328                let v = resolve_agg_value(arg, arg_expr, r);
329                if v.is_null() { continue; }
330                min_val = Some(match min_val {
331                    None => v,
332                    Some(ref current) => {
333                        if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
334                            v
335                        } else {
336                            current.clone()
337                        }
338                    }
339                });
340            }
341            min_val.unwrap_or(Value::Null)
342        }
343        AggFunc::Max => {
344            let mut max_val: Option<Value> = None;
345            for r in rows {
346                let v = resolve_agg_value(arg, arg_expr, r);
347                if v.is_null() { continue; }
348                max_val = Some(match max_val {
349                    None => v,
350                    Some(ref current) => {
351                        if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
352                            v
353                        } else {
354                            current.clone()
355                        }
356                    }
357                });
358            }
359            max_val.unwrap_or(Value::Null)
360        }
361    }
362}
363
364fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
365    match clause {
366        WhereClause::BoolOp(bop) => {
367            let left = evaluate_with_fts(&bop.left, row, fts);
368            match bop.op.as_str() {
369                "AND" => left && evaluate_with_fts(&bop.right, row, fts),
370                "OR" => left || evaluate_with_fts(&bop.right, row, fts),
371                _ => false,
372            }
373        }
374        WhereClause::Comparison(cmp) => {
375            // Check if we have FTS results for this comparison
376            if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
377                if let Some(SqlValue::String(pattern)) = &cmp.value {
378                    let key = (cmp.column.clone(), pattern.clone());
379                    if let Some(matching_paths) = fts.get(&key) {
380                        let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
381                        let matched = matching_paths.contains(row_path);
382                        return if cmp.op == "LIKE" { matched } else { !matched };
383                    }
384                }
385            }
386            evaluate_comparison(cmp, row)
387        }
388    }
389}
390
391pub fn execute_join_query(
392    query: &SelectQuery,
393    tables: &HashMap<String, (Schema, Vec<Row>)>,
394) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
395    if query.joins.is_empty() {
396        return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
397    }
398
399    let left_name = &query.table;
400    let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
401
402    // Build alias→table mapping for all tables
403    let mut aliases: HashMap<String, String> = HashMap::new();
404    aliases.insert(left_name.clone(), left_name.clone());
405    if let Some(ref a) = query.table_alias {
406        aliases.insert(a.clone(), left_name.clone());
407    }
408    for join in &query.joins {
409        aliases.insert(join.table.clone(), join.table.clone());
410        if let Some(ref a) = join.alias {
411            aliases.insert(a.clone(), join.table.clone());
412        }
413    }
414
415    // Start with the left table rows, prefixed with alias
416    let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
417        MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
418    })?;
419
420    let mut current_rows: Vec<Row> = left_rows
421        .iter()
422        .map(|r| {
423            let mut prefixed = Row::new();
424            for (k, v) in r {
425                prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
426            }
427            prefixed
428        })
429        .collect();
430
431    // Process each JOIN sequentially
432    for join in &query.joins {
433        let right_name = &join.table;
434        let right_alias = join.alias.as_deref().unwrap_or(right_name);
435
436        let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
437            MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
438        })?;
439
440        // Resolve ON columns to determine which is left vs right
441        let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
442        let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
443
444        // Figure out which ON column refers to the new right table
445        let (left_key, right_key) = if on_right_table == *right_name {
446            // left_col is from the left side, right_col is from the right table
447            let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
448            (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
449        } else {
450            // right_col is from the left side, left_col is from the right table
451            let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
452            (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
453        };
454
455        // Build index on right table
456        let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
457        for r in right_rows {
458            if let Some(key) = r.get(&right_key) {
459                let key_str = key.to_display_string();
460                right_index.entry(key_str).or_default().push(r);
461            }
462        }
463
464        // Join current rows with right table
465        let mut next_rows: Vec<Row> = Vec::new();
466        for lr in &current_rows {
467            if let Some(key) = lr.get(&left_key) {
468                let key_str = key.to_display_string();
469                if let Some(matching) = right_index.get(&key_str) {
470                    for rr in matching {
471                        let mut merged = lr.clone();
472                        for (k, v) in *rr {
473                            merged.insert(format!("{}.{}", right_alias, k), v.clone());
474                        }
475                        next_rows.push(merged);
476                    }
477                }
478            }
479        }
480        current_rows = next_rows;
481    }
482
483    let (mut result, columns) = execute(query, &current_rows, None)?;
484
485    // Add unprefixed aliases for non-colliding column names in the output.
486    // e.g., if result has s.title and b.sharpe (no other "title" or "sharpe"),
487    // add "title" and "sharpe" as shorthand keys.
488    if !result.is_empty() {
489        let mut base_counts: HashMap<String, usize> = HashMap::new();
490        for key in &columns {
491            if let Some((_prefix, base)) = key.split_once('.') {
492                *base_counts.entry(base.to_string()).or_default() += 1;
493            }
494        }
495        let unique_bases: Vec<String> = base_counts
496            .into_iter()
497            .filter(|(_, count)| *count == 1)
498            .map(|(base, _)| base)
499            .collect();
500
501        if !unique_bases.is_empty() {
502            let unique_set: std::collections::HashSet<&str> =
503                unique_bases.iter().map(|s| s.as_str()).collect();
504            for row in &mut result {
505                let additions: Vec<(String, Value)> = row
506                    .iter()
507                    .filter_map(|(k, v)| {
508                        k.split_once('.').and_then(|(_, base)| {
509                            if unique_set.contains(base) {
510                                Some((base.to_string(), v.clone()))
511                            } else {
512                                None
513                            }
514                        })
515                    })
516                    .collect();
517                for (k, v) in additions {
518                    row.insert(k, v);
519                }
520            }
521        }
522    }
523
524    Ok((result, columns))
525}
526
527/// Given a table name, find the alias used for it.
528fn reverse_alias(
529    table_name: &str,
530    aliases: &HashMap<String, String>,
531    query: &SelectQuery,
532    joins: &[JoinClause],
533) -> String {
534    // Check if the FROM table matches
535    if query.table == table_name {
536        return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
537    }
538    // Check join tables
539    for j in joins {
540        if j.table == table_name {
541            return j.alias.as_deref().unwrap_or(&j.table).to_string();
542        }
543    }
544    // Fall back: check if table_name is itself an alias
545    if aliases.contains_key(table_name) {
546        return table_name.to_string();
547    }
548    table_name.to_string()
549}
550
551fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
552    if let Some((alias, column)) = col.split_once('.') {
553        let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
554        (table, column.to_string())
555    } else {
556        (String::new(), col.to_string())
557    }
558}
559
560fn execute(
561    query: &SelectQuery,
562    rows: &[Row],
563    index: Option<&crate::index::TableIndex>,
564) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
565    let empty_fts = HashMap::new();
566    execute_with_fts(query, rows, index, &empty_fts)
567}
568
569pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
570    match clause {
571        WhereClause::BoolOp(bop) => {
572            let left = evaluate(&bop.left, row);
573            match bop.op.as_str() {
574                "AND" => left && evaluate(&bop.right, row),
575                "OR" => left || evaluate(&bop.right, row),
576                _ => false,
577            }
578        }
579        WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
580    }
581}
582
583/// Evaluate an Expr against a row, returning a Value.
584pub fn evaluate_expr(expr: &Expr, row: &Row) -> Value {
585    match expr {
586        Expr::Literal(SqlValue::Int(n)) => Value::Int(*n),
587        Expr::Literal(SqlValue::Float(f)) => Value::Float(*f),
588        Expr::Literal(SqlValue::String(s)) => Value::String(s.clone()),
589        Expr::Literal(SqlValue::Null) => Value::Null,
590        Expr::Literal(SqlValue::List(_)) => Value::Null,
591        Expr::Column(name) => {
592            if let Some(val) = row.get(name) {
593                return val.clone();
594            }
595            if let Some((dict_col, dict_key)) = name.split_once('.') {
596                if let Some(Value::Dict(map)) = row.get(dict_col) {
597                    return map.get(dict_key).cloned().unwrap_or(Value::Null);
598                }
599            }
600            Value::Null
601        }
602        Expr::UnaryMinus(inner) => {
603            match evaluate_expr(inner, row) {
604                Value::Int(n) => Value::Int(-n),
605                Value::Float(f) => Value::Float(-f),
606                Value::Null => Value::Null,
607                _ => Value::Null, // non-numeric → NULL
608            }
609        }
610        Expr::BinaryOp { left, op, right } => {
611            let lv = evaluate_expr(left, row);
612            let rv = evaluate_expr(right, row);
613
614            // NULL propagation: any NULL operand → NULL
615            if lv.is_null() || rv.is_null() {
616                return Value::Null;
617            }
618
619            // Extract numeric values with int→float coercion
620            match (&lv, &rv) {
621                (Value::Int(a), Value::Int(b)) => {
622                    match op {
623                        ArithOp::Add => Value::Int(a.wrapping_add(*b)),
624                        ArithOp::Sub => Value::Int(a.wrapping_sub(*b)),
625                        ArithOp::Mul => Value::Int(a.wrapping_mul(*b)),
626                        ArithOp::Div => {
627                            if *b == 0 { Value::Null } else { Value::Int(a / b) }
628                        }
629                        ArithOp::Mod => {
630                            if *b == 0 { Value::Null } else { Value::Int(a % b) }
631                        }
632                    }
633                }
634                _ => {
635                    // Coerce to float
636                    let a = match &lv {
637                        Value::Int(n) => *n as f64,
638                        Value::Float(f) => *f,
639                        _ => return Value::Null,
640                    };
641                    let b = match &rv {
642                        Value::Int(n) => *n as f64,
643                        Value::Float(f) => *f,
644                        _ => return Value::Null,
645                    };
646                    match op {
647                        ArithOp::Add => Value::Float(a + b),
648                        ArithOp::Sub => Value::Float(a - b),
649                        ArithOp::Mul => Value::Float(a * b),
650                        ArithOp::Div => {
651                            if b == 0.0 { Value::Null } else { Value::Float(a / b) }
652                        }
653                        ArithOp::Mod => {
654                            if b == 0.0 { Value::Null } else { Value::Float(a % b) }
655                        }
656                    }
657                }
658            }
659        }
660        Expr::Case { whens, else_expr } => {
661            for (condition, result) in whens {
662                if evaluate(condition, row) {
663                    return evaluate_expr(result, row);
664                }
665            }
666            match else_expr {
667                Some(e) => evaluate_expr(e, row),
668                None => Value::Null,
669            }
670        }
671    }
672}
673
674fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
675    // If we have expression-based comparison (new path), use it for standard ops
676    if let (Some(left_expr), Some(right_expr)) = (&cmp.left_expr, &cmp.right_expr) {
677        if ["=", "!=", "<", ">", "<=", ">="].contains(&cmp.op.as_str()) {
678            let left_val = evaluate_expr(left_expr, row);
679            let right_val = evaluate_expr(right_expr, row);
680
681            // NULL comparison: always false (except IS NULL handled below)
682            if left_val.is_null() || right_val.is_null() {
683                return false;
684            }
685
686            // Coerce for comparison: if types differ, try int→float
687            let ord = compare_model_values(&left_val, &right_val);
688
689            return match cmp.op.as_str() {
690                "=" => ord == Some(Ordering::Equal),
691                "!=" => ord != Some(Ordering::Equal),
692                "<" => ord == Some(Ordering::Less),
693                ">" => ord == Some(Ordering::Greater),
694                "<=" => matches!(ord, Some(Ordering::Less | Ordering::Equal)),
695                ">=" => matches!(ord, Some(Ordering::Greater | Ordering::Equal)),
696                _ => false,
697            };
698        }
699    }
700
701    // Fall back to legacy column-based comparison for IS NULL, IN, LIKE, etc.
702    let actual = row.get(&cmp.column);
703
704    if cmp.op == "IS NULL" {
705        return actual.map_or(true, |v| v.is_null());
706    }
707    if cmp.op == "IS NOT NULL" {
708        return actual.map_or(false, |v| !v.is_null());
709    }
710
711    let actual = match actual {
712        Some(v) if !v.is_null() => v,
713        _ => return false,
714    };
715
716    let expected = match &cmp.value {
717        Some(v) => v,
718        None => return false,
719    };
720
721    match cmp.op.as_str() {
722        "=" => eq_match(actual, expected),
723        "!=" => !eq_match(actual, expected),
724        "<" => compare_values(actual, expected) == Some(Ordering::Less),
725        ">" => compare_values(actual, expected) == Some(Ordering::Greater),
726        "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
727        ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
728        "LIKE" => like_match(actual, expected),
729        "NOT LIKE" => !like_match(actual, expected),
730        "IN" => {
731            if let SqlValue::List(items) = expected {
732                items.iter().any(|v| eq_match(actual, v))
733            } else {
734                eq_match(actual, expected)
735            }
736        }
737        _ => false,
738    }
739}
740
741/// Compare two model::Value instances, with int↔float coercion.
742fn compare_model_values(a: &Value, b: &Value) -> Option<Ordering> {
743    match (a, b) {
744        (Value::Int(x), Value::Float(y)) => (*x as f64).partial_cmp(y),
745        (Value::Float(x), Value::Int(y)) => x.partial_cmp(&(*y as f64)),
746        _ => a.partial_cmp(b),
747    }
748}
749
750fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
751    match sql_val {
752        SqlValue::Null => Value::Null,
753        SqlValue::String(s) => {
754            match target {
755                Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
756                Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
757                Value::Date(_) => {
758                    chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
759                        .map(Value::Date)
760                        .unwrap_or(Value::String(s.clone()))
761                }
762                Value::DateTime(_) => {
763                    chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S")
764                        .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f"))
765                        .map(Value::DateTime)
766                        .unwrap_or(Value::String(s.clone()))
767                }
768                _ => Value::String(s.clone()),
769            }
770        }
771        SqlValue::Int(n) => {
772            match target {
773                Value::Float(_) => Value::Float(*n as f64),
774                _ => Value::Int(*n),
775            }
776        }
777        SqlValue::Float(f) => Value::Float(*f),
778        SqlValue::List(_) => Value::Null, // Lists handled separately
779    }
780}
781
782fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
783    // Special handling for lists (e.g., categories)
784    if let Value::List(items) = actual {
785        if let SqlValue::String(s) = expected {
786            return items.contains(s);
787        }
788    }
789
790    let coerced = coerce_sql_to_value(expected, actual);
791    actual == &coerced
792}
793
794fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
795    let pattern_str = match pattern {
796        SqlValue::String(s) => s,
797        _ => return false,
798    };
799
800    // Convert SQL LIKE to regex
801    let mut regex_str = String::from("(?is)^");
802    for ch in pattern_str.chars() {
803        match ch {
804            '%' => regex_str.push_str(".*"),
805            '_' => regex_str.push('.'),
806            c => {
807                if regex::escape(&c.to_string()) != c.to_string() {
808                    regex_str.push_str(&regex::escape(&c.to_string()));
809                } else {
810                    regex_str.push(c);
811                }
812            }
813        }
814    }
815    regex_str.push('$');
816
817    let re = match Regex::new(&regex_str) {
818        Ok(r) => r,
819        Err(_) => return false,
820    };
821
822    match actual {
823        Value::List(items) => items.iter().any(|item| re.is_match(item)),
824        _ => re.is_match(&actual.to_display_string()),
825    }
826}
827
828fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
829    let coerced = coerce_sql_to_value(expected, actual);
830    actual.partial_cmp(&coerced).map(|o| o)
831}
832
833/// Convert a SqlValue to a Value for index lookups (without a target type for coercion).
834fn sql_value_to_index_value(sv: &SqlValue) -> Value {
835    match sv {
836        SqlValue::String(s) => {
837            // Try datetime first (more specific)
838            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S") {
839                return Value::DateTime(dt);
840            }
841            if let Ok(dt) = chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%dT%H:%M:%S%.f") {
842                return Value::DateTime(dt);
843            }
844            // Try date
845            if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
846                return Value::Date(d);
847            }
848            Value::String(s.clone())
849        }
850        SqlValue::Int(n) => Value::Int(*n),
851        SqlValue::Float(f) => Value::Float(*f),
852        SqlValue::Null => Value::Null,
853        SqlValue::List(_) => Value::Null,
854    }
855}
856
857/// Try to use B-tree indexes to narrow the candidate row set.
858/// Returns Some(paths) if the entire WHERE clause could be resolved via index,
859/// or None if a full scan is needed.
860fn try_index_filter(
861    clause: &WhereClause,
862    index: &crate::index::TableIndex,
863) -> Option<std::collections::HashSet<String>> {
864    match clause {
865        WhereClause::Comparison(cmp) => {
866            if !index.has_index(&cmp.column) {
867                return None;
868            }
869            match cmp.op.as_str() {
870                "=" => {
871                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
872                    let paths = index.lookup_eq(&cmp.column, &val);
873                    Some(paths.into_iter().map(|s| s.to_string()).collect())
874                }
875                "<" => {
876                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
877                    // exclusive upper bound: use range with max < val
878                    // lookup_range is inclusive, so we get all <= val then remove exact matches
879                    let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
880                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
881                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
882                }
883                ">" => {
884                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
885                    let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
886                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
887                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
888                }
889                "<=" => {
890                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
891                    let paths = index.lookup_range(&cmp.column, None, Some(&val));
892                    Some(paths.into_iter().map(|s| s.to_string()).collect())
893                }
894                ">=" => {
895                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
896                    let paths = index.lookup_range(&cmp.column, Some(&val), None);
897                    Some(paths.into_iter().map(|s| s.to_string()).collect())
898                }
899                "IN" => {
900                    if let Some(SqlValue::List(items)) = &cmp.value {
901                        let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
902                        let paths = index.lookup_in(&cmp.column, &vals);
903                        Some(paths.into_iter().map(|s| s.to_string()).collect())
904                    } else {
905                        None
906                    }
907                }
908                _ => None, // LIKE, IS NULL, etc. can't use index
909            }
910        }
911        WhereClause::BoolOp(bop) => {
912            let left = try_index_filter(&bop.left, index);
913            let right = try_index_filter(&bop.right, index);
914            match bop.op.as_str() {
915                "AND" => {
916                    match (left, right) {
917                        (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
918                        (Some(l), None) => Some(l), // narrow with left, scan-verify right
919                        (None, Some(r)) => Some(r),
920                        (None, None) => None,
921                    }
922                }
923                "OR" => {
924                    match (left, right) {
925                        (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
926                        _ => None, // Can't use index if either side needs full scan
927                    }
928                }
929                _ => None,
930            }
931        }
932    }
933}
934
935/// If an ORDER BY column matches a SELECT alias, replace its expr with the
936/// aliased expression so sorting uses the computed value.
937fn resolve_order_aliases(specs: &[OrderSpec], columns: &ColumnList) -> Vec<OrderSpec> {
938    let named = match columns {
939        ColumnList::Named(exprs) => exprs,
940        _ => return specs.to_vec(),
941    };
942
943    // Build alias → expr map
944    let alias_map: HashMap<String, &Expr> = named
945        .iter()
946        .filter_map(|se| match se {
947            SelectExpr::Expr { expr, alias: Some(a) } => Some((a.clone(), expr)),
948            _ => None,
949        })
950        .collect();
951
952    specs
953        .iter()
954        .map(|spec| {
955            // If the ORDER BY column name matches a SELECT alias, use that expression
956            if let Some(expr) = alias_map.get(&spec.column) {
957                OrderSpec {
958                    column: spec.column.clone(),
959                    expr: Some((*expr).clone()),
960                    descending: spec.descending,
961                }
962            } else {
963                spec.clone()
964            }
965        })
966        .collect()
967}
968
969fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
970    rows.sort_by(|a, b| {
971        for spec in specs {
972            let (va, vb) = if let Some(ref expr) = spec.expr {
973                (evaluate_expr(expr, a), evaluate_expr(expr, b))
974            } else {
975                (
976                    a.get(&spec.column).cloned().unwrap_or(Value::Null),
977                    b.get(&spec.column).cloned().unwrap_or(Value::Null),
978                )
979            };
980
981            // NULLs sort last
982            let ordering = match (&va, &vb) {
983                (Value::Null, Value::Null) => Ordering::Equal,
984                (Value::Null, _) => Ordering::Greater,
985                (_, Value::Null) => Ordering::Less,
986                (a_val, b_val) => {
987                    compare_model_values(a_val, b_val).unwrap_or(Ordering::Equal)
988                }
989            };
990
991            let ordering = if spec.descending {
992                ordering.reverse()
993            } else {
994                ordering
995            };
996
997            if ordering != Ordering::Equal {
998                return ordering;
999            }
1000        }
1001        Ordering::Equal
1002    });
1003}
1004
1005/// Convert a SqlValue to our model Value (for use in insert/update).
1006pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
1007    match sql_val {
1008        SqlValue::Null => Value::Null,
1009        SqlValue::String(s) => Value::String(s.clone()),
1010        SqlValue::Int(n) => Value::Int(*n),
1011        SqlValue::Float(f) => Value::Float(*f),
1012        SqlValue::List(items) => {
1013            let strings: Vec<String> = items
1014                .iter()
1015                .filter_map(|v| match v {
1016                    SqlValue::String(s) => Some(s.clone()),
1017                    _ => None,
1018                })
1019                .collect();
1020            Value::List(strings)
1021        }
1022    }
1023}
1024
1025#[cfg(test)]
1026mod tests {
1027    use super::*;
1028
1029    fn make_rows() -> Vec<Row> {
1030        vec![
1031            Row::from([
1032                ("path".into(), Value::String("a.md".into())),
1033                ("title".into(), Value::String("Alpha".into())),
1034                ("count".into(), Value::Int(10)),
1035            ]),
1036            Row::from([
1037                ("path".into(), Value::String("b.md".into())),
1038                ("title".into(), Value::String("Beta".into())),
1039                ("count".into(), Value::Int(5)),
1040            ]),
1041            Row::from([
1042                ("path".into(), Value::String("c.md".into())),
1043                ("title".into(), Value::String("Gamma".into())),
1044                ("count".into(), Value::Int(20)),
1045            ]),
1046        ]
1047    }
1048
1049    #[test]
1050    fn test_select_all() {
1051        let q = SelectQuery {
1052            columns: ColumnList::All,
1053            table: "test".into(),
1054            table_alias: None,
1055            joins: vec![],
1056            where_clause: None,
1057            group_by: None,
1058            order_by: None,
1059            limit: None,
1060        };
1061        let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
1062        assert_eq!(rows.len(), 3);
1063    }
1064
1065    #[test]
1066    fn test_where_gt() {
1067        let q = SelectQuery {
1068            columns: ColumnList::All,
1069            table: "test".into(),
1070            table_alias: None,
1071            joins: vec![],
1072            where_clause: Some(WhereClause::Comparison(Comparison {
1073                column: "count".into(),
1074                op: ">".into(),
1075                value: Some(SqlValue::Int(5)),
1076                left_expr: Some(Expr::Column("count".into())),
1077                right_expr: Some(Expr::Literal(SqlValue::Int(5))),
1078            })),
1079            group_by: None,
1080            order_by: None,
1081            limit: None,
1082        };
1083        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1084        assert_eq!(rows.len(), 2);
1085    }
1086
1087    #[test]
1088    fn test_order_by_desc() {
1089        let q = SelectQuery {
1090            columns: ColumnList::All,
1091            table: "test".into(),
1092            table_alias: None,
1093            joins: vec![],
1094            where_clause: None,
1095            group_by: None,
1096            order_by: Some(vec![OrderSpec {
1097                column: "count".into(),
1098                expr: Some(Expr::Column("count".into())),
1099                descending: true,
1100            }]),
1101            limit: None,
1102        };
1103        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1104        assert_eq!(rows[0]["count"], Value::Int(20));
1105        assert_eq!(rows[2]["count"], Value::Int(5));
1106    }
1107
1108    #[test]
1109    fn test_limit() {
1110        let q = SelectQuery {
1111            columns: ColumnList::All,
1112            table: "test".into(),
1113            table_alias: None,
1114            joins: vec![],
1115            where_clause: None,
1116            group_by: None,
1117            order_by: None,
1118            limit: Some(2),
1119        };
1120        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1121        assert_eq!(rows.len(), 2);
1122    }
1123
1124    #[test]
1125    fn test_like() {
1126        let q = SelectQuery {
1127            columns: ColumnList::All,
1128            table: "test".into(),
1129            table_alias: None,
1130            joins: vec![],
1131            where_clause: Some(WhereClause::Comparison(Comparison {
1132                column: "title".into(),
1133                op: "LIKE".into(),
1134                value: Some(SqlValue::String("%lph%".into())),
1135                left_expr: Some(Expr::Column("title".into())),
1136                right_expr: None,
1137            })),
1138            group_by: None,
1139            order_by: None,
1140            limit: None,
1141        };
1142        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1143        assert_eq!(rows.len(), 1);
1144        assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
1145    }
1146
1147    #[test]
1148    fn test_is_null() {
1149        let mut rows = make_rows();
1150        rows[1].insert("optional".into(), Value::Null);
1151
1152        let q = SelectQuery {
1153            columns: ColumnList::All,
1154            table: "test".into(),
1155            table_alias: None,
1156            joins: vec![],
1157            where_clause: Some(WhereClause::Comparison(Comparison {
1158                column: "optional".into(),
1159                op: "IS NULL".into(),
1160                value: None,
1161                left_expr: Some(Expr::Column("optional".into())),
1162                right_expr: None,
1163            })),
1164            group_by: None,
1165            order_by: None,
1166            limit: None,
1167        };
1168        let (result, _) = execute(&q, &rows, None).unwrap();
1169        // All rows where optional is NULL or missing
1170        assert_eq!(result.len(), 3);
1171    }
1172
1173    // ── Expression evaluation tests ─────────────────────────��─────
1174
1175    #[test]
1176    fn test_evaluate_expr_literal() {
1177        let row = Row::new();
1178        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Int(42)), &row), Value::Int(42));
1179        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Float(3.14)), &row), Value::Float(3.14));
1180        assert_eq!(evaluate_expr(&Expr::Literal(SqlValue::Null), &row), Value::Null);
1181    }
1182
1183    #[test]
1184    fn test_evaluate_expr_column() {
1185        let row = Row::from([("x".into(), Value::Int(10))]);
1186        assert_eq!(evaluate_expr(&Expr::Column("x".into()), &row), Value::Int(10));
1187        assert_eq!(evaluate_expr(&Expr::Column("missing".into()), &row), Value::Null);
1188    }
1189
1190    #[test]
1191    fn test_evaluate_expr_int_arithmetic() {
1192        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(3))]);
1193        let add = Expr::BinaryOp {
1194            left: Box::new(Expr::Column("a".into())),
1195            op: ArithOp::Add,
1196            right: Box::new(Expr::Column("b".into())),
1197        };
1198        assert_eq!(evaluate_expr(&add, &row), Value::Int(13));
1199
1200        let sub = Expr::BinaryOp {
1201            left: Box::new(Expr::Column("a".into())),
1202            op: ArithOp::Sub,
1203            right: Box::new(Expr::Column("b".into())),
1204        };
1205        assert_eq!(evaluate_expr(&sub, &row), Value::Int(7));
1206
1207        let mul = Expr::BinaryOp {
1208            left: Box::new(Expr::Column("a".into())),
1209            op: ArithOp::Mul,
1210            right: Box::new(Expr::Column("b".into())),
1211        };
1212        assert_eq!(evaluate_expr(&mul, &row), Value::Int(30));
1213
1214        let div = Expr::BinaryOp {
1215            left: Box::new(Expr::Column("a".into())),
1216            op: ArithOp::Div,
1217            right: Box::new(Expr::Column("b".into())),
1218        };
1219        assert_eq!(evaluate_expr(&div, &row), Value::Int(3)); // integer division
1220
1221        let modulo = Expr::BinaryOp {
1222            left: Box::new(Expr::Column("a".into())),
1223            op: ArithOp::Mod,
1224            right: Box::new(Expr::Column("b".into())),
1225        };
1226        assert_eq!(evaluate_expr(&modulo, &row), Value::Int(1));
1227    }
1228
1229    #[test]
1230    fn test_evaluate_expr_float_coercion() {
1231        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Float(3.0))]);
1232        let add = Expr::BinaryOp {
1233            left: Box::new(Expr::Column("a".into())),
1234            op: ArithOp::Add,
1235            right: Box::new(Expr::Column("b".into())),
1236        };
1237        assert_eq!(evaluate_expr(&add, &row), Value::Float(13.0));
1238    }
1239
1240    #[test]
1241    fn test_evaluate_expr_null_propagation() {
1242        let row = Row::from([("a".into(), Value::Int(10))]);
1243        let add = Expr::BinaryOp {
1244            left: Box::new(Expr::Column("a".into())),
1245            op: ArithOp::Add,
1246            right: Box::new(Expr::Column("missing".into())),
1247        };
1248        assert_eq!(evaluate_expr(&add, &row), Value::Null);
1249    }
1250
1251    #[test]
1252    fn test_evaluate_expr_div_by_zero() {
1253        let row = Row::from([("a".into(), Value::Int(10)), ("b".into(), Value::Int(0))]);
1254        let div = Expr::BinaryOp {
1255            left: Box::new(Expr::Column("a".into())),
1256            op: ArithOp::Div,
1257            right: Box::new(Expr::Column("b".into())),
1258        };
1259        assert_eq!(evaluate_expr(&div, &row), Value::Null);
1260    }
1261
1262    #[test]
1263    fn test_evaluate_expr_unary_minus() {
1264        let row = Row::from([("x".into(), Value::Int(5))]);
1265        let neg = Expr::UnaryMinus(Box::new(Expr::Column("x".into())));
1266        assert_eq!(evaluate_expr(&neg, &row), Value::Int(-5));
1267    }
1268
1269    #[test]
1270    fn test_select_with_expression() {
1271        // Integration test: SELECT count * 2 AS doubled FROM test
1272        let stmt = crate::query_parser::parse_query(
1273            "SELECT count * 2 AS doubled FROM test"
1274        ).unwrap();
1275        if let crate::query_parser::Statement::Select(q) = stmt {
1276            let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1277            assert_eq!(cols, vec!["doubled"]);
1278            assert_eq!(rows.len(), 3);
1279            // Rows are: count=10, count=5, count=20
1280            let values: Vec<Value> = rows.iter().map(|r| r["doubled"].clone()).collect();
1281            assert!(values.contains(&Value::Int(20)));
1282            assert!(values.contains(&Value::Int(10)));
1283            assert!(values.contains(&Value::Int(40)));
1284        } else {
1285            panic!("Expected Select");
1286        }
1287    }
1288
1289    #[test]
1290    fn test_where_with_expression() {
1291        // SELECT * FROM test WHERE count * 2 > 15
1292        let stmt = crate::query_parser::parse_query(
1293            "SELECT * FROM test WHERE count * 2 > 15"
1294        ).unwrap();
1295        if let crate::query_parser::Statement::Select(q) = stmt {
1296            let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1297            // count=10 → 20 > 15 ✓, count=5 → 10 > 15 ✗, count=20 → 40 > 15 ✓
1298            assert_eq!(rows.len(), 2);
1299        } else {
1300            panic!("Expected Select");
1301        }
1302    }
1303
1304    #[test]
1305    fn test_order_by_expression() {
1306        // SELECT * FROM test ORDER BY count * -1 ASC (effectively DESC by count)
1307        let stmt = crate::query_parser::parse_query(
1308            "SELECT title, count FROM test ORDER BY count * -1 ASC"
1309        ).unwrap();
1310        if let crate::query_parser::Statement::Select(q) = stmt {
1311            let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1312            // count: 20 → -20, 10 → -10, 5 → -5, ASC means -20, -10, -5
1313            assert_eq!(rows[0]["count"], Value::Int(20));
1314            assert_eq!(rows[1]["count"], Value::Int(10));
1315            assert_eq!(rows[2]["count"], Value::Int(5));
1316        } else {
1317            panic!("Expected Select");
1318        }
1319    }
1320
1321    // ── CASE WHEN evaluation tests ────────────────────────────────
1322
1323    #[test]
1324    fn test_case_when_eval_basic() {
1325        let row = Row::from([("status".into(), Value::String("ACTIVE".into()))]);
1326        let expr = Expr::Case {
1327            whens: vec![(
1328                WhereClause::Comparison(Comparison {
1329                    column: "status".into(),
1330                    op: "=".into(),
1331                    value: Some(SqlValue::String("ACTIVE".into())),
1332                    left_expr: Some(Expr::Column("status".into())),
1333                    right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1334                }),
1335                Box::new(Expr::Literal(SqlValue::Int(1))),
1336            )],
1337            else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1338        };
1339        assert_eq!(evaluate_expr(&expr, &row), Value::Int(1));
1340    }
1341
1342    #[test]
1343    fn test_case_when_eval_else() {
1344        let row = Row::from([("status".into(), Value::String("KILLED".into()))]);
1345        let expr = Expr::Case {
1346            whens: vec![(
1347                WhereClause::Comparison(Comparison {
1348                    column: "status".into(),
1349                    op: "=".into(),
1350                    value: Some(SqlValue::String("ACTIVE".into())),
1351                    left_expr: Some(Expr::Column("status".into())),
1352                    right_expr: Some(Expr::Literal(SqlValue::String("ACTIVE".into()))),
1353                }),
1354                Box::new(Expr::Literal(SqlValue::Int(1))),
1355            )],
1356            else_expr: Some(Box::new(Expr::Literal(SqlValue::Int(0)))),
1357        };
1358        assert_eq!(evaluate_expr(&expr, &row), Value::Int(0));
1359    }
1360
1361    #[test]
1362    fn test_case_when_eval_no_else_null() {
1363        let row = Row::from([("x".into(), Value::Int(99))]);
1364        let expr = Expr::Case {
1365            whens: vec![(
1366                WhereClause::Comparison(Comparison {
1367                    column: "x".into(),
1368                    op: "=".into(),
1369                    value: Some(SqlValue::Int(1)),
1370                    left_expr: Some(Expr::Column("x".into())),
1371                    right_expr: Some(Expr::Literal(SqlValue::Int(1))),
1372                }),
1373                Box::new(Expr::Literal(SqlValue::String("one".into()))),
1374            )],
1375            else_expr: None,
1376        };
1377        assert_eq!(evaluate_expr(&expr, &row), Value::Null);
1378    }
1379
1380    #[test]
1381    fn test_case_when_in_aggregate_query() {
1382        // SUM(CASE WHEN count > 5 THEN count ELSE 0 END)
1383        // Rows: count=10, count=5, count=20 → should sum 10 + 0 + 20 = 30
1384        let stmt = crate::query_parser::parse_query(
1385            "SELECT SUM(CASE WHEN count > 5 THEN count ELSE 0 END) AS total FROM test"
1386        ).unwrap();
1387        if let crate::query_parser::Statement::Select(q) = stmt {
1388            let (rows, cols) = execute(&q, &make_rows(), None).unwrap();
1389            assert_eq!(cols, vec!["total"]);
1390            assert_eq!(rows.len(), 1);
1391            assert_eq!(rows[0]["total"], Value::Float(30.0));
1392        } else {
1393            panic!("Expected Select");
1394        }
1395    }
1396
1397    #[test]
1398    fn test_case_when_with_unary_minus_in_aggregate() {
1399        // SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END)
1400        // Alpha: 10, Beta: -5, Gamma: -20 → 10 - 5 - 20 = -15
1401        let stmt = crate::query_parser::parse_query(
1402            "SELECT SUM(CASE WHEN title = 'Alpha' THEN count ELSE -count END) AS net FROM test"
1403        ).unwrap();
1404        if let crate::query_parser::Statement::Select(q) = stmt {
1405            let (rows, _) = execute(&q, &make_rows(), None).unwrap();
1406            assert_eq!(rows.len(), 1);
1407            assert_eq!(rows[0]["net"], Value::Float(-15.0));
1408        } else {
1409            panic!("Expected Select");
1410        }
1411    }
1412}