Skip to main content

sqlrite/sql/
executor.rs

1//! Query executors — evaluate parsed SQL statements against the in-memory
2//! storage and produce formatted output.
3
4use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8    AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9    FunctionArgExpr, FunctionArguments, ObjectNamePart, Statement, TableFactor, TableWithJoins,
10    UnaryOperator, Update,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{DataType, Table, Value, parse_vector_literal};
17use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
18
19/// Executes a parsed `SelectQuery` against the database and returns a
20/// human-readable rendering of the result set (prettytable). Also returns
21/// the number of rows produced, for the top-level status message.
22/// Structured result of a SELECT: column names in projection order,
23/// and each matching row as a `Vec<Value>` aligned with the columns.
24/// Phase 5a introduced this so the public `Connection` / `Statement`
25/// API has typed rows to yield; the existing `execute_select` that
26/// returns pre-rendered text is now a thin wrapper on top.
27pub struct SelectResult {
28    pub columns: Vec<String>,
29    pub rows: Vec<Vec<Value>>,
30}
31
32/// Executes a SELECT and returns structured rows. The typed rows are
33/// what the new public API streams to callers; the REPL / Tauri app
34/// pre-render into a prettytable via `execute_select`.
35pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
36    let table = db
37        .get_table(query.table_name.clone())
38        .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
39
40    // Resolve projection to a concrete ordered column list.
41    let projected_cols: Vec<String> = match &query.projection {
42        Projection::All => table.column_names(),
43        Projection::Columns(cols) => {
44            for c in cols {
45                if !table.contains_column(c.to_string()) {
46                    return Err(SQLRiteError::Internal(format!(
47                        "Column '{c}' does not exist on table '{}'",
48                        query.table_name
49                    )));
50                }
51            }
52            cols.clone()
53        }
54    };
55
56    // Collect matching rowids. If the WHERE is the shape `col = literal`
57    // and `col` has a secondary index, probe the index for an O(log N)
58    // seek; otherwise fall back to the full table scan.
59    let matching = match select_rowids(table, query.selection.as_ref())? {
60        RowidSource::IndexProbe(rowids) => rowids,
61        RowidSource::FullScan => {
62            let mut out = Vec::new();
63            for rowid in table.rowids() {
64                if let Some(expr) = &query.selection {
65                    if !eval_predicate(expr, table, rowid)? {
66                        continue;
67                    }
68                }
69                out.push(rowid);
70            }
71            out
72        }
73    };
74    let mut matching = matching;
75
76    // Sort before applying LIMIT, matching SQL semantics.
77    if let Some(order) = &query.order_by {
78        sort_rowids(&mut matching, table, order)?;
79    }
80
81    if let Some(n) = query.limit {
82        matching.truncate(n);
83    }
84
85    // Build typed rows. Missing cells surface as `Value::Null` — that
86    // maps a column-not-present-for-this-rowid case onto the public
87    // `Row::get` → `Option<T>` surface cleanly.
88    let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
89    for rowid in &matching {
90        let row: Vec<Value> = projected_cols
91            .iter()
92            .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
93            .collect();
94        rows.push(row);
95    }
96
97    Ok(SelectResult {
98        columns: projected_cols,
99        rows,
100    })
101}
102
103/// Executes a SELECT and returns `(rendered_table, row_count)`. The
104/// REPL and Tauri app use this to keep the table-printing behaviour
105/// the engine has always shipped. Structured callers use
106/// `execute_select_rows` instead.
107pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
108    let result = execute_select_rows(query, db)?;
109    let row_count = result.rows.len();
110
111    let mut print_table = PrintTable::new();
112    let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
113    print_table.add_row(PrintRow::new(header_cells));
114
115    for row in &result.rows {
116        let cells: Vec<PrintCell> = row
117            .iter()
118            .map(|v| PrintCell::new(&v.to_display_string()))
119            .collect();
120        print_table.add_row(PrintRow::new(cells));
121    }
122
123    Ok((print_table.to_string(), row_count))
124}
125
126/// Executes a DELETE statement. Returns the number of rows removed.
127pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
128    let Statement::Delete(Delete {
129        from, selection, ..
130    }) = stmt
131    else {
132        return Err(SQLRiteError::Internal(
133            "execute_delete called on a non-DELETE statement".to_string(),
134        ));
135    };
136
137    let tables = match from {
138        FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
139    };
140    let table_name = extract_single_table_name(tables)?;
141
142    // Compute matching rowids with an immutable borrow, then mutate.
143    let matching: Vec<i64> = {
144        let table = db
145            .get_table(table_name.clone())
146            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
147        match select_rowids(table, selection.as_ref())? {
148            RowidSource::IndexProbe(rowids) => rowids,
149            RowidSource::FullScan => {
150                let mut out = Vec::new();
151                for rowid in table.rowids() {
152                    if let Some(expr) = selection {
153                        if !eval_predicate(expr, table, rowid)? {
154                            continue;
155                        }
156                    }
157                    out.push(rowid);
158                }
159                out
160            }
161        }
162    };
163
164    let table = db.get_table_mut(table_name)?;
165    for rowid in &matching {
166        table.delete_row(*rowid);
167    }
168    Ok(matching.len())
169}
170
171/// Executes an UPDATE statement. Returns the number of rows updated.
172pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
173    let Statement::Update(Update {
174        table,
175        assignments,
176        from,
177        selection,
178        ..
179    }) = stmt
180    else {
181        return Err(SQLRiteError::Internal(
182            "execute_update called on a non-UPDATE statement".to_string(),
183        ));
184    };
185
186    if from.is_some() {
187        return Err(SQLRiteError::NotImplemented(
188            "UPDATE ... FROM is not supported yet".to_string(),
189        ));
190    }
191
192    let table_name = extract_table_name(table)?;
193
194    // Resolve assignment targets to plain column names and verify they exist.
195    let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
196    {
197        let tbl = db
198            .get_table(table_name.clone())
199            .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
200        for a in assignments {
201            let col = match &a.target {
202                AssignmentTarget::ColumnName(name) => name
203                    .0
204                    .last()
205                    .map(|p| p.to_string())
206                    .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
207                AssignmentTarget::Tuple(_) => {
208                    return Err(SQLRiteError::NotImplemented(
209                        "tuple assignment targets are not supported".to_string(),
210                    ));
211                }
212            };
213            if !tbl.contains_column(col.clone()) {
214                return Err(SQLRiteError::Internal(format!(
215                    "UPDATE references unknown column '{col}'"
216                )));
217            }
218            parsed_assignments.push((col, a.value.clone()));
219        }
220    }
221
222    // Gather matching rowids + the new values to write for each assignment, under
223    // an immutable borrow. Uses the index-probe fast path when the WHERE is
224    // `col = literal` on an indexed column.
225    let work: Vec<(i64, Vec<(String, Value)>)> = {
226        let tbl = db.get_table(table_name.clone())?;
227        let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
228            RowidSource::IndexProbe(rowids) => rowids,
229            RowidSource::FullScan => {
230                let mut out = Vec::new();
231                for rowid in tbl.rowids() {
232                    if let Some(expr) = selection {
233                        if !eval_predicate(expr, tbl, rowid)? {
234                            continue;
235                        }
236                    }
237                    out.push(rowid);
238                }
239                out
240            }
241        };
242        let mut rows_to_update = Vec::new();
243        for rowid in matched_rowids {
244            let mut values = Vec::with_capacity(parsed_assignments.len());
245            for (col, expr) in &parsed_assignments {
246                // UPDATE's RHS is evaluated in the context of the row being updated,
247                // so column references on the right resolve to the current row's values.
248                let v = eval_expr(expr, tbl, rowid)?;
249                values.push((col.clone(), v));
250            }
251            rows_to_update.push((rowid, values));
252        }
253        rows_to_update
254    };
255
256    let tbl = db.get_table_mut(table_name)?;
257    for (rowid, values) in &work {
258        for (col, v) in values {
259            tbl.set_value(col, *rowid, v.clone())?;
260        }
261    }
262    Ok(work.len())
263}
264
265/// Handles `CREATE INDEX [UNIQUE] <name> ON <table> (<column>)`. Single-
266/// column indexes only; multi-column / composite indexes are future work.
267/// Returns the (possibly synthesized) index name for the status message.
268pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
269    let Statement::CreateIndex(CreateIndex {
270        name,
271        table_name,
272        columns,
273        unique,
274        if_not_exists,
275        predicate,
276        ..
277    }) = stmt
278    else {
279        return Err(SQLRiteError::Internal(
280            "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
281        ));
282    };
283
284    if predicate.is_some() {
285        return Err(SQLRiteError::NotImplemented(
286            "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
287        ));
288    }
289
290    if columns.len() != 1 {
291        return Err(SQLRiteError::NotImplemented(format!(
292            "multi-column indexes are not supported yet ({} columns given)",
293            columns.len()
294        )));
295    }
296
297    let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
298        SQLRiteError::NotImplemented(
299            "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
300        )
301    })?;
302
303    let table_name_str = table_name.to_string();
304    let column_name = match &columns[0].column.expr {
305        Expr::Identifier(ident) => ident.value.clone(),
306        Expr::CompoundIdentifier(parts) => parts
307            .last()
308            .map(|p| p.value.clone())
309            .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
310        other => {
311            return Err(SQLRiteError::NotImplemented(format!(
312                "CREATE INDEX only supports simple column references, got {other:?}"
313            )));
314        }
315    };
316
317    // Validate: table exists, column exists, type is indexable, name is unique.
318    let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
319        let table = db.get_table(table_name_str.clone()).map_err(|_| {
320            SQLRiteError::General(format!(
321                "CREATE INDEX references unknown table '{table_name_str}'"
322            ))
323        })?;
324        if !table.contains_column(column_name.clone()) {
325            return Err(SQLRiteError::General(format!(
326                "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
327            )));
328        }
329        let col = table
330            .columns
331            .iter()
332            .find(|c| c.column_name == column_name)
333            .expect("we just verified the column exists");
334        if table.index_by_name(&index_name).is_some() {
335            if *if_not_exists {
336                return Ok(index_name);
337            }
338            return Err(SQLRiteError::General(format!(
339                "index '{index_name}' already exists"
340            )));
341        }
342        let datatype = clone_datatype(&col.datatype);
343
344        // Snapshot (rowid, value) pairs so we can populate the index after
345        // it's attached. Doing this under the immutable borrow of the table
346        // means the mutable attach below can proceed without aliasing.
347        let mut pairs = Vec::new();
348        for rowid in table.rowids() {
349            if let Some(v) = table.get_value(&column_name, rowid) {
350                pairs.push((rowid, v));
351            }
352        }
353        (datatype, pairs)
354    };
355
356    // Build the index.
357    let mut idx = SecondaryIndex::new(
358        index_name.clone(),
359        table_name_str.clone(),
360        column_name.clone(),
361        &datatype,
362        *unique,
363        IndexOrigin::Explicit,
364    )?;
365
366    // Populate from the existing rows. UNIQUE violations here mean the
367    // existing data already breaks the new index's constraint — a common
368    // source of user confusion, so be explicit.
369    for (rowid, v) in &existing_rowids_and_values {
370        if *unique && idx.would_violate_unique(v) {
371            return Err(SQLRiteError::General(format!(
372                "cannot create UNIQUE index '{index_name}': column '{column_name}' \
373                 already contains the duplicate value {}",
374                v.to_display_string()
375            )));
376        }
377        idx.insert(v, *rowid)?;
378    }
379
380    // Attach to the table.
381    let table_mut = db.get_table_mut(table_name_str)?;
382    table_mut.secondary_indexes.push(idx);
383    Ok(index_name)
384}
385
386/// Cheap clone helper — `DataType` intentionally doesn't derive `Clone`
387/// because the enum has no ergonomic reason to be cloneable elsewhere.
388fn clone_datatype(dt: &DataType) -> DataType {
389    match dt {
390        DataType::Integer => DataType::Integer,
391        DataType::Text => DataType::Text,
392        DataType::Real => DataType::Real,
393        DataType::Bool => DataType::Bool,
394        DataType::Vector(dim) => DataType::Vector(*dim),
395        DataType::None => DataType::None,
396        DataType::Invalid => DataType::Invalid,
397    }
398}
399
400fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
401    if tables.len() != 1 {
402        return Err(SQLRiteError::NotImplemented(
403            "multi-table DELETE is not supported yet".to_string(),
404        ));
405    }
406    extract_table_name(&tables[0])
407}
408
409fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
410    if !twj.joins.is_empty() {
411        return Err(SQLRiteError::NotImplemented(
412            "JOIN is not supported yet".to_string(),
413        ));
414    }
415    match &twj.relation {
416        TableFactor::Table { name, .. } => Ok(name.to_string()),
417        _ => Err(SQLRiteError::NotImplemented(
418            "only plain table references are supported".to_string(),
419        )),
420    }
421}
422
423/// Tells the executor how to produce its candidate rowid list.
424enum RowidSource {
425    /// The WHERE was simple enough to probe a secondary index directly.
426    /// The `Vec` already contains exactly the rows the index matched;
427    /// no further WHERE evaluation is needed (the probe is precise).
428    IndexProbe(Vec<i64>),
429    /// No applicable index; caller falls back to walking `table.rowids()`
430    /// and evaluating the WHERE on each row.
431    FullScan,
432}
433
434/// Try to satisfy `WHERE` with an index probe. Currently supports the
435/// simplest shape: a single `col = literal` (or `literal = col`) where
436/// `col` is on a secondary index. AND/OR/range predicates fall back to
437/// full scan — those can be layered on later without changing the caller.
438fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
439    let Some(expr) = selection else {
440        return Ok(RowidSource::FullScan);
441    };
442    let Some((col, literal)) = try_extract_equality(expr) else {
443        return Ok(RowidSource::FullScan);
444    };
445    let Some(idx) = table.index_for_column(&col) else {
446        return Ok(RowidSource::FullScan);
447    };
448
449    // Convert the literal into a runtime Value. If the literal type doesn't
450    // match the column's index we still need correct semantics — evaluate
451    // the WHERE against every row. Fall back to full scan.
452    let literal_value = match convert_literal(&literal) {
453        Ok(v) => v,
454        Err(_) => return Ok(RowidSource::FullScan),
455    };
456
457    // Index lookup returns the full list of rowids matching this equality
458    // predicate. For unique indexes that's at most one; for non-unique it
459    // can be many.
460    let mut rowids = idx.lookup(&literal_value);
461    rowids.sort_unstable();
462    Ok(RowidSource::IndexProbe(rowids))
463}
464
465/// Recognizes `expr` as a simple equality on a column reference against a
466/// literal. Returns `(column_name, literal_value)` if the shape matches;
467/// `None` otherwise. Accepts both `col = literal` and `literal = col`.
468fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
469    // Peel off Nested parens so `WHERE (x = 1)` is recognized too.
470    let peeled = match expr {
471        Expr::Nested(inner) => inner.as_ref(),
472        other => other,
473    };
474    let Expr::BinaryOp { left, op, right } = peeled else {
475        return None;
476    };
477    if !matches!(op, BinaryOperator::Eq) {
478        return None;
479    }
480    let col_from = |e: &Expr| -> Option<String> {
481        match e {
482            Expr::Identifier(ident) => Some(ident.value.clone()),
483            Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
484            _ => None,
485        }
486    };
487    let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
488        if let Expr::Value(v) = e {
489            Some(v.value.clone())
490        } else {
491            None
492        }
493    };
494    if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
495        return Some((c, l));
496    }
497    if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
498        return Some((c, l));
499    }
500    None
501}
502
503fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
504    // Phase 7b: ORDER BY now accepts any expression (column ref,
505    // arithmetic, function call, …). Pre-compute the sort key for
506    // every rowid up front so the comparator is called O(N log N)
507    // times against pre-evaluated Values rather than re-evaluating
508    // the expression O(N log N) times. Not strictly necessary today,
509    // but vital once 7d's HNSW index lands and this same code path
510    // could be running tens of millions of distance computations.
511    let mut keys: Vec<(i64, Result<Value>)> = rowids
512        .iter()
513        .map(|r| (*r, eval_expr(&order.expr, table, *r)))
514        .collect();
515
516    // Surface the FIRST evaluation error if any. We could be lazy
517    // and let sort_by encounter it, but `Ord::cmp` can't return a
518    // Result and we'd have to swallow errors silently.
519    for (_, k) in &keys {
520        if let Err(e) = k {
521            return Err(SQLRiteError::General(format!(
522                "ORDER BY expression failed: {e}"
523            )));
524        }
525    }
526
527    keys.sort_by(|(_, ka), (_, kb)| {
528        // Both unwrap()s are safe — we just verified above that
529        // every key Result is Ok.
530        let va = ka.as_ref().unwrap();
531        let vb = kb.as_ref().unwrap();
532        let ord = compare_values(Some(va), Some(vb));
533        if order.ascending { ord } else { ord.reverse() }
534    });
535
536    // Write the sorted rowids back into the caller's slice.
537    for (i, (rowid, _)) in keys.into_iter().enumerate() {
538        rowids[i] = rowid;
539    }
540    Ok(())
541}
542
543fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
544    match (a, b) {
545        (None, None) => Ordering::Equal,
546        (None, _) => Ordering::Less,
547        (_, None) => Ordering::Greater,
548        (Some(a), Some(b)) => match (a, b) {
549            (Value::Null, Value::Null) => Ordering::Equal,
550            (Value::Null, _) => Ordering::Less,
551            (_, Value::Null) => Ordering::Greater,
552            (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
553            (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
554            (Value::Integer(x), Value::Real(y)) => {
555                (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
556            }
557            (Value::Real(x), Value::Integer(y)) => {
558                x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
559            }
560            (Value::Text(x), Value::Text(y)) => x.cmp(y),
561            (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
562            // Cross-type fallback: stringify and compare; keeps ORDER BY total.
563            (x, y) => x.to_display_string().cmp(&y.to_display_string()),
564        },
565    }
566}
567
568/// Returns `true` if the row at `rowid` matches the predicate expression.
569pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
570    let v = eval_expr(expr, table, rowid)?;
571    match v {
572        Value::Bool(b) => Ok(b),
573        Value::Null => Ok(false), // SQL NULL in a WHERE is treated as false
574        Value::Integer(i) => Ok(i != 0),
575        other => Err(SQLRiteError::Internal(format!(
576            "WHERE clause must evaluate to boolean, got {}",
577            other.to_display_string()
578        ))),
579    }
580}
581
582fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
583    match expr {
584        Expr::Nested(inner) => eval_expr(inner, table, rowid),
585
586        Expr::Identifier(ident) => {
587            // Phase 7b — sqlparser parses bracket-array literals like
588            // `[0.1, 0.2, 0.3]` as bracket-quoted identifiers (it inherits
589            // MSSQL `[name]` syntax). When we see `quote_style == Some('[')`
590            // in expression-evaluation position (SELECT projection, WHERE,
591            // ORDER BY, function args), parse the bracketed content as a
592            // vector literal so the rest of the executor can compare /
593            // distance-compute against it. Same trick the INSERT parser
594            // uses; the executor needed its own copy because expression
595            // eval runs on a different code path.
596            if ident.quote_style == Some('[') {
597                let raw = format!("[{}]", ident.value);
598                let v = parse_vector_literal(&raw)?;
599                return Ok(Value::Vector(v));
600            }
601            Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
602        }
603
604        Expr::CompoundIdentifier(parts) => {
605            // Accept `table.col` — we only have one table in scope, so ignore the qualifier.
606            let col = parts
607                .last()
608                .map(|i| i.value.as_str())
609                .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
610            Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
611        }
612
613        Expr::Value(v) => convert_literal(&v.value),
614
615        Expr::UnaryOp { op, expr } => {
616            let inner = eval_expr(expr, table, rowid)?;
617            match op {
618                UnaryOperator::Not => match inner {
619                    Value::Bool(b) => Ok(Value::Bool(!b)),
620                    Value::Null => Ok(Value::Null),
621                    other => Err(SQLRiteError::Internal(format!(
622                        "NOT applied to non-boolean value: {}",
623                        other.to_display_string()
624                    ))),
625                },
626                UnaryOperator::Minus => match inner {
627                    Value::Integer(i) => Ok(Value::Integer(-i)),
628                    Value::Real(f) => Ok(Value::Real(-f)),
629                    Value::Null => Ok(Value::Null),
630                    other => Err(SQLRiteError::Internal(format!(
631                        "unary minus on non-numeric value: {}",
632                        other.to_display_string()
633                    ))),
634                },
635                UnaryOperator::Plus => Ok(inner),
636                other => Err(SQLRiteError::NotImplemented(format!(
637                    "unary operator {other:?} is not supported"
638                ))),
639            }
640        }
641
642        Expr::BinaryOp { left, op, right } => match op {
643            BinaryOperator::And => {
644                let l = eval_expr(left, table, rowid)?;
645                let r = eval_expr(right, table, rowid)?;
646                Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
647            }
648            BinaryOperator::Or => {
649                let l = eval_expr(left, table, rowid)?;
650                let r = eval_expr(right, table, rowid)?;
651                Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
652            }
653            cmp @ (BinaryOperator::Eq
654            | BinaryOperator::NotEq
655            | BinaryOperator::Lt
656            | BinaryOperator::LtEq
657            | BinaryOperator::Gt
658            | BinaryOperator::GtEq) => {
659                let l = eval_expr(left, table, rowid)?;
660                let r = eval_expr(right, table, rowid)?;
661                // Any comparison involving NULL is unknown → false in a WHERE.
662                if matches!(l, Value::Null) || matches!(r, Value::Null) {
663                    return Ok(Value::Bool(false));
664                }
665                let ord = compare_values(Some(&l), Some(&r));
666                let result = match cmp {
667                    BinaryOperator::Eq => ord == Ordering::Equal,
668                    BinaryOperator::NotEq => ord != Ordering::Equal,
669                    BinaryOperator::Lt => ord == Ordering::Less,
670                    BinaryOperator::LtEq => ord != Ordering::Greater,
671                    BinaryOperator::Gt => ord == Ordering::Greater,
672                    BinaryOperator::GtEq => ord != Ordering::Less,
673                    _ => unreachable!(),
674                };
675                Ok(Value::Bool(result))
676            }
677            arith @ (BinaryOperator::Plus
678            | BinaryOperator::Minus
679            | BinaryOperator::Multiply
680            | BinaryOperator::Divide
681            | BinaryOperator::Modulo) => {
682                let l = eval_expr(left, table, rowid)?;
683                let r = eval_expr(right, table, rowid)?;
684                eval_arith(arith, &l, &r)
685            }
686            BinaryOperator::StringConcat => {
687                let l = eval_expr(left, table, rowid)?;
688                let r = eval_expr(right, table, rowid)?;
689                if matches!(l, Value::Null) || matches!(r, Value::Null) {
690                    return Ok(Value::Null);
691                }
692                Ok(Value::Text(format!(
693                    "{}{}",
694                    l.to_display_string(),
695                    r.to_display_string()
696                )))
697            }
698            other => Err(SQLRiteError::NotImplemented(format!(
699                "binary operator {other:?} is not supported yet"
700            ))),
701        },
702
703        // Phase 7b — function-call dispatch. Currently only the three
704        // vector-distance functions; this match arm becomes the single
705        // place to register more SQL functions later (e.g. abs(),
706        // length(), …) without re-touching the rest of the executor.
707        //
708        // Operator forms (`<->` `<=>` `<#>`) are NOT plumbed here: two
709        // of three don't parse natively in sqlparser (we'd need a
710        // string-preprocessing pass or a sqlparser fork). Deferred to
711        // a follow-up sub-phase; see docs/phase-7-plan.md's "Scope
712        // corrections" note.
713        Expr::Function(func) => eval_function(func, table, rowid),
714
715        other => Err(SQLRiteError::NotImplemented(format!(
716            "unsupported expression in WHERE/projection: {other:?}"
717        ))),
718    }
719}
720
721/// Dispatches an `Expr::Function` to its built-in implementation.
722/// Currently only the three vec_distance_* functions; other functions
723/// surface as `NotImplemented` errors with the function name in the
724/// message so users see what they tried.
725fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
726    // Function name lives in `name.0[0]` for unqualified calls. Anything
727    // qualified (e.g. `pkg.fn(...)`) falls through to NotImplemented.
728    let name = match func.name.0.as_slice() {
729        [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
730        _ => {
731            return Err(SQLRiteError::NotImplemented(format!(
732                "qualified function names not supported: {:?}",
733                func.name
734            )));
735        }
736    };
737
738    match name.as_str() {
739        "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
740            let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
741            let dist = match name.as_str() {
742                "vec_distance_l2" => vec_distance_l2(&a, &b),
743                "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
744                "vec_distance_dot" => vec_distance_dot(&a, &b),
745                _ => unreachable!(),
746            };
747            // Widen f32 → f64 for the runtime Value. Vectors are stored
748            // as f32 (consistent with industry convention for embeddings),
749            // but the executor's numeric type is f64 so distances slot
750            // into Value::Real cleanly and can be compared / ordered with
751            // other reals via the existing arithmetic + comparison paths.
752            Ok(Value::Real(dist as f64))
753        }
754        other => Err(SQLRiteError::NotImplemented(format!(
755            "unknown function: {other}(...)"
756        ))),
757    }
758}
759
760/// Extracts exactly two `Vec<f32>` arguments from a function call,
761/// validating arity and that both sides are Vector-typed with matching
762/// dimensions. Used by all three vec_distance_* functions.
763fn extract_two_vector_args(
764    fn_name: &str,
765    args: &FunctionArguments,
766    table: &Table,
767    rowid: i64,
768) -> Result<(Vec<f32>, Vec<f32>)> {
769    let arg_list = match args {
770        FunctionArguments::List(l) => &l.args,
771        _ => {
772            return Err(SQLRiteError::General(format!(
773                "{fn_name}() expects exactly two vector arguments"
774            )));
775        }
776    };
777    if arg_list.len() != 2 {
778        return Err(SQLRiteError::General(format!(
779            "{fn_name}() expects exactly 2 arguments, got {}",
780            arg_list.len()
781        )));
782    }
783    let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
784    for (i, arg) in arg_list.iter().enumerate() {
785        let expr = match arg {
786            FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
787            other => {
788                return Err(SQLRiteError::NotImplemented(format!(
789                    "{fn_name}() argument {i} has unsupported shape: {other:?}"
790                )));
791            }
792        };
793        let val = eval_expr(expr, table, rowid)?;
794        match val {
795            Value::Vector(v) => out.push(v),
796            other => {
797                return Err(SQLRiteError::General(format!(
798                    "{fn_name}() argument {i} is not a vector: got {}",
799                    other.to_display_string()
800                )));
801            }
802        }
803    }
804    let b = out.pop().unwrap();
805    let a = out.pop().unwrap();
806    if a.len() != b.len() {
807        return Err(SQLRiteError::General(format!(
808            "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
809            a.len(),
810            b.len()
811        )));
812    }
813    Ok((a, b))
814}
815
816/// Euclidean (L2) distance: √Σ(aᵢ − bᵢ)².
817/// Smaller-is-closer; identical vectors return 0.0.
818pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
819    debug_assert_eq!(a.len(), b.len());
820    let mut sum = 0.0f32;
821    for i in 0..a.len() {
822        let d = a[i] - b[i];
823        sum += d * d;
824    }
825    sum.sqrt()
826}
827
828/// Cosine distance: 1 − (a·b) / (‖a‖·‖b‖).
829/// Smaller-is-closer; identical (non-zero) vectors return 0.0,
830/// orthogonal vectors return 1.0, opposite-direction vectors return 2.0.
831///
832/// Errors if either vector has zero magnitude — cosine similarity is
833/// undefined for the zero vector and silently returning NaN would
834/// poison `ORDER BY` ranking. Callers who want the silent-NaN
835/// behavior can compute `vec_distance_dot(a, b) / (norm(a) * norm(b))`
836/// themselves.
837pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
838    debug_assert_eq!(a.len(), b.len());
839    let mut dot = 0.0f32;
840    let mut norm_a_sq = 0.0f32;
841    let mut norm_b_sq = 0.0f32;
842    for i in 0..a.len() {
843        dot += a[i] * b[i];
844        norm_a_sq += a[i] * a[i];
845        norm_b_sq += b[i] * b[i];
846    }
847    let denom = (norm_a_sq * norm_b_sq).sqrt();
848    if denom == 0.0 {
849        return Err(SQLRiteError::General(
850            "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
851        ));
852    }
853    Ok(1.0 - dot / denom)
854}
855
856/// Negated dot product: −(a·b).
857/// pgvector convention — negated so smaller-is-closer like L2 / cosine.
858/// For unit-norm vectors `vec_distance_dot(a, b) == vec_distance_cosine(a, b) - 1`.
859pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
860    debug_assert_eq!(a.len(), b.len());
861    let mut dot = 0.0f32;
862    for i in 0..a.len() {
863        dot += a[i] * b[i];
864    }
865    -dot
866}
867
868/// Evaluates an integer/real arithmetic op. NULL on either side propagates.
869/// Mixed Integer/Real promotes to Real. Divide/Modulo by zero → error.
870fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
871    if matches!(l, Value::Null) || matches!(r, Value::Null) {
872        return Ok(Value::Null);
873    }
874    match (l, r) {
875        (Value::Integer(a), Value::Integer(b)) => match op {
876            BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
877            BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
878            BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
879            BinaryOperator::Divide => {
880                if *b == 0 {
881                    Err(SQLRiteError::General("division by zero".to_string()))
882                } else {
883                    Ok(Value::Integer(a / b))
884                }
885            }
886            BinaryOperator::Modulo => {
887                if *b == 0 {
888                    Err(SQLRiteError::General("modulo by zero".to_string()))
889                } else {
890                    Ok(Value::Integer(a % b))
891                }
892            }
893            _ => unreachable!(),
894        },
895        // Anything involving a Real promotes both sides to f64.
896        (a, b) => {
897            let af = as_number(a)?;
898            let bf = as_number(b)?;
899            match op {
900                BinaryOperator::Plus => Ok(Value::Real(af + bf)),
901                BinaryOperator::Minus => Ok(Value::Real(af - bf)),
902                BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
903                BinaryOperator::Divide => {
904                    if bf == 0.0 {
905                        Err(SQLRiteError::General("division by zero".to_string()))
906                    } else {
907                        Ok(Value::Real(af / bf))
908                    }
909                }
910                BinaryOperator::Modulo => {
911                    if bf == 0.0 {
912                        Err(SQLRiteError::General("modulo by zero".to_string()))
913                    } else {
914                        Ok(Value::Real(af % bf))
915                    }
916                }
917                _ => unreachable!(),
918            }
919        }
920    }
921}
922
923fn as_number(v: &Value) -> Result<f64> {
924    match v {
925        Value::Integer(i) => Ok(*i as f64),
926        Value::Real(f) => Ok(*f),
927        Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
928        other => Err(SQLRiteError::General(format!(
929            "arithmetic on non-numeric value '{}'",
930            other.to_display_string()
931        ))),
932    }
933}
934
935fn as_bool(v: &Value) -> Result<bool> {
936    match v {
937        Value::Bool(b) => Ok(*b),
938        Value::Null => Ok(false),
939        Value::Integer(i) => Ok(*i != 0),
940        other => Err(SQLRiteError::Internal(format!(
941            "expected boolean, got {}",
942            other.to_display_string()
943        ))),
944    }
945}
946
947fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
948    use sqlparser::ast::Value as AstValue;
949    match v {
950        AstValue::Number(n, _) => {
951            if let Ok(i) = n.parse::<i64>() {
952                Ok(Value::Integer(i))
953            } else if let Ok(f) = n.parse::<f64>() {
954                Ok(Value::Real(f))
955            } else {
956                Err(SQLRiteError::Internal(format!(
957                    "could not parse numeric literal '{n}'"
958                )))
959            }
960        }
961        AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
962        AstValue::Boolean(b) => Ok(Value::Bool(*b)),
963        AstValue::Null => Ok(Value::Null),
964        other => Err(SQLRiteError::NotImplemented(format!(
965            "unsupported literal value: {other:?}"
966        ))),
967    }
968}
969
970#[cfg(test)]
971mod tests {
972    use super::*;
973
974    // -----------------------------------------------------------------
975    // Phase 7b — Vector distance function math
976    // -----------------------------------------------------------------
977
978    /// Float comparison helper — distance results need a small epsilon
979    /// because we accumulate sums across many f32 multiplies.
980    fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
981        (a - b).abs() < eps
982    }
983
984    #[test]
985    fn vec_distance_l2_identical_is_zero() {
986        let v = vec![0.1, 0.2, 0.3];
987        assert_eq!(vec_distance_l2(&v, &v), 0.0);
988    }
989
990    #[test]
991    fn vec_distance_l2_unit_basis_is_sqrt2() {
992        // [1, 0] vs [0, 1]: distance = √((1-0)² + (0-1)²) = √2 ≈ 1.414
993        let a = vec![1.0, 0.0];
994        let b = vec![0.0, 1.0];
995        assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
996    }
997
998    #[test]
999    fn vec_distance_l2_known_value() {
1000        // [0, 0, 0] vs [3, 4, 0]: √(9 + 16 + 0) = 5 (the classic 3-4-5 triangle).
1001        let a = vec![0.0, 0.0, 0.0];
1002        let b = vec![3.0, 4.0, 0.0];
1003        assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
1004    }
1005
1006    #[test]
1007    fn vec_distance_cosine_identical_is_zero() {
1008        let v = vec![0.1, 0.2, 0.3];
1009        let d = vec_distance_cosine(&v, &v).unwrap();
1010        assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
1011    }
1012
1013    #[test]
1014    fn vec_distance_cosine_orthogonal_is_one() {
1015        // Two orthogonal unit vectors should have cosine distance = 1.0
1016        // (cosine similarity = 0 → distance = 1 - 0 = 1).
1017        let a = vec![1.0, 0.0];
1018        let b = vec![0.0, 1.0];
1019        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
1020    }
1021
1022    #[test]
1023    fn vec_distance_cosine_opposite_is_two() {
1024        // a and -a have cosine similarity = -1 → distance = 1 - (-1) = 2.
1025        let a = vec![1.0, 0.0, 0.0];
1026        let b = vec![-1.0, 0.0, 0.0];
1027        assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
1028    }
1029
1030    #[test]
1031    fn vec_distance_cosine_zero_magnitude_errors() {
1032        // Cosine is undefined for the zero vector — error rather than NaN.
1033        let a = vec![0.0, 0.0];
1034        let b = vec![1.0, 0.0];
1035        let err = vec_distance_cosine(&a, &b).unwrap_err();
1036        assert!(format!("{err}").contains("zero-magnitude"));
1037    }
1038
1039    #[test]
1040    fn vec_distance_dot_negates() {
1041        // a·b = 1*4 + 2*5 + 3*6 = 32. Negated → -32.
1042        let a = vec![1.0, 2.0, 3.0];
1043        let b = vec![4.0, 5.0, 6.0];
1044        assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
1045    }
1046
1047    #[test]
1048    fn vec_distance_dot_orthogonal_is_zero() {
1049        // Orthogonal vectors have dot product 0 → negated is also 0.
1050        let a = vec![1.0, 0.0];
1051        let b = vec![0.0, 1.0];
1052        assert_eq!(vec_distance_dot(&a, &b), 0.0);
1053    }
1054
1055    #[test]
1056    fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
1057        // For unit-norm vectors: dot(a,b) = cos(a,b)
1058        // → -dot(a,b) = -cos(a,b) = (1 - cos(a,b)) - 1 = vec_distance_cosine(a,b) - 1.
1059        // Useful sanity check that the two functions agree on unit vectors.
1060        let a = vec![0.6f32, 0.8]; // unit norm: √(0.36+0.64) = 1
1061        let b = vec![0.8f32, 0.6]; // unit norm too
1062        let dot = vec_distance_dot(&a, &b);
1063        let cos = vec_distance_cosine(&a, &b).unwrap();
1064        assert!(approx_eq(dot, cos - 1.0, 1e-5));
1065    }
1066}