Skip to main content

mdql_core/
query_engine.rs

1//! Execute parsed queries over in-memory rows.
2
3use std::cmp::Ordering;
4use std::collections::HashMap;
5
6use regex::Regex;
7
8use crate::errors::MdqlError;
9use crate::model::{Row, Value};
10use crate::query_parser::*;
11use crate::schema::Schema;
12
13pub fn execute_query(
14    query: &SelectQuery,
15    rows: &[Row],
16    _schema: &Schema,
17) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
18    execute(query, rows, None)
19}
20
21/// Execute a query with optional B-tree index and FTS searcher.
22pub fn execute_query_indexed(
23    query: &SelectQuery,
24    rows: &[Row],
25    schema: &Schema,
26    index: Option<&crate::index::TableIndex>,
27    searcher: Option<&crate::search::TableSearcher>,
28) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
29    // Pre-compute FTS results for any LIKE clauses on section columns
30    let fts_results = if let (Some(ref wc), Some(searcher)) = (&query.where_clause, searcher) {
31        collect_fts_results(wc, schema, searcher)
32    } else {
33        HashMap::new()
34    };
35
36    execute_with_fts(query, rows, index, &fts_results)
37}
38
39/// Collect FTS results for LIKE comparisons on section columns.
40/// Returns a map from (column, pattern) → set of matching paths.
41fn collect_fts_results(
42    clause: &WhereClause,
43    schema: &Schema,
44    searcher: &crate::search::TableSearcher,
45) -> HashMap<(String, String), std::collections::HashSet<String>> {
46    let mut results = HashMap::new();
47    collect_fts_results_inner(clause, schema, searcher, &mut results);
48    results
49}
50
51fn collect_fts_results_inner(
52    clause: &WhereClause,
53    schema: &Schema,
54    searcher: &crate::search::TableSearcher,
55    results: &mut HashMap<(String, String), std::collections::HashSet<String>>,
56) {
57    match clause {
58        WhereClause::Comparison(cmp) => {
59            if (cmp.op == "LIKE" || cmp.op == "NOT LIKE") && schema.sections.contains_key(&cmp.column) {
60                if let Some(SqlValue::String(pattern)) = &cmp.value {
61                    // Strip SQL wildcards for Tantivy query
62                    let search_term = pattern.replace('%', " ").replace('_', " ").trim().to_string();
63                    if !search_term.is_empty() {
64                        if let Ok(paths) = searcher.search(&search_term, Some(&cmp.column)) {
65                            let key = (cmp.column.clone(), pattern.clone());
66                            results.insert(key, paths.into_iter().collect());
67                        }
68                    }
69                }
70            }
71        }
72        WhereClause::BoolOp(bop) => {
73            collect_fts_results_inner(&bop.left, schema, searcher, results);
74            collect_fts_results_inner(&bop.right, schema, searcher, results);
75        }
76    }
77}
78
79type FtsResults = HashMap<(String, String), std::collections::HashSet<String>>;
80
81fn execute_with_fts(
82    query: &SelectQuery,
83    rows: &[Row],
84    index: Option<&crate::index::TableIndex>,
85    fts: &FtsResults,
86) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
87    // Determine available columns
88    let mut all_columns: Vec<String> = Vec::new();
89    let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
90    for r in rows {
91        for k in r.keys() {
92            if seen.insert(k.clone()) {
93                all_columns.push(k.clone());
94            }
95        }
96    }
97
98    // Check if query has aggregates
99    let has_aggregates = match &query.columns {
100        ColumnList::Named(exprs) => exprs.iter().any(|e| e.is_aggregate()),
101        _ => false,
102    };
103
104    // Output column names
105    let columns: Vec<String> = match &query.columns {
106        ColumnList::All => all_columns,
107        ColumnList::Named(exprs) => exprs.iter().map(|e| e.output_name()).collect(),
108    };
109
110    // Filter — try index first, fall back to full scan
111    let filtered: Vec<Row> = if let Some(ref wc) = query.where_clause {
112        let candidate_paths = index.and_then(|idx| try_index_filter(wc, idx));
113        if let Some(paths) = candidate_paths {
114            rows.iter()
115                .filter(|r| {
116                    r.get("path")
117                        .and_then(|v| v.as_str())
118                        .map_or(false, |p| paths.contains(p))
119                })
120                .filter(|r| evaluate_with_fts(wc, r, fts))
121                .cloned()
122                .collect()
123        } else {
124            rows.iter()
125                .filter(|r| evaluate_with_fts(wc, r, fts))
126                .cloned()
127                .collect()
128        }
129    } else {
130        rows.to_vec()
131    };
132
133    // Aggregate if needed
134    let mut result = if has_aggregates || query.group_by.is_some() {
135        let exprs = match &query.columns {
136            ColumnList::Named(exprs) => exprs.clone(),
137            _ => return Err(MdqlError::QueryExecution(
138                "SELECT * with GROUP BY is not supported".into(),
139            )),
140        };
141        let group_keys = query.group_by.as_deref().unwrap_or(&[]);
142        aggregate_rows(&filtered, &exprs, group_keys)?
143    } else {
144        filtered
145    };
146
147    // Sort
148    if let Some(ref order_by) = query.order_by {
149        sort_rows(&mut result, order_by);
150    }
151
152    // Limit
153    if let Some(limit) = query.limit {
154        result.truncate(limit as usize);
155    }
156
157    // Project — strip row dicts to only the requested columns
158    if !matches!(query.columns, ColumnList::All) {
159        let col_set: std::collections::HashSet<&str> =
160            columns.iter().map(|s| s.as_str()).collect();
161        for row in &mut result {
162            row.retain(|k, _| col_set.contains(k.as_str()));
163        }
164    }
165
166    Ok((result, columns))
167}
168
169fn aggregate_rows(
170    rows: &[Row],
171    exprs: &[SelectExpr],
172    group_keys: &[String],
173) -> crate::errors::Result<Vec<Row>> {
174    // Group rows by group_keys
175    let mut groups: Vec<(Vec<Value>, Vec<&Row>)> = Vec::new();
176    let mut key_index: HashMap<Vec<String>, usize> = HashMap::new();
177
178    if group_keys.is_empty() {
179        // No GROUP BY — all rows are one group
180        let all_refs: Vec<&Row> = rows.iter().collect();
181        groups.push((vec![], all_refs));
182    } else {
183        for row in rows {
184            let key: Vec<String> = group_keys
185                .iter()
186                .map(|k| {
187                    row.get(k)
188                        .map(|v| v.to_display_string())
189                        .unwrap_or_default()
190                })
191                .collect();
192            let key_vals: Vec<Value> = group_keys
193                .iter()
194                .map(|k| row.get(k).cloned().unwrap_or(Value::Null))
195                .collect();
196            if let Some(&idx) = key_index.get(&key) {
197                groups[idx].1.push(row);
198            } else {
199                let idx = groups.len();
200                key_index.insert(key, idx);
201                groups.push((key_vals, vec![row]));
202            }
203        }
204    }
205
206    // Compute aggregates per group
207    let mut result = Vec::new();
208    for (key_vals, group_rows) in &groups {
209        let mut out = Row::new();
210
211        // Fill in group key values
212        for (i, k) in group_keys.iter().enumerate() {
213            out.insert(k.clone(), key_vals[i].clone());
214        }
215
216        // Compute each expression
217        for expr in exprs {
218            match expr {
219                SelectExpr::Column(name) => {
220                    // Already filled if it's a group key; otherwise take first row's value
221                    if !out.contains_key(name) {
222                        if let Some(first) = group_rows.first() {
223                            out.insert(
224                                name.clone(),
225                                first.get(name).cloned().unwrap_or(Value::Null),
226                            );
227                        }
228                    }
229                }
230                SelectExpr::Aggregate { func, arg, alias } => {
231                    let out_name = alias
232                        .clone()
233                        .unwrap_or_else(|| expr.output_name());
234                    let val = compute_aggregate(func, arg, group_rows);
235                    out.insert(out_name, val);
236                }
237            }
238        }
239
240        result.push(out);
241    }
242
243    Ok(result)
244}
245
246fn compute_aggregate(func: &AggFunc, arg: &str, rows: &[&Row]) -> Value {
247    match func {
248        AggFunc::Count => {
249            if arg == "*" {
250                Value::Int(rows.len() as i64)
251            } else {
252                let count = rows
253                    .iter()
254                    .filter(|r| {
255                        r.get(arg)
256                            .map_or(false, |v| !v.is_null())
257                    })
258                    .count();
259                Value::Int(count as i64)
260            }
261        }
262        AggFunc::Sum => {
263            let mut total = 0.0f64;
264            let mut has_any = false;
265            for r in rows {
266                if let Some(v) = r.get(arg) {
267                    match v {
268                        Value::Int(n) => { total += *n as f64; has_any = true; }
269                        Value::Float(f) => { total += f; has_any = true; }
270                        _ => {}
271                    }
272                }
273            }
274            if has_any { Value::Float(total) } else { Value::Null }
275        }
276        AggFunc::Avg => {
277            let mut total = 0.0f64;
278            let mut count = 0usize;
279            for r in rows {
280                if let Some(v) = r.get(arg) {
281                    match v {
282                        Value::Int(n) => { total += *n as f64; count += 1; }
283                        Value::Float(f) => { total += f; count += 1; }
284                        _ => {}
285                    }
286                }
287            }
288            if count > 0 { Value::Float(total / count as f64) } else { Value::Null }
289        }
290        AggFunc::Min => {
291            let mut min_val: Option<Value> = None;
292            for r in rows {
293                if let Some(v) = r.get(arg) {
294                    if v.is_null() { continue; }
295                    min_val = Some(match min_val {
296                        None => v.clone(),
297                        Some(ref current) => {
298                            if v.partial_cmp(current) == Some(std::cmp::Ordering::Less) {
299                                v.clone()
300                            } else {
301                                current.clone()
302                            }
303                        }
304                    });
305                }
306            }
307            min_val.unwrap_or(Value::Null)
308        }
309        AggFunc::Max => {
310            let mut max_val: Option<Value> = None;
311            for r in rows {
312                if let Some(v) = r.get(arg) {
313                    if v.is_null() { continue; }
314                    max_val = Some(match max_val {
315                        None => v.clone(),
316                        Some(ref current) => {
317                            if v.partial_cmp(current) == Some(std::cmp::Ordering::Greater) {
318                                v.clone()
319                            } else {
320                                current.clone()
321                            }
322                        }
323                    });
324                }
325            }
326            max_val.unwrap_or(Value::Null)
327        }
328    }
329}
330
331fn evaluate_with_fts(clause: &WhereClause, row: &Row, fts: &FtsResults) -> bool {
332    match clause {
333        WhereClause::BoolOp(bop) => {
334            let left = evaluate_with_fts(&bop.left, row, fts);
335            match bop.op.as_str() {
336                "AND" => left && evaluate_with_fts(&bop.right, row, fts),
337                "OR" => left || evaluate_with_fts(&bop.right, row, fts),
338                _ => false,
339            }
340        }
341        WhereClause::Comparison(cmp) => {
342            // Check if we have FTS results for this comparison
343            if cmp.op == "LIKE" || cmp.op == "NOT LIKE" {
344                if let Some(SqlValue::String(pattern)) = &cmp.value {
345                    let key = (cmp.column.clone(), pattern.clone());
346                    if let Some(matching_paths) = fts.get(&key) {
347                        let row_path = row.get("path").and_then(|v| v.as_str()).unwrap_or("");
348                        let matched = matching_paths.contains(row_path);
349                        return if cmp.op == "LIKE" { matched } else { !matched };
350                    }
351                }
352            }
353            evaluate_comparison(cmp, row)
354        }
355    }
356}
357
358pub fn execute_join_query(
359    query: &SelectQuery,
360    tables: &HashMap<String, (Schema, Vec<Row>)>,
361) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
362    if query.joins.is_empty() {
363        return Err(MdqlError::QueryExecution("No JOIN clause in query".into()));
364    }
365
366    let left_name = &query.table;
367    let left_alias = query.table_alias.as_deref().unwrap_or(left_name);
368
369    // Build alias→table mapping for all tables
370    let mut aliases: HashMap<String, String> = HashMap::new();
371    aliases.insert(left_name.clone(), left_name.clone());
372    if let Some(ref a) = query.table_alias {
373        aliases.insert(a.clone(), left_name.clone());
374    }
375    for join in &query.joins {
376        aliases.insert(join.table.clone(), join.table.clone());
377        if let Some(ref a) = join.alias {
378            aliases.insert(a.clone(), join.table.clone());
379        }
380    }
381
382    // Start with the left table rows, prefixed with alias
383    let (_left_schema, left_rows) = tables.get(left_name.as_str()).ok_or_else(|| {
384        MdqlError::QueryExecution(format!("Unknown table '{}'", left_name))
385    })?;
386
387    let mut current_rows: Vec<Row> = left_rows
388        .iter()
389        .map(|r| {
390            let mut prefixed = Row::new();
391            for (k, v) in r {
392                prefixed.insert(format!("{}.{}", left_alias, k), v.clone());
393            }
394            prefixed
395        })
396        .collect();
397
398    // Process each JOIN sequentially
399    for join in &query.joins {
400        let right_name = &join.table;
401        let right_alias = join.alias.as_deref().unwrap_or(right_name);
402
403        let (_right_schema, right_rows) = tables.get(right_name.as_str()).ok_or_else(|| {
404            MdqlError::QueryExecution(format!("Unknown table '{}'", right_name))
405        })?;
406
407        // Resolve ON columns to determine which is left vs right
408        let (on_left_table, on_left_col) = resolve_dotted(&join.left_col, &aliases);
409        let (on_right_table, on_right_col) = resolve_dotted(&join.right_col, &aliases);
410
411        // Figure out which ON column refers to the new right table
412        let (left_key, right_key) = if on_right_table == *right_name {
413            // left_col is from the left side, right_col is from the right table
414            let left_alias_for_col = reverse_alias(&on_left_table, &aliases, query, &query.joins);
415            (format!("{}.{}", left_alias_for_col, on_left_col), on_right_col)
416        } else {
417            // right_col is from the left side, left_col is from the right table
418            let right_alias_for_col = reverse_alias(&on_right_table, &aliases, query, &query.joins);
419            (format!("{}.{}", right_alias_for_col, on_right_col), on_left_col)
420        };
421
422        // Build index on right table
423        let mut right_index: HashMap<String, Vec<&Row>> = HashMap::new();
424        for r in right_rows {
425            if let Some(key) = r.get(&right_key) {
426                let key_str = key.to_display_string();
427                right_index.entry(key_str).or_default().push(r);
428            }
429        }
430
431        // Join current rows with right table
432        let mut next_rows: Vec<Row> = Vec::new();
433        for lr in &current_rows {
434            if let Some(key) = lr.get(&left_key) {
435                let key_str = key.to_display_string();
436                if let Some(matching) = right_index.get(&key_str) {
437                    for rr in matching {
438                        let mut merged = lr.clone();
439                        for (k, v) in *rr {
440                            merged.insert(format!("{}.{}", right_alias, k), v.clone());
441                        }
442                        next_rows.push(merged);
443                    }
444                }
445            }
446        }
447        current_rows = next_rows;
448    }
449
450    let (mut result, columns) = execute(query, &current_rows, None)?;
451
452    // Add unprefixed aliases for non-colliding column names in the output.
453    // e.g., if result has s.title and b.sharpe (no other "title" or "sharpe"),
454    // add "title" and "sharpe" as shorthand keys.
455    if !result.is_empty() {
456        let mut base_counts: HashMap<String, usize> = HashMap::new();
457        for key in &columns {
458            if let Some((_prefix, base)) = key.split_once('.') {
459                *base_counts.entry(base.to_string()).or_default() += 1;
460            }
461        }
462        let unique_bases: Vec<String> = base_counts
463            .into_iter()
464            .filter(|(_, count)| *count == 1)
465            .map(|(base, _)| base)
466            .collect();
467
468        if !unique_bases.is_empty() {
469            let unique_set: std::collections::HashSet<&str> =
470                unique_bases.iter().map(|s| s.as_str()).collect();
471            for row in &mut result {
472                let additions: Vec<(String, Value)> = row
473                    .iter()
474                    .filter_map(|(k, v)| {
475                        k.split_once('.').and_then(|(_, base)| {
476                            if unique_set.contains(base) {
477                                Some((base.to_string(), v.clone()))
478                            } else {
479                                None
480                            }
481                        })
482                    })
483                    .collect();
484                for (k, v) in additions {
485                    row.insert(k, v);
486                }
487            }
488        }
489    }
490
491    Ok((result, columns))
492}
493
494/// Given a table name, find the alias used for it.
495fn reverse_alias(
496    table_name: &str,
497    aliases: &HashMap<String, String>,
498    query: &SelectQuery,
499    joins: &[JoinClause],
500) -> String {
501    // Check if the FROM table matches
502    if query.table == table_name {
503        return query.table_alias.as_deref().unwrap_or(&query.table).to_string();
504    }
505    // Check join tables
506    for j in joins {
507        if j.table == table_name {
508            return j.alias.as_deref().unwrap_or(&j.table).to_string();
509        }
510    }
511    // Fall back: check if table_name is itself an alias
512    if aliases.contains_key(table_name) {
513        return table_name.to_string();
514    }
515    table_name.to_string()
516}
517
518fn resolve_dotted(col: &str, aliases: &HashMap<String, String>) -> (String, String) {
519    if let Some((alias, column)) = col.split_once('.') {
520        let table = aliases.get(alias).cloned().unwrap_or_else(|| alias.to_string());
521        (table, column.to_string())
522    } else {
523        (String::new(), col.to_string())
524    }
525}
526
527fn execute(
528    query: &SelectQuery,
529    rows: &[Row],
530    index: Option<&crate::index::TableIndex>,
531) -> crate::errors::Result<(Vec<Row>, Vec<String>)> {
532    let empty_fts = HashMap::new();
533    execute_with_fts(query, rows, index, &empty_fts)
534}
535
536pub fn evaluate(clause: &WhereClause, row: &Row) -> bool {
537    match clause {
538        WhereClause::BoolOp(bop) => {
539            let left = evaluate(&bop.left, row);
540            match bop.op.as_str() {
541                "AND" => left && evaluate(&bop.right, row),
542                "OR" => left || evaluate(&bop.right, row),
543                _ => false,
544            }
545        }
546        WhereClause::Comparison(cmp) => evaluate_comparison(cmp, row),
547    }
548}
549
550fn evaluate_comparison(cmp: &Comparison, row: &Row) -> bool {
551    let actual = row.get(&cmp.column);
552
553    if cmp.op == "IS NULL" {
554        return actual.map_or(true, |v| v.is_null());
555    }
556    if cmp.op == "IS NOT NULL" {
557        return actual.map_or(false, |v| !v.is_null());
558    }
559
560    let actual = match actual {
561        Some(v) if !v.is_null() => v,
562        _ => return false,
563    };
564
565    let expected = match &cmp.value {
566        Some(v) => v,
567        None => return false,
568    };
569
570    match cmp.op.as_str() {
571        "=" => eq_match(actual, expected),
572        "!=" => !eq_match(actual, expected),
573        "<" => compare_values(actual, expected) == Some(Ordering::Less),
574        ">" => compare_values(actual, expected) == Some(Ordering::Greater),
575        "<=" => matches!(compare_values(actual, expected), Some(Ordering::Less | Ordering::Equal)),
576        ">=" => matches!(compare_values(actual, expected), Some(Ordering::Greater | Ordering::Equal)),
577        "LIKE" => like_match(actual, expected),
578        "NOT LIKE" => !like_match(actual, expected),
579        "IN" => {
580            if let SqlValue::List(items) = expected {
581                items.iter().any(|v| eq_match(actual, v))
582            } else {
583                eq_match(actual, expected)
584            }
585        }
586        _ => false,
587    }
588}
589
590fn coerce_sql_to_value(sql_val: &SqlValue, target: &Value) -> Value {
591    match sql_val {
592        SqlValue::Null => Value::Null,
593        SqlValue::String(s) => {
594            match target {
595                Value::Int(_) => s.parse::<i64>().map(Value::Int).unwrap_or(Value::String(s.clone())),
596                Value::Float(_) => s.parse::<f64>().map(Value::Float).unwrap_or(Value::String(s.clone())),
597                Value::Date(_) => {
598                    chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
599                        .map(Value::Date)
600                        .unwrap_or(Value::String(s.clone()))
601                }
602                _ => Value::String(s.clone()),
603            }
604        }
605        SqlValue::Int(n) => {
606            match target {
607                Value::Float(_) => Value::Float(*n as f64),
608                _ => Value::Int(*n),
609            }
610        }
611        SqlValue::Float(f) => Value::Float(*f),
612        SqlValue::List(_) => Value::Null, // Lists handled separately
613    }
614}
615
616fn eq_match(actual: &Value, expected: &SqlValue) -> bool {
617    // Special handling for lists (e.g., categories)
618    if let Value::List(items) = actual {
619        if let SqlValue::String(s) = expected {
620            return items.contains(s);
621        }
622    }
623
624    let coerced = coerce_sql_to_value(expected, actual);
625    actual == &coerced
626}
627
628fn like_match(actual: &Value, pattern: &SqlValue) -> bool {
629    let pattern_str = match pattern {
630        SqlValue::String(s) => s,
631        _ => return false,
632    };
633
634    // Convert SQL LIKE to regex
635    let mut regex_str = String::from("(?is)^");
636    for ch in pattern_str.chars() {
637        match ch {
638            '%' => regex_str.push_str(".*"),
639            '_' => regex_str.push('.'),
640            c => {
641                if regex::escape(&c.to_string()) != c.to_string() {
642                    regex_str.push_str(&regex::escape(&c.to_string()));
643                } else {
644                    regex_str.push(c);
645                }
646            }
647        }
648    }
649    regex_str.push('$');
650
651    let re = match Regex::new(&regex_str) {
652        Ok(r) => r,
653        Err(_) => return false,
654    };
655
656    match actual {
657        Value::List(items) => items.iter().any(|item| re.is_match(item)),
658        _ => re.is_match(&actual.to_display_string()),
659    }
660}
661
662fn compare_values(actual: &Value, expected: &SqlValue) -> Option<Ordering> {
663    let coerced = coerce_sql_to_value(expected, actual);
664    actual.partial_cmp(&coerced).map(|o| o)
665}
666
667/// Convert a SqlValue to a Value for index lookups (without a target type for coercion).
668fn sql_value_to_index_value(sv: &SqlValue) -> Value {
669    match sv {
670        SqlValue::String(s) => {
671            // Try date
672            if let Ok(d) = chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d") {
673                return Value::Date(d);
674            }
675            Value::String(s.clone())
676        }
677        SqlValue::Int(n) => Value::Int(*n),
678        SqlValue::Float(f) => Value::Float(*f),
679        SqlValue::Null => Value::Null,
680        SqlValue::List(_) => Value::Null,
681    }
682}
683
684/// Try to use B-tree indexes to narrow the candidate row set.
685/// Returns Some(paths) if the entire WHERE clause could be resolved via index,
686/// or None if a full scan is needed.
687fn try_index_filter(
688    clause: &WhereClause,
689    index: &crate::index::TableIndex,
690) -> Option<std::collections::HashSet<String>> {
691    match clause {
692        WhereClause::Comparison(cmp) => {
693            if !index.has_index(&cmp.column) {
694                return None;
695            }
696            match cmp.op.as_str() {
697                "=" => {
698                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
699                    let paths = index.lookup_eq(&cmp.column, &val);
700                    Some(paths.into_iter().map(|s| s.to_string()).collect())
701                }
702                "<" => {
703                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
704                    // exclusive upper bound: use range with max < val
705                    // lookup_range is inclusive, so we get all <= val then remove exact matches
706                    let range_paths = index.lookup_range(&cmp.column, None, Some(&val));
707                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
708                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
709                }
710                ">" => {
711                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
712                    let range_paths = index.lookup_range(&cmp.column, Some(&val), None);
713                    let eq_paths: std::collections::HashSet<&str> = index.lookup_eq(&cmp.column, &val).into_iter().collect();
714                    Some(range_paths.into_iter().filter(|p| !eq_paths.contains(p)).map(|s| s.to_string()).collect())
715                }
716                "<=" => {
717                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
718                    let paths = index.lookup_range(&cmp.column, None, Some(&val));
719                    Some(paths.into_iter().map(|s| s.to_string()).collect())
720                }
721                ">=" => {
722                    let val = sql_value_to_index_value(cmp.value.as_ref()?);
723                    let paths = index.lookup_range(&cmp.column, Some(&val), None);
724                    Some(paths.into_iter().map(|s| s.to_string()).collect())
725                }
726                "IN" => {
727                    if let Some(SqlValue::List(items)) = &cmp.value {
728                        let vals: Vec<Value> = items.iter().map(sql_value_to_index_value).collect();
729                        let paths = index.lookup_in(&cmp.column, &vals);
730                        Some(paths.into_iter().map(|s| s.to_string()).collect())
731                    } else {
732                        None
733                    }
734                }
735                _ => None, // LIKE, IS NULL, etc. can't use index
736            }
737        }
738        WhereClause::BoolOp(bop) => {
739            let left = try_index_filter(&bop.left, index);
740            let right = try_index_filter(&bop.right, index);
741            match bop.op.as_str() {
742                "AND" => {
743                    match (left, right) {
744                        (Some(l), Some(r)) => Some(l.intersection(&r).cloned().collect()),
745                        (Some(l), None) => Some(l), // narrow with left, scan-verify right
746                        (None, Some(r)) => Some(r),
747                        (None, None) => None,
748                    }
749                }
750                "OR" => {
751                    match (left, right) {
752                        (Some(l), Some(r)) => Some(l.union(&r).cloned().collect()),
753                        _ => None, // Can't use index if either side needs full scan
754                    }
755                }
756                _ => None,
757            }
758        }
759    }
760}
761
762fn sort_rows(rows: &mut Vec<Row>, specs: &[OrderSpec]) {
763    rows.sort_by(|a, b| {
764        for spec in specs {
765            let va = a.get(&spec.column);
766            let vb = b.get(&spec.column);
767
768            // NULLs sort last
769            let ordering = match (va, vb) {
770                (None, None) | (Some(Value::Null), Some(Value::Null)) => Ordering::Equal,
771                (None, _) | (Some(Value::Null), _) => Ordering::Greater,
772                (_, None) | (_, Some(Value::Null)) => Ordering::Less,
773                (Some(a_val), Some(b_val)) => {
774                    a_val.partial_cmp(b_val).unwrap_or(Ordering::Equal)
775                }
776            };
777
778            let ordering = if spec.descending {
779                ordering.reverse()
780            } else {
781                ordering
782            };
783
784            if ordering != Ordering::Equal {
785                return ordering;
786            }
787        }
788        Ordering::Equal
789    });
790}
791
792/// Convert a SqlValue to our model Value (for use in insert/update).
793pub fn sql_value_to_value(sql_val: &SqlValue) -> Value {
794    match sql_val {
795        SqlValue::Null => Value::Null,
796        SqlValue::String(s) => Value::String(s.clone()),
797        SqlValue::Int(n) => Value::Int(*n),
798        SqlValue::Float(f) => Value::Float(*f),
799        SqlValue::List(items) => {
800            let strings: Vec<String> = items
801                .iter()
802                .filter_map(|v| match v {
803                    SqlValue::String(s) => Some(s.clone()),
804                    _ => None,
805                })
806                .collect();
807            Value::List(strings)
808        }
809    }
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815
816    fn make_rows() -> Vec<Row> {
817        vec![
818            Row::from([
819                ("path".into(), Value::String("a.md".into())),
820                ("title".into(), Value::String("Alpha".into())),
821                ("count".into(), Value::Int(10)),
822            ]),
823            Row::from([
824                ("path".into(), Value::String("b.md".into())),
825                ("title".into(), Value::String("Beta".into())),
826                ("count".into(), Value::Int(5)),
827            ]),
828            Row::from([
829                ("path".into(), Value::String("c.md".into())),
830                ("title".into(), Value::String("Gamma".into())),
831                ("count".into(), Value::Int(20)),
832            ]),
833        ]
834    }
835
836    #[test]
837    fn test_select_all() {
838        let q = SelectQuery {
839            columns: ColumnList::All,
840            table: "test".into(),
841            table_alias: None,
842            joins: vec![],
843            where_clause: None,
844            group_by: None,
845            order_by: None,
846            limit: None,
847        };
848        let (rows, _cols) = execute(&q, &make_rows(), None).unwrap();
849        assert_eq!(rows.len(), 3);
850    }
851
852    #[test]
853    fn test_where_gt() {
854        let q = SelectQuery {
855            columns: ColumnList::All,
856            table: "test".into(),
857            table_alias: None,
858            joins: vec![],
859            where_clause: Some(WhereClause::Comparison(Comparison {
860                column: "count".into(),
861                op: ">".into(),
862                value: Some(SqlValue::Int(5)),
863            })),
864            group_by: None,
865            order_by: None,
866            limit: None,
867        };
868        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
869        assert_eq!(rows.len(), 2);
870    }
871
872    #[test]
873    fn test_order_by_desc() {
874        let q = SelectQuery {
875            columns: ColumnList::All,
876            table: "test".into(),
877            table_alias: None,
878            joins: vec![],
879            where_clause: None,
880            group_by: None,
881            order_by: Some(vec![OrderSpec {
882                column: "count".into(),
883                descending: true,
884            }]),
885            limit: None,
886        };
887        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
888        assert_eq!(rows[0]["count"], Value::Int(20));
889        assert_eq!(rows[2]["count"], Value::Int(5));
890    }
891
892    #[test]
893    fn test_limit() {
894        let q = SelectQuery {
895            columns: ColumnList::All,
896            table: "test".into(),
897            table_alias: None,
898            joins: vec![],
899            where_clause: None,
900            group_by: None,
901            order_by: None,
902            limit: Some(2),
903        };
904        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
905        assert_eq!(rows.len(), 2);
906    }
907
908    #[test]
909    fn test_like() {
910        let q = SelectQuery {
911            columns: ColumnList::All,
912            table: "test".into(),
913            table_alias: None,
914            joins: vec![],
915            where_clause: Some(WhereClause::Comparison(Comparison {
916                column: "title".into(),
917                op: "LIKE".into(),
918                value: Some(SqlValue::String("%lph%".into())),
919            })),
920            group_by: None,
921            order_by: None,
922            limit: None,
923        };
924        let (rows, _) = execute(&q, &make_rows(), None).unwrap();
925        assert_eq!(rows.len(), 1);
926        assert_eq!(rows[0]["title"], Value::String("Alpha".into()));
927    }
928
929    #[test]
930    fn test_is_null() {
931        let mut rows = make_rows();
932        rows[1].insert("optional".into(), Value::Null);
933
934        let q = SelectQuery {
935            columns: ColumnList::All,
936            table: "test".into(),
937            table_alias: None,
938            joins: vec![],
939            where_clause: Some(WhereClause::Comparison(Comparison {
940                column: "optional".into(),
941                op: "IS NULL".into(),
942                value: None,
943            })),
944            group_by: None,
945            order_by: None,
946            limit: None,
947        };
948        let (result, _) = execute(&q, &rows, None).unwrap();
949        // All rows where optional is NULL or missing
950        assert_eq!(result.len(), 3);
951    }
952}