Skip to main content

sqlrite/sql/
executor.rs

1//! Query executors — evaluate parsed SQL statements against the in-memory
2//! storage and produce formatted output.
3
4use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8    AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9    FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10    TableWithJoins, UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
17use crate::sql::hnsw::{DistanceMetric, HnswIndex};
18use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
19
20/// Executes a parsed `SelectQuery` against the database and returns a
21/// human-readable rendering of the result set (prettytable). Also returns
22/// the number of rows produced, for the top-level status message.
23/// Structured result of a SELECT: column names in projection order,
24/// and each matching row as a `Vec<Value>` aligned with the columns.
25/// Phase 5a introduced this so the public `Connection` / `Statement`
26/// API has typed rows to yield; the existing `execute_select` that
27/// returns pre-rendered text is now a thin wrapper on top.
28pub struct SelectResult {
29    pub columns: Vec<String>,
30    pub rows: Vec<Vec<Value>>,
31}
32
33/// Executes a SELECT and returns structured rows. The typed rows are
34/// what the new public API streams to callers; the REPL / Tauri app
35/// pre-render into a prettytable via `execute_select`.
36pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
37    let table = db
38        .get_table(query.table_name.clone())
39        .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
40
41    // Resolve projection to a concrete ordered column list.
42    let projected_cols: Vec<String> = match &query.projection {
43        Projection::All => table.column_names(),
44        Projection::Columns(cols) => {
45            for c in cols {
46                if !table.contains_column(c.to_string()) {
47                    return Err(SQLRiteError::Internal(format!(
48                        "Column '{c}' does not exist on table '{}'",
49                        query.table_name
50                    )));
51                }
52            }
53            cols.clone()
54        }
55    };
56
57    // Collect matching rowids. If the WHERE is the shape `col = literal`
58    // and `col` has a secondary index, probe the index for an O(log N)
59    // seek; otherwise fall back to the full table scan.
60    let matching = match select_rowids(table, query.selection.as_ref())? {
61        RowidSource::IndexProbe(rowids) => rowids,
62        RowidSource::FullScan => {
63            let mut out = Vec::new();
64            for rowid in table.rowids() {
65                if let Some(expr) = &query.selection {
66                    if !eval_predicate(expr, table, rowid)? {
67                        continue;
68                    }
69                }
70                out.push(rowid);
71            }
72            out
73        }
74    };
75    let mut matching = matching;
76
77    // Phase 7c — bounded-heap top-k optimization.
78    //
79    // The naive "ORDER BY <expr>" path (Phase 7b) sorts every matching
80    // rowid: O(N log N) sort_by + a truncate. For KNN queries
81    //
82    //     SELECT id FROM docs
83    //     ORDER BY vec_distance_l2(embedding, [...])
84    //     LIMIT 10;
85    //
86    // N is the table row count and k is the LIMIT. With a bounded
87    // max-heap of size k we can find the top-k in O(N log k) — same
88    // sort_by-per-row cost on the heap operations, but k is typically
89    // 10-100 while N can be millions.
90    //
91    // Phase 7d.2 — HNSW ANN probe.
92    //
93    // Even better than the bounded heap: if the ORDER BY expression is
94    // exactly `vec_distance_l2(<col>, <bracket-array literal>)` AND
95    // `<col>` has an HNSW index attached, skip the linear scan
96    // entirely and probe the graph in O(log N). Approximate but
97    // typically ≥ 0.95 recall (verified by the recall tests in
98    // src/sql/hnsw.rs).
99    //
100    // We branch in cases:
101    //   1. ORDER BY + LIMIT k matches the HNSW probe pattern  → graph probe.
102    //   2. ORDER BY + LIMIT k where k < |matching|            → bounded heap (7c).
103    //   3. ORDER BY without LIMIT, or LIMIT >= |matching|     → full sort.
104    //   4. LIMIT without ORDER BY                              → just truncate.
105    match (&query.order_by, query.limit) {
106        (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
107            matching = try_hnsw_probe(table, &order.expr, k).unwrap();
108        }
109        (Some(order), Some(k)) if k < matching.len() => {
110            matching = select_topk(&matching, table, order, k)?;
111        }
112        (Some(order), _) => {
113            sort_rowids(&mut matching, table, order)?;
114            if let Some(k) = query.limit {
115                matching.truncate(k);
116            }
117        }
118        (None, Some(k)) => {
119            matching.truncate(k);
120        }
121        (None, None) => {}
122    }
123
124    // Build typed rows. Missing cells surface as `Value::Null` — that
125    // maps a column-not-present-for-this-rowid case onto the public
126    // `Row::get` → `Option<T>` surface cleanly.
127    let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
128    for rowid in &matching {
129        let row: Vec<Value> = projected_cols
130            .iter()
131            .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
132            .collect();
133        rows.push(row);
134    }
135
136    Ok(SelectResult {
137        columns: projected_cols,
138        rows,
139    })
140}
141
142/// Executes a SELECT and returns `(rendered_table, row_count)`. The
143/// REPL and Tauri app use this to keep the table-printing behaviour
144/// the engine has always shipped. Structured callers use
145/// `execute_select_rows` instead.
146pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
147    let result = execute_select_rows(query, db)?;
148    let row_count = result.rows.len();
149
150    let mut print_table = PrintTable::new();
151    let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
152    print_table.add_row(PrintRow::new(header_cells));
153
154    for row in &result.rows {
155        let cells: Vec<PrintCell> = row
156            .iter()
157            .map(|v| PrintCell::new(&v.to_display_string()))
158            .collect();
159        print_table.add_row(PrintRow::new(cells));
160    }
161
162    Ok((print_table.to_string(), row_count))
163}
164
165/// Executes a DELETE statement. Returns the number of rows removed.
166pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
167    let Statement::Delete(Delete {
168        from, selection, ..
169    }) = stmt
170    else {
171        return Err(SQLRiteError::Internal(
172            "execute_delete called on a non-DELETE statement".to_string(),
173        ));
174    };
175
176    let tables = match from {
177        FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
178    };
179    let table_name = extract_single_table_name(tables)?;
180
181    // Phase 7d.2 limitation: HNSW lacks an in-place delete-node operation.
182    // True deletion needs either soft-delete + tombstones or a graph rebuild
183    // — both nontrivial. Until 7d.3 lands persistence we don't have a
184    // natural rebuild trigger either. So: refuse DELETE on tables carrying
185    // any HNSW index, with a message that points at the workaround
186    // (DROP the index, DELETE, recreate).
187    {
188        let table = db.get_table(table_name.clone()).map_err(|_| {
189            SQLRiteError::General(format!("DELETE references unknown table '{table_name}'"))
190        })?;
191        if !table.hnsw_indexes.is_empty() {
192            let names: Vec<&str> = table.hnsw_indexes.iter().map(|e| e.name.as_str()).collect();
193            return Err(SQLRiteError::NotImplemented(format!(
194                "DELETE on tables with HNSW indexes is not supported yet \
195                 (Phase 7d.3 follow-up). DROP the index first, then DELETE, then re-CREATE. \
196                 Table '{table_name}' currently has: {names:?}"
197            )));
198        }
199    }
200
201    // Compute matching rowids with an immutable borrow, then mutate.
202    let matching: Vec<i64> = {
203        let table = db
204            .get_table(table_name.clone())
205            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
206        match select_rowids(table, selection.as_ref())? {
207            RowidSource::IndexProbe(rowids) => rowids,
208            RowidSource::FullScan => {
209                let mut out = Vec::new();
210                for rowid in table.rowids() {
211                    if let Some(expr) = selection {
212                        if !eval_predicate(expr, table, rowid)? {
213                            continue;
214                        }
215                    }
216                    out.push(rowid);
217                }
218                out
219            }
220        }
221    };
222
223    let table = db.get_table_mut(table_name)?;
224    for rowid in &matching {
225        table.delete_row(*rowid);
226    }
227    Ok(matching.len())
228}
229
230/// Executes an UPDATE statement. Returns the number of rows updated.
231pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
232    let Statement::Update(Update {
233        table,
234        assignments,
235        from,
236        selection,
237        ..
238    }) = stmt
239    else {
240        return Err(SQLRiteError::Internal(
241            "execute_update called on a non-UPDATE statement".to_string(),
242        ));
243    };
244
245    if from.is_some() {
246        return Err(SQLRiteError::NotImplemented(
247            "UPDATE ... FROM is not supported yet".to_string(),
248        ));
249    }
250
251    let table_name = extract_table_name(table)?;
252
253    // Phase 7d.2 limitation (same shape as DELETE above): we have no
254    // in-place UPDATE-an-HNSW-node primitive. UPDATE on a column NOT
255    // covered by HNSW is fine in principle, but the simplest MVP is
256    // refuse-everything-when-HNSW-is-present. Re-evaluate in 7d.3 once
257    // persistence + rebuild is in.
258    {
259        let tbl = db.get_table(table_name.clone()).map_err(|_| {
260            SQLRiteError::General(format!("UPDATE references unknown table '{table_name}'"))
261        })?;
262        if !tbl.hnsw_indexes.is_empty() {
263            let names: Vec<&str> = tbl.hnsw_indexes.iter().map(|e| e.name.as_str()).collect();
264            return Err(SQLRiteError::NotImplemented(format!(
265                "UPDATE on tables with HNSW indexes is not supported yet \
266                 (Phase 7d.3 follow-up). DROP the index first if you need to mutate. \
267                 Table '{table_name}' currently has: {names:?}"
268            )));
269        }
270    }
271
272    // Resolve assignment targets to plain column names and verify they exist.
273    let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
274    {
275        let tbl = db
276            .get_table(table_name.clone())
277            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
278        for a in assignments {
279            let col = match &a.target {
280                AssignmentTarget::ColumnName(name) => name
281                    .0
282                    .last()
283                    .map(|p| p.to_string())
284                    .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
285                AssignmentTarget::Tuple(_) => {
286                    return Err(SQLRiteError::NotImplemented(
287                        "tuple assignment targets are not supported".to_string(),
288                    ));
289                }
290            };
291            if !tbl.contains_column(col.clone()) {
292                return Err(SQLRiteError::Internal(format!(
293                    "UPDATE references unknown column '{col}'"
294                )));
295            }
296            parsed_assignments.push((col, a.value.clone()));
297        }
298    }
299
300    // Gather matching rowids + the new values to write for each assignment, under
301    // an immutable borrow. Uses the index-probe fast path when the WHERE is
302    // `col = literal` on an indexed column.
303    let work: Vec<(i64, Vec<(String, Value)>)> = {
304        let tbl = db.get_table(table_name.clone())?;
305        let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
306            RowidSource::IndexProbe(rowids) => rowids,
307            RowidSource::FullScan => {
308                let mut out = Vec::new();
309                for rowid in tbl.rowids() {
310                    if let Some(expr) = selection {
311                        if !eval_predicate(expr, tbl, rowid)? {
312                            continue;
313                        }
314                    }
315                    out.push(rowid);
316                }
317                out
318            }
319        };
320        let mut rows_to_update = Vec::new();
321        for rowid in matched_rowids {
322            let mut values = Vec::with_capacity(parsed_assignments.len());
323            for (col, expr) in &parsed_assignments {
324                // UPDATE's RHS is evaluated in the context of the row being updated,
325                // so column references on the right resolve to the current row's values.
326                let v = eval_expr(expr, tbl, rowid)?;
327                values.push((col.clone(), v));
328            }
329            rows_to_update.push((rowid, values));
330        }
331        rows_to_update
332    };
333
334    let tbl = db.get_table_mut(table_name)?;
335    for (rowid, values) in &work {
336        for (col, v) in values {
337            tbl.set_value(col, *rowid, v.clone())?;
338        }
339    }
340    Ok(work.len())
341}
342
343/// Handles `CREATE INDEX [UNIQUE] <name> ON <table> [USING <method>] (<column>)`.
344/// Single-column indexes only.
345///
346/// Two flavours, branching on the optional `USING <method>` clause:
347///   - **No USING, or `USING btree`**: regular B-Tree secondary index
348///     (Phase 3e). Indexable types: Integer, Text.
349///   - **`USING hnsw`**: HNSW ANN index (Phase 7d.2). Indexable types:
350///     Vector(N) only. Distance metric is L2 by default; cosine and
351///     dot variants are deferred to Phase 7d.x.
352///
353/// Returns the (possibly synthesized) index name for the status message.
354pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
355    let Statement::CreateIndex(CreateIndex {
356        name,
357        table_name,
358        columns,
359        using,
360        unique,
361        if_not_exists,
362        predicate,
363        ..
364    }) = stmt
365    else {
366        return Err(SQLRiteError::Internal(
367            "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
368        ));
369    };
370
371    if predicate.is_some() {
372        return Err(SQLRiteError::NotImplemented(
373            "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
374        ));
375    }
376
377    if columns.len() != 1 {
378        return Err(SQLRiteError::NotImplemented(format!(
379            "multi-column indexes are not supported yet ({} columns given)",
380            columns.len()
381        )));
382    }
383
384    let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
385        SQLRiteError::NotImplemented(
386            "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
387        )
388    })?;
389
390    // Detect USING <method>. The `using` field on CreateIndex covers the
391    // pre-column form `CREATE INDEX … USING hnsw (col)`. (sqlparser also
392    // accepts a post-column form `… (col) USING hnsw` and parks that in
393    // `index_options`; we don't bother with it — the canonical form is
394    // pre-column and matches PG/pgvector convention.)
395    let method = match using {
396        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
397            IndexMethod::Hnsw
398        }
399        Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
400            IndexMethod::Btree
401        }
402        Some(other) => {
403            return Err(SQLRiteError::NotImplemented(format!(
404                "CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
405            )));
406        }
407        None => IndexMethod::Btree,
408    };
409
410    let table_name_str = table_name.to_string();
411    let column_name = match &columns[0].column.expr {
412        Expr::Identifier(ident) => ident.value.clone(),
413        Expr::CompoundIdentifier(parts) => parts
414            .last()
415            .map(|p| p.value.clone())
416            .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
417        other => {
418            return Err(SQLRiteError::NotImplemented(format!(
419                "CREATE INDEX only supports simple column references, got {other:?}"
420            )));
421        }
422    };
423
424    // Validate: table exists, column exists, type matches the index method,
425    // name is unique across both index kinds. Snapshot (rowid, value) pairs
426    // up front under the immutable borrow so the mutable attach later
427    // doesn't fight over `self`.
428    let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
429        let table = db.get_table(table_name_str.clone()).map_err(|_| {
430            SQLRiteError::General(format!(
431                "CREATE INDEX references unknown table '{table_name_str}'"
432            ))
433        })?;
434        if !table.contains_column(column_name.clone()) {
435            return Err(SQLRiteError::General(format!(
436                "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
437            )));
438        }
439        let col = table
440            .columns
441            .iter()
442            .find(|c| c.column_name == column_name)
443            .expect("we just verified the column exists");
444
445        // Name uniqueness check spans BOTH index kinds — a btree and an
446        // hnsw can't share a name.
447        if table.index_by_name(&index_name).is_some()
448            || table.hnsw_indexes.iter().any(|i| i.name == index_name)
449        {
450            if *if_not_exists {
451                return Ok(index_name);
452            }
453            return Err(SQLRiteError::General(format!(
454                "index '{index_name}' already exists"
455            )));
456        }
457        let datatype = clone_datatype(&col.datatype);
458
459        let mut pairs = Vec::new();
460        for rowid in table.rowids() {
461            if let Some(v) = table.get_value(&column_name, rowid) {
462                pairs.push((rowid, v));
463            }
464        }
465        (datatype, pairs)
466    };
467
468    match method {
469        IndexMethod::Btree => create_btree_index(
470            db,
471            &table_name_str,
472            &index_name,
473            &column_name,
474            &datatype,
475            *unique,
476            &existing_rowids_and_values,
477        ),
478        IndexMethod::Hnsw => create_hnsw_index(
479            db,
480            &table_name_str,
481            &index_name,
482            &column_name,
483            &datatype,
484            *unique,
485            &existing_rowids_and_values,
486        ),
487    }
488}
489
490/// `USING <method>` choices recognized by `execute_create_index`. A
491/// missing USING clause defaults to `Btree` so existing CREATE INDEX
492/// statements (Phase 3e) keep working unchanged.
493#[derive(Debug, Clone, Copy)]
494enum IndexMethod {
495    Btree,
496    Hnsw,
497}
498
499/// Builds a Phase 3e B-Tree secondary index and attaches it to the table.
500fn create_btree_index(
501    db: &mut Database,
502    table_name: &str,
503    index_name: &str,
504    column_name: &str,
505    datatype: &DataType,
506    unique: bool,
507    existing: &[(i64, Value)],
508) -> Result<String> {
509    let mut idx = SecondaryIndex::new(
510        index_name.to_string(),
511        table_name.to_string(),
512        column_name.to_string(),
513        datatype,
514        unique,
515        IndexOrigin::Explicit,
516    )?;
517
518    // Populate from existing rows. UNIQUE violations here mean the
519    // existing data already breaks the new index's constraint — a
520    // common source of user confusion, so be explicit.
521    for (rowid, v) in existing {
522        if unique && idx.would_violate_unique(v) {
523            return Err(SQLRiteError::General(format!(
524                "cannot create UNIQUE index '{index_name}': column '{column_name}' \
525                 already contains the duplicate value {}",
526                v.to_display_string()
527            )));
528        }
529        idx.insert(v, *rowid)?;
530    }
531
532    let table_mut = db.get_table_mut(table_name.to_string())?;
533    table_mut.secondary_indexes.push(idx);
534    Ok(index_name.to_string())
535}
536
537/// Builds a Phase 7d.2 HNSW index and attaches it to the table.
538fn create_hnsw_index(
539    db: &mut Database,
540    table_name: &str,
541    index_name: &str,
542    column_name: &str,
543    datatype: &DataType,
544    unique: bool,
545    existing: &[(i64, Value)],
546) -> Result<String> {
547    // HNSW only makes sense on VECTOR columns. Reject anything else
548    // with a clear message — this is the most likely user error.
549    let dim = match datatype {
550        DataType::Vector(d) => *d,
551        other => {
552            return Err(SQLRiteError::General(format!(
553                "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
554            )));
555        }
556    };
557
558    if unique {
559        return Err(SQLRiteError::General(
560            "UNIQUE has no meaning for HNSW indexes".to_string(),
561        ));
562    }
563
564    // Build the in-memory graph. Distance metric is L2 by default
565    // (Phase 7d.2 doesn't yet expose a knob for picking cosine/dot —
566    // see `docs/phase-7-plan.md` for the deferral).
567    //
568    // Seed: hash the index name so different indexes get different
569    // graph topologies, but the same index always gets the same one
570    // — useful when debugging recall / index size.
571    let seed = hash_str_to_seed(index_name);
572    let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
573
574    // Snapshot the (rowid, vector) pairs into a side map so the
575    // get_vec closure below can serve them by id without re-borrowing
576    // the table (we're already holding `existing` — flatten it).
577    let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
578        std::collections::HashMap::with_capacity(existing.len());
579    for (rowid, v) in existing {
580        match v {
581            Value::Vector(vec) => {
582                if vec.len() != dim {
583                    return Err(SQLRiteError::Internal(format!(
584                        "row {rowid} stores a {}-dim vector in column '{column_name}' \
585                         declared as VECTOR({dim}) — schema invariant violated",
586                        vec.len()
587                    )));
588                }
589                vec_map.insert(*rowid, vec.clone());
590            }
591            // Non-vector values (theoretical NULL, type coercion bug)
592            // get skipped — they wouldn't have a sensible graph
593            // position anyway.
594            _ => continue,
595        }
596    }
597
598    for (rowid, _) in existing {
599        if let Some(v) = vec_map.get(rowid) {
600            let v_clone = v.clone();
601            idx.insert(*rowid, &v_clone, |id| {
602                vec_map.get(&id).cloned().unwrap_or_default()
603            });
604        }
605    }
606
607    let table_mut = db.get_table_mut(table_name.to_string())?;
608    table_mut.hnsw_indexes.push(HnswIndexEntry {
609        name: index_name.to_string(),
610        column_name: column_name.to_string(),
611        index: idx,
612    });
613    Ok(index_name.to_string())
614}
615
616/// Stable, deterministic hash of a string into a u64 RNG seed. FNV-1a;
617/// avoids pulling in `std::hash::DefaultHasher` (which is randomized
618/// per process).
619fn hash_str_to_seed(s: &str) -> u64 {
620    let mut h: u64 = 0xCBF29CE484222325;
621    for b in s.as_bytes() {
622        h ^= *b as u64;
623        h = h.wrapping_mul(0x100000001B3);
624    }
625    h
626}
627
628/// Cheap clone helper — `DataType` intentionally doesn't derive `Clone`
629/// because the enum has no ergonomic reason to be cloneable elsewhere.
630fn clone_datatype(dt: &DataType) -> DataType {
631    match dt {
632        DataType::Integer => DataType::Integer,
633        DataType::Text => DataType::Text,
634        DataType::Real => DataType::Real,
635        DataType::Bool => DataType::Bool,
636        DataType::Vector(dim) => DataType::Vector(*dim),
637        DataType::None => DataType::None,
638        DataType::Invalid => DataType::Invalid,
639    }
640}
641
642fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
643    if tables.len() != 1 {
644        return Err(SQLRiteError::NotImplemented(
645            "multi-table DELETE is not supported yet".to_string(),
646        ));
647    }
648    extract_table_name(&tables[0])
649}
650
651fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
652    if !twj.joins.is_empty() {
653        return Err(SQLRiteError::NotImplemented(
654            "JOIN is not supported yet".to_string(),
655        ));
656    }
657    match &twj.relation {
658        TableFactor::Table { name, .. } => Ok(name.to_string()),
659        _ => Err(SQLRiteError::NotImplemented(
660            "only plain table references are supported".to_string(),
661        )),
662    }
663}
664
665/// Tells the executor how to produce its candidate rowid list.
666enum RowidSource {
667    /// The WHERE was simple enough to probe a secondary index directly.
668    /// The `Vec` already contains exactly the rows the index matched;
669    /// no further WHERE evaluation is needed (the probe is precise).
670    IndexProbe(Vec<i64>),
671    /// No applicable index; caller falls back to walking `table.rowids()`
672    /// and evaluating the WHERE on each row.
673    FullScan,
674}
675
676/// Try to satisfy `WHERE` with an index probe. Currently supports the
677/// simplest shape: a single `col = literal` (or `literal = col`) where
678/// `col` is on a secondary index. AND/OR/range predicates fall back to
679/// full scan — those can be layered on later without changing the caller.
680fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
681    let Some(expr) = selection else {
682        return Ok(RowidSource::FullScan);
683    };
684    let Some((col, literal)) = try_extract_equality(expr) else {
685        return Ok(RowidSource::FullScan);
686    };
687    let Some(idx) = table.index_for_column(&col) else {
688        return Ok(RowidSource::FullScan);
689    };
690
691    // Convert the literal into a runtime Value. If the literal type doesn't
692    // match the column's index we still need correct semantics — evaluate
693    // the WHERE against every row. Fall back to full scan.
694    let literal_value = match convert_literal(&literal) {
695        Ok(v) => v,
696        Err(_) => return Ok(RowidSource::FullScan),
697    };
698
699    // Index lookup returns the full list of rowids matching this equality
700    // predicate. For unique indexes that's at most one; for non-unique it
701    // can be many.
702    let mut rowids = idx.lookup(&literal_value);
703    rowids.sort_unstable();
704    Ok(RowidSource::IndexProbe(rowids))
705}
706
707/// Recognizes `expr` as a simple equality on a column reference against a
708/// literal. Returns `(column_name, literal_value)` if the shape matches;
709/// `None` otherwise. Accepts both `col = literal` and `literal = col`.
710fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
711    // Peel off Nested parens so `WHERE (x = 1)` is recognized too.
712    let peeled = match expr {
713        Expr::Nested(inner) => inner.as_ref(),
714        other => other,
715    };
716    let Expr::BinaryOp { left, op, right } = peeled else {
717        return None;
718    };
719    if !matches!(op, BinaryOperator::Eq) {
720        return None;
721    }
722    let col_from = |e: &Expr| -> Option<String> {
723        match e {
724            Expr::Identifier(ident) => Some(ident.value.clone()),
725            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
726            _ => None,
727        }
728    };
729    let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
730        if let Expr::Value(v) = e {
731            Some(v.value.clone())
732        } else {
733            None
734        }
735    };
736    if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
737        return Some((c, l));
738    }
739    if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
740        return Some((c, l));
741    }
742    None
743}
744
745/// Recognizes the HNSW-probable query pattern and probes the graph
746/// if a matching index exists.
747///
748/// Looks for ORDER BY `vec_distance_l2(<col>, <bracket-array literal>)`
749/// where the table has an HNSW index attached to `<col>`. On a match,
750/// returns the top-k rowids straight from the graph (O(log N)). On
751/// any miss — different function name, no matching index, query
752/// dimension wrong, etc. — returns `None` and the caller falls through
753/// to the bounded-heap brute-force path (7c) or the full sort (7b),
754/// preserving correct results regardless of whether the HNSW pathway
755/// kicked in.
756///
757/// Phase 7d.2 caveats:
758/// - Only `vec_distance_l2` is recognized. Cosine and dot fall through
759///   to brute-force because we don't yet expose a per-index distance
760///   knob (deferred to Phase 7d.x — see `docs/phase-7-plan.md`).
761/// - Only ASCENDING order makes sense for "k nearest" — DESC ORDER BY
762///   `vec_distance_l2(...) LIMIT k` would mean "k farthest", which
763///   isn't what the index is built for. We don't bother to detect
764///   `ascending == false` here; the optimizer just skips and the
765///   fallback path handles it correctly (slower).
766fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
767    if k == 0 {
768        return None;
769    }
770
771    // Pattern-match: order expr must be a function call vec_distance_l2(a, b).
772    let func = match order_expr {
773        Expr::Function(f) => f,
774        _ => return None,
775    };
776    let fname = match func.name.0.as_slice() {
777        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
778        _ => return None,
779    };
780    if fname != "vec_distance_l2" {
781        return None;
782    }
783
784    // Extract the two args as raw Exprs.
785    let arg_list = match &func.args {
786        FunctionArguments::List(l) => &l.args,
787        _ => return None,
788    };
789    if arg_list.len() != 2 {
790        return None;
791    }
792    let exprs: Vec<&Expr> = arg_list
793        .iter()
794        .filter_map(|a| match a {
795            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
796            _ => None,
797        })
798        .collect();
799    if exprs.len() != 2 {
800        return None;
801    }
802
803    // One arg must be a column reference (the indexed col); the other
804    // must be a bracket-array literal (the query vector). Try both
805    // orderings — pgvector's idiom puts the column on the left, but
806    // SQL is commutative for distance.
807    let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
808        Some(v) => v,
809        None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
810            Some(v) => v,
811            None => return None,
812        },
813    };
814
815    // Find the HNSW index on this column.
816    let entry = table
817        .hnsw_indexes
818        .iter()
819        .find(|e| e.column_name == col_name)?;
820
821    // Dimension sanity check — the query vector must match the
822    // indexed column's declared dimension. If it doesn't, the brute-
823    // force fallback would also error at the vec_distance_l2 dim-check;
824    // returning None here lets that path produce the user-visible
825    // error message.
826    let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
827        Some(c) => match &c.datatype {
828            DataType::Vector(d) => *d,
829            _ => return None,
830        },
831        None => return None,
832    };
833    if query_vec.len() != declared_dim {
834        return None;
835    }
836
837    // Probe the graph. Vectors are looked up from the table's row
838    // storage — a closure rather than a `&Table` so the algorithm
839    // module stays decoupled from the SQL types.
840    let column_for_closure = col_name.clone();
841    let table_ref = table;
842    let result = entry.index.search(&query_vec, k, |id| {
843        match table_ref.get_value(&column_for_closure, id) {
844            Some(Value::Vector(v)) => v,
845            _ => Vec::new(),
846        }
847    });
848    Some(result)
849}
850
851/// Helper for `try_hnsw_probe`: given two function args, identify which
852/// one is a bare column identifier (the indexed column) and which is a
853/// bracket-array literal (the query vector). Returns
854/// `Some((column_name, query_vec))` on a match, `None` otherwise.
855fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
856    let col_name = match a {
857        Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
858        _ => return None,
859    };
860    let lit_str = match b {
861        Expr::Identifier(ident) if ident.quote_style == Some('[') => {
862            format!("[{}]", ident.value)
863        }
864        _ => return None,
865    };
866    let v = parse_vector_literal(&lit_str).ok()?;
867    Some((col_name, v))
868}
869
870/// One entry in the bounded-heap top-k path. Holds a pre-evaluated
871/// sort key + the rowid it came from. The `asc` flag inverts `Ord`
872/// so a single `BinaryHeap<HeapEntry>` works for both ASC and DESC
873/// without wrapping in `std::cmp::Reverse` at the call site:
874///
875///   - ASC LIMIT k = "k smallest": natural Ord. Max-heap top is the
876///     largest currently kept; new items smaller than top displace.
877///   - DESC LIMIT k = "k largest": Ord reversed. Max-heap top is now
878///     the smallest currently kept (under reversed Ord, smallest
879///     looks largest); new items larger than top displace.
880///
881/// In both cases the displacement test reduces to "new entry < heap top".
882struct HeapEntry {
883    key: Value,
884    rowid: i64,
885    asc: bool,
886}
887
888impl PartialEq for HeapEntry {
889    fn eq(&self, other: &Self) -> bool {
890        self.cmp(other) == Ordering::Equal
891    }
892}
893
894impl Eq for HeapEntry {}
895
896impl PartialOrd for HeapEntry {
897    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
898        Some(self.cmp(other))
899    }
900}
901
902impl Ord for HeapEntry {
903    fn cmp(&self, other: &Self) -> Ordering {
904        let raw = compare_values(Some(&self.key), Some(&other.key));
905        if self.asc { raw } else { raw.reverse() }
906    }
907}
908
909/// Bounded-heap top-k selection. Returns at most `k` rowids in the
910/// caller's desired order (ascending key for `order.ascending`,
911/// descending otherwise).
912///
913/// O(N log k) where N = `matching.len()`. Caller must check
914/// `k < matching.len()` for this to be a win — for k ≥ N the
915/// `sort_rowids` full-sort path is the same asymptotic cost without
916/// the heap overhead.
917fn select_topk(
918    matching: &[i64],
919    table: &Table,
920    order: &OrderByClause,
921    k: usize,
922) -> Result<Vec<i64>> {
923    use std::collections::BinaryHeap;
924
925    if k == 0 || matching.is_empty() {
926        return Ok(Vec::new());
927    }
928
929    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
930
931    for &rowid in matching {
932        let key = eval_expr(&order.expr, table, rowid)?;
933        let entry = HeapEntry {
934            key,
935            rowid,
936            asc: order.ascending,
937        };
938
939        if heap.len() < k {
940            heap.push(entry);
941        } else {
942            // peek() returns the largest under our direction-aware Ord
943            // — the worst entry currently kept. Displace it iff the
944            // new entry is "better" (i.e. compares Less).
945            if entry < *heap.peek().unwrap() {
946                heap.pop();
947                heap.push(entry);
948            }
949        }
950    }
951
952    // `into_sorted_vec` returns ascending under our direction-aware Ord:
953    //   ASC: ascending by raw key (what we want)
954    //   DESC: ascending under reversed Ord = descending by raw key (what
955    //         we want for an ORDER BY DESC LIMIT k result)
956    Ok(heap
957        .into_sorted_vec()
958        .into_iter()
959        .map(|e| e.rowid)
960        .collect())
961}
962
963fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
964    // Phase 7b: ORDER BY now accepts any expression (column ref,
965    // arithmetic, function call, …). Pre-compute the sort key for
966    // every rowid up front so the comparator is called O(N log N)
967    // times against pre-evaluated Values rather than re-evaluating
968    // the expression O(N log N) times. Not strictly necessary today,
969    // but vital once 7d's HNSW index lands and this same code path
970    // could be running tens of millions of distance computations.
971    let mut keys: Vec<(i64, Result<Value>)> = rowids
972        .iter()
973        .map(|r| (*r, eval_expr(&order.expr, table, *r)))
974        .collect();
975
976    // Surface the FIRST evaluation error if any. We could be lazy
977    // and let sort_by encounter it, but `Ord::cmp` can't return a
978    // Result and we'd have to swallow errors silently.
979    for (_, k) in &keys {
980        if let Err(e) = k {
981            return Err(SQLRiteError::General(format!(
982                "ORDER BY expression failed: {e}"
983            )));
984        }
985    }
986
987    keys.sort_by(|(_, ka), (_, kb)| {
988        // Both unwrap()s are safe — we just verified above that
989        // every key Result is Ok.
990        let va = ka.as_ref().unwrap();
991        let vb = kb.as_ref().unwrap();
992        let ord = compare_values(Some(va), Some(vb));
993        if order.ascending { ord } else { ord.reverse() }
994    });
995
996    // Write the sorted rowids back into the caller's slice.
997    for (i, (rowid, _)) in keys.into_iter().enumerate() {
998        rowids[i] = rowid;
999    }
1000    Ok(())
1001}
1002
1003fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1004    match (a, b) {
1005        (None, None) => Ordering::Equal,
1006        (None, _) => Ordering::Less,
1007        (_, None) => Ordering::Greater,
1008        (Some(a), Some(b)) => match (a, b) {
1009            (Value::Null, Value::Null) => Ordering::Equal,
1010            (Value::Null, _) => Ordering::Less,
1011            (_, Value::Null) => Ordering::Greater,
1012            (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1013            (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1014            (Value::Integer(x), Value::Real(y)) => {
1015                (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1016            }
1017            (Value::Real(x), Value::Integer(y)) => {
1018                x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1019            }
1020            (Value::Text(x), Value::Text(y)) => x.cmp(y),
1021            (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1022            // Cross-type fallback: stringify and compare; keeps ORDER BY total.
1023            (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1024        },
1025    }
1026}
1027
1028/// Returns `true` if the row at `rowid` matches the predicate expression.
1029pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1030    let v = eval_expr(expr, table, rowid)?;
1031    match v {
1032        Value::Bool(b) => Ok(b),
1033        Value::Null => Ok(false), // SQL NULL in a WHERE is treated as false
1034        Value::Integer(i) => Ok(i != 0),
1035        other => Err(SQLRiteError::Internal(format!(
1036            "WHERE clause must evaluate to boolean, got {}",
1037            other.to_display_string()
1038        ))),
1039    }
1040}
1041
1042fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1043    match expr {
1044        Expr::Nested(inner) => eval_expr(inner, table, rowid),
1045
1046        Expr::Identifier(ident) => {
1047            // Phase 7b — sqlparser parses bracket-array literals like
1048            // `[0.1, 0.2, 0.3]` as bracket-quoted identifiers (it inherits
1049            // MSSQL `[name]` syntax). When we see `quote_style == Some('[')`
1050            // in expression-evaluation position (SELECT projection, WHERE,
1051            // ORDER BY, function args), parse the bracketed content as a
1052            // vector literal so the rest of the executor can compare /
1053            // distance-compute against it. Same trick the INSERT parser
1054            // uses; the executor needed its own copy because expression
1055            // eval runs on a different code path.
1056            if ident.quote_style == Some('[') {
1057                let raw = format!("[{}]", ident.value);
1058                let v = parse_vector_literal(&raw)?;
1059                return Ok(Value::Vector(v));
1060            }
1061            Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1062        }
1063
1064        Expr::CompoundIdentifier(parts) => {
1065            // Accept `table.col` — we only have one table in scope, so ignore the qualifier.
1066            let col = parts
1067                .last()
1068                .map(|i| i.value.as_str())
1069                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1070            Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1071        }
1072
1073        Expr::Value(v) => convert_literal(&v.value),
1074
1075        Expr::UnaryOp { op, expr } => {
1076            let inner = eval_expr(expr, table, rowid)?;
1077            match op {
1078                UnaryOperator::Not => match inner {
1079                    Value::Bool(b) => Ok(Value::Bool(!b)),
1080                    Value::Null => Ok(Value::Null),
1081                    other => Err(SQLRiteError::Internal(format!(
1082                        "NOT applied to non-boolean value: {}",
1083                        other.to_display_string()
1084                    ))),
1085                },
1086                UnaryOperator::Minus => match inner {
1087                    Value::Integer(i) => Ok(Value::Integer(-i)),
1088                    Value::Real(f) => Ok(Value::Real(-f)),
1089                    Value::Null => Ok(Value::Null),
1090                    other => Err(SQLRiteError::Internal(format!(
1091                        "unary minus on non-numeric value: {}",
1092                        other.to_display_string()
1093                    ))),
1094                },
1095                UnaryOperator::Plus => Ok(inner),
1096                other => Err(SQLRiteError::NotImplemented(format!(
1097                    "unary operator {other:?} is not supported"
1098                ))),
1099            }
1100        }
1101
1102        Expr::BinaryOp { left, op, right } => match op {
1103            BinaryOperator::And => {
1104                let l = eval_expr(left, table, rowid)?;
1105                let r = eval_expr(right, table, rowid)?;
1106                Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1107            }
1108            BinaryOperator::Or => {
1109                let l = eval_expr(left, table, rowid)?;
1110                let r = eval_expr(right, table, rowid)?;
1111                Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1112            }
1113            cmp @ (BinaryOperator::Eq
1114            | BinaryOperator::NotEq
1115            | BinaryOperator::Lt
1116            | BinaryOperator::LtEq
1117            | BinaryOperator::Gt
1118            | BinaryOperator::GtEq) => {
1119                let l = eval_expr(left, table, rowid)?;
1120                let r = eval_expr(right, table, rowid)?;
1121                // Any comparison involving NULL is unknown → false in a WHERE.
1122                if matches!(l, Value::Null) || matches!(r, Value::Null) {
1123                    return Ok(Value::Bool(false));
1124                }
1125                let ord = compare_values(Some(&l), Some(&r));
1126                let result = match cmp {
1127                    BinaryOperator::Eq => ord == Ordering::Equal,
1128                    BinaryOperator::NotEq => ord != Ordering::Equal,
1129                    BinaryOperator::Lt => ord == Ordering::Less,
1130                    BinaryOperator::LtEq => ord != Ordering::Greater,
1131                    BinaryOperator::Gt => ord == Ordering::Greater,
1132                    BinaryOperator::GtEq => ord != Ordering::Less,
1133                    _ => unreachable!(),
1134                };
1135                Ok(Value::Bool(result))
1136            }
1137            arith @ (BinaryOperator::Plus
1138            | BinaryOperator::Minus
1139            | BinaryOperator::Multiply
1140            | BinaryOperator::Divide
1141            | BinaryOperator::Modulo) => {
1142                let l = eval_expr(left, table, rowid)?;
1143                let r = eval_expr(right, table, rowid)?;
1144                eval_arith(arith, &l, &r)
1145            }
1146            BinaryOperator::StringConcat => {
1147                let l = eval_expr(left, table, rowid)?;
1148                let r = eval_expr(right, table, rowid)?;
1149                if matches!(l, Value::Null) || matches!(r, Value::Null) {
1150                    return Ok(Value::Null);
1151                }
1152                Ok(Value::Text(format!(
1153                    "{}{}",
1154                    l.to_display_string(),
1155                    r.to_display_string()
1156                )))
1157            }
1158            other => Err(SQLRiteError::NotImplemented(format!(
1159                "binary operator {other:?} is not supported yet"
1160            ))),
1161        },
1162
1163        // Phase 7b — function-call dispatch. Currently only the three
1164        // vector-distance functions; this match arm becomes the single
1165        // place to register more SQL functions later (e.g. abs(),
1166        // length(), …) without re-touching the rest of the executor.
1167        //
1168        // Operator forms (`<->` `<=>` `<#>`) are NOT plumbed here: two
1169        // of three don't parse natively in sqlparser (we'd need a
1170        // string-preprocessing pass or a sqlparser fork). Deferred to
1171        // a follow-up sub-phase; see docs/phase-7-plan.md's "Scope
1172        // corrections" note.
1173        Expr::Function(func) => eval_function(func, table, rowid),
1174
1175        other => Err(SQLRiteError::NotImplemented(format!(
1176            "unsupported expression in WHERE/projection: {other:?}"
1177        ))),
1178    }
1179}
1180
1181/// Dispatches an `Expr::Function` to its built-in implementation.
1182/// Currently only the three vec_distance_* functions; other functions
1183/// surface as `NotImplemented` errors with the function name in the
1184/// message so users see what they tried.
1185fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1186    // Function name lives in `name.0[0]` for unqualified calls. Anything
1187    // qualified (e.g. `pkg.fn(...)`) falls through to NotImplemented.
1188    let name = match func.name.0.as_slice() {
1189        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1190        _ => {
1191            return Err(SQLRiteError::NotImplemented(format!(
1192                "qualified function names not supported: {:?}",
1193                func.name
1194            )));
1195        }
1196    };
1197
1198    match name.as_str() {
1199        "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1200            let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1201            let dist = match name.as_str() {
1202                "vec_distance_l2" => vec_distance_l2(&a, &b),
1203                "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1204                "vec_distance_dot" => vec_distance_dot(&a, &b),
1205                _ => unreachable!(),
1206            };
1207            // Widen f32 → f64 for the runtime Value. Vectors are stored
1208            // as f32 (consistent with industry convention for embeddings),
1209            // but the executor's numeric type is f64 so distances slot
1210            // into Value::Real cleanly and can be compared / ordered with
1211            // other reals via the existing arithmetic + comparison paths.
1212            Ok(Value::Real(dist as f64))
1213        }
1214        other => Err(SQLRiteError::NotImplemented(format!(
1215            "unknown function: {other}(...)"
1216        ))),
1217    }
1218}
1219
1220/// Extracts exactly two `Vec<f32>` arguments from a function call,
1221/// validating arity and that both sides are Vector-typed with matching
1222/// dimensions. Used by all three vec_distance_* functions.
1223fn extract_two_vector_args(
1224    fn_name: &str,
1225    args: &FunctionArguments,
1226    table: &Table,
1227    rowid: i64,
1228) -> Result<(Vec<f32>, Vec<f32>)> {
1229    let arg_list = match args {
1230        FunctionArguments::List(l) => &l.args,
1231        _ => {
1232            return Err(SQLRiteError::General(format!(
1233                "{fn_name}() expects exactly two vector arguments"
1234            )));
1235        }
1236    };
1237    if arg_list.len() != 2 {
1238        return Err(SQLRiteError::General(format!(
1239            "{fn_name}() expects exactly 2 arguments, got {}",
1240            arg_list.len()
1241        )));
1242    }
1243    let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1244    for (i, arg) in arg_list.iter().enumerate() {
1245        let expr = match arg {
1246            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1247            other => {
1248                return Err(SQLRiteError::NotImplemented(format!(
1249                    "{fn_name}() argument {i} has unsupported shape: {other:?}"
1250                )));
1251            }
1252        };
1253        let val = eval_expr(expr, table, rowid)?;
1254        match val {
1255            Value::Vector(v) => out.push(v),
1256            other => {
1257                return Err(SQLRiteError::General(format!(
1258                    "{fn_name}() argument {i} is not a vector: got {}",
1259                    other.to_display_string()
1260                )));
1261            }
1262        }
1263    }
1264    let b = out.pop().unwrap();
1265    let a = out.pop().unwrap();
1266    if a.len() != b.len() {
1267        return Err(SQLRiteError::General(format!(
1268            "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1269            a.len(),
1270            b.len()
1271        )));
1272    }
1273    Ok((a, b))
1274}
1275
1276/// Euclidean (L2) distance: √Σ(aᵢ − bᵢ)².
1277/// Smaller-is-closer; identical vectors return 0.0.
1278pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1279    debug_assert_eq!(a.len(), b.len());
1280    let mut sum = 0.0f32;
1281    for i in 0..a.len() {
1282        let d = a[i] - b[i];
1283        sum += d * d;
1284    }
1285    sum.sqrt()
1286}
1287
1288/// Cosine distance: 1 − (a·b) / (‖a‖·‖b‖).
1289/// Smaller-is-closer; identical (non-zero) vectors return 0.0,
1290/// orthogonal vectors return 1.0, opposite-direction vectors return 2.0.
1291///
1292/// Errors if either vector has zero magnitude — cosine similarity is
1293/// undefined for the zero vector and silently returning NaN would
1294/// poison `ORDER BY` ranking. Callers who want the silent-NaN
1295/// behavior can compute `vec_distance_dot(a, b) / (norm(a) * norm(b))`
1296/// themselves.
1297pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1298    debug_assert_eq!(a.len(), b.len());
1299    let mut dot = 0.0f32;
1300    let mut norm_a_sq = 0.0f32;
1301    let mut norm_b_sq = 0.0f32;
1302    for i in 0..a.len() {
1303        dot += a[i] * b[i];
1304        norm_a_sq += a[i] * a[i];
1305        norm_b_sq += b[i] * b[i];
1306    }
1307    let denom = (norm_a_sq * norm_b_sq).sqrt();
1308    if denom == 0.0 {
1309        return Err(SQLRiteError::General(
1310            "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1311        ));
1312    }
1313    Ok(1.0 - dot / denom)
1314}
1315
1316/// Negated dot product: −(a·b).
1317/// pgvector convention — negated so smaller-is-closer like L2 / cosine.
1318/// For unit-norm vectors `vec_distance_dot(a, b) == vec_distance_cosine(a, b) - 1`.
1319pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1320    debug_assert_eq!(a.len(), b.len());
1321    let mut dot = 0.0f32;
1322    for i in 0..a.len() {
1323        dot += a[i] * b[i];
1324    }
1325    -dot
1326}
1327
1328/// Evaluates an integer/real arithmetic op. NULL on either side propagates.
1329/// Mixed Integer/Real promotes to Real. Divide/Modulo by zero → error.
1330fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1331    if matches!(l, Value::Null) || matches!(r, Value::Null) {
1332        return Ok(Value::Null);
1333    }
1334    match (l, r) {
1335        (Value::Integer(a), Value::Integer(b)) => match op {
1336            BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1337            BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
1338            BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1339            BinaryOperator::Divide => {
1340                if *b == 0 {
1341                    Err(SQLRiteError::General("division by zero".to_string()))
1342                } else {
1343                    Ok(Value::Integer(a / b))
1344                }
1345            }
1346            BinaryOperator::Modulo => {
1347                if *b == 0 {
1348                    Err(SQLRiteError::General("modulo by zero".to_string()))
1349                } else {
1350                    Ok(Value::Integer(a % b))
1351                }
1352            }
1353            _ => unreachable!(),
1354        },
1355        // Anything involving a Real promotes both sides to f64.
1356        (a, b) => {
1357            let af = as_number(a)?;
1358            let bf = as_number(b)?;
1359            match op {
1360                BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1361                BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1362                BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1363                BinaryOperator::Divide => {
1364                    if bf == 0.0 {
1365                        Err(SQLRiteError::General("division by zero".to_string()))
1366                    } else {
1367                        Ok(Value::Real(af / bf))
1368                    }
1369                }
1370                BinaryOperator::Modulo => {
1371                    if bf == 0.0 {
1372                        Err(SQLRiteError::General("modulo by zero".to_string()))
1373                    } else {
1374                        Ok(Value::Real(af % bf))
1375                    }
1376                }
1377                _ => unreachable!(),
1378            }
1379        }
1380    }
1381}
1382
1383fn as_number(v: &Value) -> Result<f64> {
1384    match v {
1385        Value::Integer(i) => Ok(*i as f64),
1386        Value::Real(f) => Ok(*f),
1387        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1388        other => Err(SQLRiteError::General(format!(
1389            "arithmetic on non-numeric value '{}'",
1390            other.to_display_string()
1391        ))),
1392    }
1393}
1394
1395fn as_bool(v: &Value) -> Result<bool> {
1396    match v {
1397        Value::Bool(b) => Ok(*b),
1398        Value::Null => Ok(false),
1399        Value::Integer(i) => Ok(*i != 0),
1400        other => Err(SQLRiteError::Internal(format!(
1401            "expected boolean, got {}",
1402            other.to_display_string()
1403        ))),
1404    }
1405}
1406
1407fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1408    use sqlparser::ast::Value as AstValue;
1409    match v {
1410        AstValue::Number(n, _) => {
1411            if let Ok(i) = n.parse::<i64>() {
1412                Ok(Value::Integer(i))
1413            } else if let Ok(f) = n.parse::<f64>() {
1414                Ok(Value::Real(f))
1415            } else {
1416                Err(SQLRiteError::Internal(format!(
1417                    "could not parse numeric literal '{n}'"
1418                )))
1419            }
1420        }
1421        AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1422        AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1423        AstValue::Null => Ok(Value::Null),
1424        other => Err(SQLRiteError::NotImplemented(format!(
1425            "unsupported literal value: {other:?}"
1426        ))),
1427    }
1428}
1429
1430#[cfg(test)]
1431mod tests {
1432    use super::*;
1433
1434    // -----------------------------------------------------------------
1435    // Phase 7b — Vector distance function math
1436    // -----------------------------------------------------------------
1437
1438    /// Float comparison helper — distance results need a small epsilon
1439    /// because we accumulate sums across many f32 multiplies.
1440    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1441        (a - b).abs() < eps
1442    }
1443
1444    #[test]
1445    fn vec_distance_l2_identical_is_zero() {
1446        let v = vec![0.1, 0.2, 0.3];
1447        assert_eq!(vec_distance_l2(&v, &v), 0.0);
1448    }
1449
1450    #[test]
1451    fn vec_distance_l2_unit_basis_is_sqrt2() {
1452        // [1, 0] vs [0, 1]: distance = √((1-0)² + (0-1)²) = √2 ≈ 1.414
1453        let a = vec![1.0, 0.0];
1454        let b = vec![0.0, 1.0];
1455        assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
1456    }
1457
1458    #[test]
1459    fn vec_distance_l2_known_value() {
1460        // [0, 0, 0] vs [3, 4, 0]: √(9 + 16 + 0) = 5 (the classic 3-4-5 triangle).
1461        let a = vec![0.0, 0.0, 0.0];
1462        let b = vec![3.0, 4.0, 0.0];
1463        assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1464    }
1465
1466    #[test]
1467    fn vec_distance_cosine_identical_is_zero() {
1468        let v = vec![0.1, 0.2, 0.3];
1469        let d = vec_distance_cosine(&v, &v).unwrap();
1470        assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1471    }
1472
1473    #[test]
1474    fn vec_distance_cosine_orthogonal_is_one() {
1475        // Two orthogonal unit vectors should have cosine distance = 1.0
1476        // (cosine similarity = 0 → distance = 1 - 0 = 1).
1477        let a = vec![1.0, 0.0];
1478        let b = vec![0.0, 1.0];
1479        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1480    }
1481
1482    #[test]
1483    fn vec_distance_cosine_opposite_is_two() {
1484        // a and -a have cosine similarity = -1 → distance = 1 - (-1) = 2.
1485        let a = vec![1.0, 0.0, 0.0];
1486        let b = vec![-1.0, 0.0, 0.0];
1487        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1488    }
1489
1490    #[test]
1491    fn vec_distance_cosine_zero_magnitude_errors() {
1492        // Cosine is undefined for the zero vector — error rather than NaN.
1493        let a = vec![0.0, 0.0];
1494        let b = vec![1.0, 0.0];
1495        let err = vec_distance_cosine(&a, &b).unwrap_err();
1496        assert!(format!("{err}").contains("zero-magnitude"));
1497    }
1498
1499    #[test]
1500    fn vec_distance_dot_negates() {
1501        // a·b = 1*4 + 2*5 + 3*6 = 32. Negated → -32.
1502        let a = vec![1.0, 2.0, 3.0];
1503        let b = vec![4.0, 5.0, 6.0];
1504        assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1505    }
1506
1507    #[test]
1508    fn vec_distance_dot_orthogonal_is_zero() {
1509        // Orthogonal vectors have dot product 0 → negated is also 0.
1510        let a = vec![1.0, 0.0];
1511        let b = vec![0.0, 1.0];
1512        assert_eq!(vec_distance_dot(&a, &b), 0.0);
1513    }
1514
1515    #[test]
1516    fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1517        // For unit-norm vectors: dot(a,b) = cos(a,b)
1518        // → -dot(a,b) = -cos(a,b) = (1 - cos(a,b)) - 1 = vec_distance_cosine(a,b) - 1.
1519        // Useful sanity check that the two functions agree on unit vectors.
1520        let a = vec![0.6f32, 0.8]; // unit norm: √(0.36+0.64) = 1
1521        let b = vec![0.8f32, 0.6]; // unit norm too
1522        let dot = vec_distance_dot(&a, &b);
1523        let cos = vec_distance_cosine(&a, &b).unwrap();
1524        assert!(approx_eq(dot, cos - 1.0, 1e-5));
1525    }
1526
1527    // -----------------------------------------------------------------
1528    // Phase 7c — bounded-heap top-k correctness + benchmark
1529    // -----------------------------------------------------------------
1530
1531    use crate::sql::db::database::Database;
1532    use crate::sql::parser::select::SelectQuery;
1533    use sqlparser::dialect::SQLiteDialect;
1534    use sqlparser::parser::Parser;
1535
1536    /// Builds a `docs(id INTEGER PK, score REAL)` table with N rows of
1537    /// distinct positive scores so top-k tests aren't sensitive to
1538    /// tie-breaking (heap is unstable; full-sort is stable; we want
1539    /// both to agree without arguing about equal-score row order).
1540    ///
1541    /// **Why positive scores:** the INSERT parser doesn't currently
1542    /// handle `Expr::UnaryOp(Minus, …)` for negative number literals
1543    /// (it would parse `-3.14` as a unary expression and the value
1544    /// extractor would skip it). That's a pre-existing bug, out of
1545    /// scope for 7c. Using the Knuth multiplicative hash gives us
1546    /// distinct positive scrambled values without dancing around the
1547    /// negative-literal limitation.
1548    fn seed_score_table(n: usize) -> Database {
1549        let mut db = Database::new("tempdb".to_string());
1550        crate::sql::process_command(
1551            "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
1552            &mut db,
1553        )
1554        .expect("create");
1555        for i in 0..n {
1556            // Knuth multiplicative hash mod 1_000_000 — distinct,
1557            // dense in [0, 999_999], no collisions for n up to ~tens
1558            // of thousands.
1559            let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
1560            let sql = format!("INSERT INTO docs (score) VALUES ({score});");
1561            crate::sql::process_command(&sql, &mut db).expect("insert");
1562        }
1563        db
1564    }
1565
1566    /// Helper: parses an SQL SELECT into a SelectQuery so we can drive
1567    /// `select_topk` / `sort_rowids` directly without the rest of the
1568    /// process_command pipeline.
1569    fn parse_select(sql: &str) -> SelectQuery {
1570        let dialect = SQLiteDialect {};
1571        let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
1572        let stmt = ast.pop().expect("one statement");
1573        SelectQuery::new(&stmt).expect("select-query")
1574    }
1575
1576    #[test]
1577    fn topk_matches_full_sort_asc() {
1578        // Build N=200, top-k=10. Bounded heap output must equal
1579        // full-sort-then-truncate output (both produce ASC order).
1580        let db = seed_score_table(200);
1581        let table = db.get_table("docs".to_string()).unwrap();
1582        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1583        let order = q.order_by.as_ref().unwrap();
1584        let all_rowids = table.rowids();
1585
1586        // Full-sort path
1587        let mut full = all_rowids.clone();
1588        sort_rowids(&mut full, table, order).unwrap();
1589        full.truncate(10);
1590
1591        // Bounded-heap path
1592        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1593
1594        assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
1595    }
1596
1597    #[test]
1598    fn topk_matches_full_sort_desc() {
1599        // Same with DESC — verifies the direction-aware Ord wrapper.
1600        let db = seed_score_table(200);
1601        let table = db.get_table("docs".to_string()).unwrap();
1602        let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
1603        let order = q.order_by.as_ref().unwrap();
1604        let all_rowids = table.rowids();
1605
1606        let mut full = all_rowids.clone();
1607        sort_rowids(&mut full, table, order).unwrap();
1608        full.truncate(10);
1609
1610        let topk = select_topk(&all_rowids, table, order, 10).unwrap();
1611
1612        assert_eq!(
1613            topk, full,
1614            "top-k DESC via heap should match full-sort+truncate"
1615        );
1616    }
1617
1618    #[test]
1619    fn topk_k_larger_than_n_returns_everything_sorted() {
1620        // The executor branches off to the full-sort path when k >= N,
1621        // but if a caller invokes select_topk directly with k > N, it
1622        // should still produce all-sorted output (no truncation
1623        // because we don't have N items to truncate to k).
1624        let db = seed_score_table(50);
1625        let table = db.get_table("docs".to_string()).unwrap();
1626        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
1627        let order = q.order_by.as_ref().unwrap();
1628        let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
1629        assert_eq!(topk.len(), 50);
1630        // All scores in ascending order.
1631        let scores: Vec<f64> = topk
1632            .iter()
1633            .filter_map(|r| match table.get_value("score", *r) {
1634                Some(Value::Real(f)) => Some(f),
1635                _ => None,
1636            })
1637            .collect();
1638        assert!(scores.windows(2).all(|w| w[0] <= w[1]));
1639    }
1640
1641    #[test]
1642    fn topk_k_zero_returns_empty() {
1643        let db = seed_score_table(10);
1644        let table = db.get_table("docs".to_string()).unwrap();
1645        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
1646        let order = q.order_by.as_ref().unwrap();
1647        let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
1648        assert!(topk.is_empty());
1649    }
1650
1651    #[test]
1652    fn topk_empty_input_returns_empty() {
1653        let db = seed_score_table(0);
1654        let table = db.get_table("docs".to_string()).unwrap();
1655        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
1656        let order = q.order_by.as_ref().unwrap();
1657        let topk = select_topk(&[], table, order, 5).unwrap();
1658        assert!(topk.is_empty());
1659    }
1660
1661    #[test]
1662    fn topk_works_through_select_executor_with_distance_function() {
1663        // Integration check that the executor actually picks the
1664        // bounded-heap path on a KNN-shaped query and produces the
1665        // correct top-k.
1666        let mut db = Database::new("tempdb".to_string());
1667        crate::sql::process_command(
1668            "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
1669            &mut db,
1670        )
1671        .unwrap();
1672        // Five rows with distinct distances from probe [1.0, 0.0]:
1673        //   id=1 [1.0, 0.0]   distance=0
1674        //   id=2 [2.0, 0.0]   distance=1
1675        //   id=3 [0.0, 3.0]   distance=√(1+9) = √10 ≈ 3.16
1676        //   id=4 [1.0, 4.0]   distance=4
1677        //   id=5 [10.0, 10.0] distance=√(81+100) ≈ 13.45
1678        for v in &[
1679            "[1.0, 0.0]",
1680            "[2.0, 0.0]",
1681            "[0.0, 3.0]",
1682            "[1.0, 4.0]",
1683            "[10.0, 10.0]",
1684        ] {
1685            crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
1686                .unwrap();
1687        }
1688        let resp = crate::sql::process_command(
1689            "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
1690            &mut db,
1691        )
1692        .unwrap();
1693        // Top-3 closest to [1.0, 0.0] are id=1, id=2, id=3 (in that order).
1694        // The status message tells us how many rows came back.
1695        assert!(resp.contains("3 rows returned"), "got: {resp}");
1696    }
1697
1698    /// Manual benchmark — not run by default. Recommended invocation:
1699    ///
1700    ///     cargo test -p sqlrite-engine --lib topk_benchmark --release \
1701    ///         -- --ignored --nocapture
1702    ///
1703    /// (`--release` matters: Rust's optimized sort gets very fast under
1704    /// optimization, so the heap's relative advantage is best observed
1705    /// against a sort that's also been optimized.)
1706    ///
1707    /// Measured numbers on an Apple Silicon laptop with N=10_000 + k=10:
1708    ///   - bounded heap:    ~820µs
1709    ///   - full sort+trunc: ~1.5ms
1710    ///   - ratio:           ~1.8×
1711    ///
1712    /// The advantage is real but moderate at this size because the sort
1713    /// key here is a single REAL column read (cheap) and Rust's sort_by
1714    /// has a very low constant factor. The asymptotic O(N log k) vs
1715    /// O(N log N) advantage scales with N and with per-row work — KNN
1716    /// queries where the sort key is `vec_distance_l2(col, [...])` are
1717    /// where this path really pays off, because each key evaluation is
1718    /// itself O(dim) and the heap path skips the per-row evaluation
1719    /// in the comparator (see `sort_rowids` for the contrast).
1720    #[test]
1721    #[ignore]
1722    fn topk_benchmark() {
1723        use std::time::Instant;
1724        const N: usize = 10_000;
1725        const K: usize = 10;
1726
1727        let db = seed_score_table(N);
1728        let table = db.get_table("docs".to_string()).unwrap();
1729        let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
1730        let order = q.order_by.as_ref().unwrap();
1731        let all_rowids = table.rowids();
1732
1733        // Time bounded heap.
1734        let t0 = Instant::now();
1735        let _topk = select_topk(&all_rowids, table, order, K).unwrap();
1736        let heap_dur = t0.elapsed();
1737
1738        // Time full sort + truncate.
1739        let t1 = Instant::now();
1740        let mut full = all_rowids.clone();
1741        sort_rowids(&mut full, table, order).unwrap();
1742        full.truncate(K);
1743        let sort_dur = t1.elapsed();
1744
1745        let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
1746        println!("\n--- topk_benchmark (N={N}, k={K}) ---");
1747        println!("  bounded heap:   {heap_dur:?}");
1748        println!("  full sort+trunc: {sort_dur:?}");
1749        println!("  speedup ratio:  {ratio:.2}×");
1750
1751        // Soft assertion. Floor is 1.4× because the cheap-key
1752        // benchmark hovers around 1.8× empirically; setting this too
1753        // close to the measured value risks flaky CI on slower
1754        // runners. Floor of 1.4× still catches an actual regression
1755        // (e.g., if select_topk became O(N²) or stopped using the
1756        // heap entirely).
1757        assert!(
1758            ratio > 1.4,
1759            "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
1760        );
1761    }
1762}