use std::cmp::Ordering;
use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
use sqlparser::ast::{
AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
TableWithJoins, UnaryOperator, Update,
};
use crate::error::{Result, SQLRiteError};
use crate::sql::db::database::Database;
use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
use crate::sql::db::table::{DataType, HnswIndexEntry, Table, Value, parse_vector_literal};
use crate::sql::hnsw::{DistanceMetric, HnswIndex};
use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
pub struct SelectResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<Value>>,
}
pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
let table = db
.get_table(query.table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
let projected_cols: Vec<String> = match &query.projection {
Projection::All => table.column_names(),
Projection::Columns(cols) => {
for c in cols {
if !table.contains_column(c.to_string()) {
return Err(SQLRiteError::Internal(format!(
"Column '{c}' does not exist on table '{}'",
query.table_name
)));
}
}
cols.clone()
}
};
let matching = match select_rowids(table, query.selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in table.rowids() {
if let Some(expr) = &query.selection {
if !eval_predicate(expr, table, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
};
let mut matching = matching;
match (&query.order_by, query.limit) {
(Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
matching = try_hnsw_probe(table, &order.expr, k).unwrap();
}
(Some(order), Some(k)) if k < matching.len() => {
matching = select_topk(&matching, table, order, k)?;
}
(Some(order), _) => {
sort_rowids(&mut matching, table, order)?;
if let Some(k) = query.limit {
matching.truncate(k);
}
}
(None, Some(k)) => {
matching.truncate(k);
}
(None, None) => {}
}
let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
for rowid in &matching {
let row: Vec<Value> = projected_cols
.iter()
.map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
.collect();
rows.push(row);
}
Ok(SelectResult {
columns: projected_cols,
rows,
})
}
pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
let result = execute_select_rows(query, db)?;
let row_count = result.rows.len();
let mut print_table = PrintTable::new();
let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
print_table.add_row(PrintRow::new(header_cells));
for row in &result.rows {
let cells: Vec<PrintCell> = row
.iter()
.map(|v| PrintCell::new(&v.to_display_string()))
.collect();
print_table.add_row(PrintRow::new(cells));
}
Ok((print_table.to_string(), row_count))
}
pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
let Statement::Delete(Delete {
from, selection, ..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_delete called on a non-DELETE statement".to_string(),
));
};
let tables = match from {
FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
};
let table_name = extract_single_table_name(tables)?;
let matching: Vec<i64> = {
let table = db
.get_table(table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
match select_rowids(table, selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in table.rowids() {
if let Some(expr) = selection {
if !eval_predicate(expr, table, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
}
};
let table = db.get_table_mut(table_name)?;
for rowid in &matching {
table.delete_row(*rowid);
}
if !matching.is_empty() {
for entry in &mut table.hnsw_indexes {
entry.needs_rebuild = true;
}
}
Ok(matching.len())
}
pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
let Statement::Update(Update {
table,
assignments,
from,
selection,
..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_update called on a non-UPDATE statement".to_string(),
));
};
if from.is_some() {
return Err(SQLRiteError::NotImplemented(
"UPDATE ... FROM is not supported yet".to_string(),
));
}
let table_name = extract_table_name(table)?;
let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
{
let tbl = db
.get_table(table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
for a in assignments {
let col = match &a.target {
AssignmentTarget::ColumnName(name) => name
.0
.last()
.map(|p| p.to_string())
.ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
AssignmentTarget::Tuple(_) => {
return Err(SQLRiteError::NotImplemented(
"tuple assignment targets are not supported".to_string(),
));
}
};
if !tbl.contains_column(col.clone()) {
return Err(SQLRiteError::Internal(format!(
"UPDATE references unknown column '{col}'"
)));
}
parsed_assignments.push((col, a.value.clone()));
}
}
let work: Vec<(i64, Vec<(String, Value)>)> = {
let tbl = db.get_table(table_name.clone())?;
let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in tbl.rowids() {
if let Some(expr) = selection {
if !eval_predicate(expr, tbl, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
};
let mut rows_to_update = Vec::new();
for rowid in matched_rowids {
let mut values = Vec::with_capacity(parsed_assignments.len());
for (col, expr) in &parsed_assignments {
let v = eval_expr(expr, tbl, rowid)?;
values.push((col.clone(), v));
}
rows_to_update.push((rowid, values));
}
rows_to_update
};
let tbl = db.get_table_mut(table_name)?;
for (rowid, values) in &work {
for (col, v) in values {
tbl.set_value(col, *rowid, v.clone())?;
}
}
if !work.is_empty() {
let updated_columns: std::collections::HashSet<&str> = work
.iter()
.flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
.collect();
for entry in &mut tbl.hnsw_indexes {
if updated_columns.contains(entry.column_name.as_str()) {
entry.needs_rebuild = true;
}
}
}
Ok(work.len())
}
pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
let Statement::CreateIndex(CreateIndex {
name,
table_name,
columns,
using,
unique,
if_not_exists,
predicate,
..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_create_index called on a non-CREATE-INDEX statement".to_string(),
));
};
if predicate.is_some() {
return Err(SQLRiteError::NotImplemented(
"partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
));
}
if columns.len() != 1 {
return Err(SQLRiteError::NotImplemented(format!(
"multi-column indexes are not supported yet ({} columns given)",
columns.len()
)));
}
let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
SQLRiteError::NotImplemented(
"anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
)
})?;
let method = match using {
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
IndexMethod::Hnsw
}
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
IndexMethod::Btree
}
Some(other) => {
return Err(SQLRiteError::NotImplemented(format!(
"CREATE INDEX … USING {other:?} is not supported (try `hnsw` or no USING clause)"
)));
}
None => IndexMethod::Btree,
};
let table_name_str = table_name.to_string();
let column_name = match &columns[0].column.expr {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"CREATE INDEX only supports simple column references, got {other:?}"
)));
}
};
let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
let table = db.get_table(table_name_str.clone()).map_err(|_| {
SQLRiteError::General(format!(
"CREATE INDEX references unknown table '{table_name_str}'"
))
})?;
if !table.contains_column(column_name.clone()) {
return Err(SQLRiteError::General(format!(
"CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
)));
}
let col = table
.columns
.iter()
.find(|c| c.column_name == column_name)
.expect("we just verified the column exists");
if table.index_by_name(&index_name).is_some()
|| table.hnsw_indexes.iter().any(|i| i.name == index_name)
{
if *if_not_exists {
return Ok(index_name);
}
return Err(SQLRiteError::General(format!(
"index '{index_name}' already exists"
)));
}
let datatype = clone_datatype(&col.datatype);
let mut pairs = Vec::new();
for rowid in table.rowids() {
if let Some(v) = table.get_value(&column_name, rowid) {
pairs.push((rowid, v));
}
}
(datatype, pairs)
};
match method {
IndexMethod::Btree => create_btree_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
IndexMethod::Hnsw => create_hnsw_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
}
}
#[derive(Debug, Clone, Copy)]
enum IndexMethod {
Btree,
Hnsw,
}
fn create_btree_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
let mut idx = SecondaryIndex::new(
index_name.to_string(),
table_name.to_string(),
column_name.to_string(),
datatype,
unique,
IndexOrigin::Explicit,
)?;
for (rowid, v) in existing {
if unique && idx.would_violate_unique(v) {
return Err(SQLRiteError::General(format!(
"cannot create UNIQUE index '{index_name}': column '{column_name}' \
already contains the duplicate value {}",
v.to_display_string()
)));
}
idx.insert(v, *rowid)?;
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.secondary_indexes.push(idx);
Ok(index_name.to_string())
}
fn create_hnsw_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
let dim = match datatype {
DataType::Vector(d) => *d,
other => {
return Err(SQLRiteError::General(format!(
"USING hnsw requires a VECTOR column; '{column_name}' is {other}"
)));
}
};
if unique {
return Err(SQLRiteError::General(
"UNIQUE has no meaning for HNSW indexes".to_string(),
));
}
let seed = hash_str_to_seed(index_name);
let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
std::collections::HashMap::with_capacity(existing.len());
for (rowid, v) in existing {
match v {
Value::Vector(vec) => {
if vec.len() != dim {
return Err(SQLRiteError::Internal(format!(
"row {rowid} stores a {}-dim vector in column '{column_name}' \
declared as VECTOR({dim}) — schema invariant violated",
vec.len()
)));
}
vec_map.insert(*rowid, vec.clone());
}
_ => continue,
}
}
for (rowid, _) in existing {
if let Some(v) = vec_map.get(rowid) {
let v_clone = v.clone();
idx.insert(*rowid, &v_clone, |id| {
vec_map.get(&id).cloned().unwrap_or_default()
});
}
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.hnsw_indexes.push(HnswIndexEntry {
name: index_name.to_string(),
column_name: column_name.to_string(),
index: idx,
needs_rebuild: false,
});
Ok(index_name.to_string())
}
fn hash_str_to_seed(s: &str) -> u64 {
let mut h: u64 = 0xCBF29CE484222325;
for b in s.as_bytes() {
h ^= *b as u64;
h = h.wrapping_mul(0x100000001B3);
}
h
}
fn clone_datatype(dt: &DataType) -> DataType {
match dt {
DataType::Integer => DataType::Integer,
DataType::Text => DataType::Text,
DataType::Real => DataType::Real,
DataType::Bool => DataType::Bool,
DataType::Vector(dim) => DataType::Vector(*dim),
DataType::None => DataType::None,
DataType::Invalid => DataType::Invalid,
}
}
fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
if tables.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"multi-table DELETE is not supported yet".to_string(),
));
}
extract_table_name(&tables[0])
}
fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
if !twj.joins.is_empty() {
return Err(SQLRiteError::NotImplemented(
"JOIN is not supported yet".to_string(),
));
}
match &twj.relation {
TableFactor::Table { name, .. } => Ok(name.to_string()),
_ => Err(SQLRiteError::NotImplemented(
"only plain table references are supported".to_string(),
)),
}
}
enum RowidSource {
IndexProbe(Vec<i64>),
FullScan,
}
fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
let Some(expr) = selection else {
return Ok(RowidSource::FullScan);
};
let Some((col, literal)) = try_extract_equality(expr) else {
return Ok(RowidSource::FullScan);
};
let Some(idx) = table.index_for_column(&col) else {
return Ok(RowidSource::FullScan);
};
let literal_value = match convert_literal(&literal) {
Ok(v) => v,
Err(_) => return Ok(RowidSource::FullScan),
};
let mut rowids = idx.lookup(&literal_value);
rowids.sort_unstable();
Ok(RowidSource::IndexProbe(rowids))
}
fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
let peeled = match expr {
Expr::Nested(inner) => inner.as_ref(),
other => other,
};
let Expr::BinaryOp { left, op, right } = peeled else {
return None;
};
if !matches!(op, BinaryOperator::Eq) {
return None;
}
let col_from = |e: &Expr| -> Option<String> {
match e {
Expr::Identifier(ident) => Some(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
_ => None,
}
};
let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
if let Expr::Value(v) = e {
Some(v.value.clone())
} else {
None
}
};
if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
return Some((c, l));
}
if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
return Some((c, l));
}
None
}
fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
if k == 0 {
return None;
}
let func = match order_expr {
Expr::Function(f) => f,
_ => return None,
};
let fname = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => return None,
};
if fname != "vec_distance_l2" {
return None;
}
let arg_list = match &func.args {
FunctionArguments::List(l) => &l.args,
_ => return None,
};
if arg_list.len() != 2 {
return None;
}
let exprs: Vec<&Expr> = arg_list
.iter()
.filter_map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
_ => None,
})
.collect();
if exprs.len() != 2 {
return None;
}
let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
Some(v) => v,
None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
Some(v) => v,
None => return None,
},
};
let entry = table
.hnsw_indexes
.iter()
.find(|e| e.column_name == col_name)?;
let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
Some(c) => match &c.datatype {
DataType::Vector(d) => *d,
_ => return None,
},
None => return None,
};
if query_vec.len() != declared_dim {
return None;
}
let column_for_closure = col_name.clone();
let table_ref = table;
let result = entry.index.search(&query_vec, k, |id| {
match table_ref.get_value(&column_for_closure, id) {
Some(Value::Vector(v)) => v,
_ => Vec::new(),
}
});
Some(result)
}
fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
let col_name = match a {
Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
_ => return None,
};
let lit_str = match b {
Expr::Identifier(ident) if ident.quote_style == Some('[') => {
format!("[{}]", ident.value)
}
_ => return None,
};
let v = parse_vector_literal(&lit_str).ok()?;
Some((col_name, v))
}
struct HeapEntry {
key: Value,
rowid: i64,
asc: bool,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
let raw = compare_values(Some(&self.key), Some(&other.key));
if self.asc { raw } else { raw.reverse() }
}
}
fn select_topk(
matching: &[i64],
table: &Table,
order: &OrderByClause,
k: usize,
) -> Result<Vec<i64>> {
use std::collections::BinaryHeap;
if k == 0 || matching.is_empty() {
return Ok(Vec::new());
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
for &rowid in matching {
let key = eval_expr(&order.expr, table, rowid)?;
let entry = HeapEntry {
key,
rowid,
asc: order.ascending,
};
if heap.len() < k {
heap.push(entry);
} else {
if entry < *heap.peek().unwrap() {
heap.pop();
heap.push(entry);
}
}
}
Ok(heap
.into_sorted_vec()
.into_iter()
.map(|e| e.rowid)
.collect())
}
fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
let mut keys: Vec<(i64, Result<Value>)> = rowids
.iter()
.map(|r| (*r, eval_expr(&order.expr, table, *r)))
.collect();
for (_, k) in &keys {
if let Err(e) = k {
return Err(SQLRiteError::General(format!(
"ORDER BY expression failed: {e}"
)));
}
}
keys.sort_by(|(_, ka), (_, kb)| {
let va = ka.as_ref().unwrap();
let vb = kb.as_ref().unwrap();
let ord = compare_values(Some(va), Some(vb));
if order.ascending { ord } else { ord.reverse() }
});
for (i, (rowid, _)) in keys.into_iter().enumerate() {
rowids[i] = rowid;
}
Ok(())
}
fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
match (a, b) {
(None, None) => Ordering::Equal,
(None, _) => Ordering::Less,
(_, None) => Ordering::Greater,
(Some(a), Some(b)) => match (a, b) {
(Value::Null, Value::Null) => Ordering::Equal,
(Value::Null, _) => Ordering::Less,
(_, Value::Null) => Ordering::Greater,
(Value::Integer(x), Value::Integer(y)) => x.cmp(y),
(Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
(Value::Integer(x), Value::Real(y)) => {
(*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
}
(Value::Real(x), Value::Integer(y)) => {
x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
}
(Value::Text(x), Value::Text(y)) => x.cmp(y),
(Value::Bool(x), Value::Bool(y)) => x.cmp(y),
(x, y) => x.to_display_string().cmp(&y.to_display_string()),
},
}
}
pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
let v = eval_expr(expr, table, rowid)?;
match v {
Value::Bool(b) => Ok(b),
Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
other => Err(SQLRiteError::Internal(format!(
"WHERE clause must evaluate to boolean, got {}",
other.to_display_string()
))),
}
}
fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
match expr {
Expr::Nested(inner) => eval_expr(inner, table, rowid),
Expr::Identifier(ident) => {
if ident.quote_style == Some('[') {
let raw = format!("[{}]", ident.value);
let v = parse_vector_literal(&raw)?;
return Ok(Value::Vector(v));
}
Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
}
Expr::CompoundIdentifier(parts) => {
let col = parts
.last()
.map(|i| i.value.as_str())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
}
Expr::Value(v) => convert_literal(&v.value),
Expr::UnaryOp { op, expr } => {
let inner = eval_expr(expr, table, rowid)?;
match op {
UnaryOperator::Not => match inner {
Value::Bool(b) => Ok(Value::Bool(!b)),
Value::Null => Ok(Value::Null),
other => Err(SQLRiteError::Internal(format!(
"NOT applied to non-boolean value: {}",
other.to_display_string()
))),
},
UnaryOperator::Minus => match inner {
Value::Integer(i) => Ok(Value::Integer(-i)),
Value::Real(f) => Ok(Value::Real(-f)),
Value::Null => Ok(Value::Null),
other => Err(SQLRiteError::Internal(format!(
"unary minus on non-numeric value: {}",
other.to_display_string()
))),
},
UnaryOperator::Plus => Ok(inner),
other => Err(SQLRiteError::NotImplemented(format!(
"unary operator {other:?} is not supported"
))),
}
}
Expr::BinaryOp { left, op, right } => match op {
BinaryOperator::And => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
}
BinaryOperator::Or => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
}
cmp @ (BinaryOperator::Eq
| BinaryOperator::NotEq
| BinaryOperator::Lt
| BinaryOperator::LtEq
| BinaryOperator::Gt
| BinaryOperator::GtEq) => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Bool(false));
}
let ord = compare_values(Some(&l), Some(&r));
let result = match cmp {
BinaryOperator::Eq => ord == Ordering::Equal,
BinaryOperator::NotEq => ord != Ordering::Equal,
BinaryOperator::Lt => ord == Ordering::Less,
BinaryOperator::LtEq => ord != Ordering::Greater,
BinaryOperator::Gt => ord == Ordering::Greater,
BinaryOperator::GtEq => ord != Ordering::Less,
_ => unreachable!(),
};
Ok(Value::Bool(result))
}
arith @ (BinaryOperator::Plus
| BinaryOperator::Minus
| BinaryOperator::Multiply
| BinaryOperator::Divide
| BinaryOperator::Modulo) => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
eval_arith(arith, &l, &r)
}
BinaryOperator::StringConcat => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
Ok(Value::Text(format!(
"{}{}",
l.to_display_string(),
r.to_display_string()
)))
}
other => Err(SQLRiteError::NotImplemented(format!(
"binary operator {other:?} is not supported yet"
))),
},
Expr::Function(func) => eval_function(func, table, rowid),
other => Err(SQLRiteError::NotImplemented(format!(
"unsupported expression in WHERE/projection: {other:?}"
))),
}
}
fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
let name = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => {
return Err(SQLRiteError::NotImplemented(format!(
"qualified function names not supported: {:?}",
func.name
)));
}
};
match name.as_str() {
"vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
let dist = match name.as_str() {
"vec_distance_l2" => vec_distance_l2(&a, &b),
"vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
"vec_distance_dot" => vec_distance_dot(&a, &b),
_ => unreachable!(),
};
Ok(Value::Real(dist as f64))
}
other => Err(SQLRiteError::NotImplemented(format!(
"unknown function: {other}(...)"
))),
}
}
fn extract_two_vector_args(
fn_name: &str,
args: &FunctionArguments,
table: &Table,
rowid: i64,
) -> Result<(Vec<f32>, Vec<f32>)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly two vector arguments"
)));
}
};
if arg_list.len() != 2 {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly 2 arguments, got {}",
arg_list.len()
)));
}
let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
for (i, arg) in arg_list.iter().enumerate() {
let expr = match arg {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument {i} has unsupported shape: {other:?}"
)));
}
};
let val = eval_expr(expr, table, rowid)?;
match val {
Value::Vector(v) => out.push(v),
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument {i} is not a vector: got {}",
other.to_display_string()
)));
}
}
}
let b = out.pop().unwrap();
let a = out.pop().unwrap();
if a.len() != b.len() {
return Err(SQLRiteError::General(format!(
"{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
a.len(),
b.len()
)));
}
Ok((a, b))
}
pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum.sqrt()
}
pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a_sq += a[i] * a[i];
norm_b_sq += b[i] * b[i];
}
let denom = (norm_a_sq * norm_b_sq).sqrt();
if denom == 0.0 {
return Err(SQLRiteError::General(
"vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
));
}
Ok(1.0 - dot / denom)
}
pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
}
-dot
}
fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
match (l, r) {
(Value::Integer(a), Value::Integer(b)) => match op {
BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
BinaryOperator::Divide => {
if *b == 0 {
Err(SQLRiteError::General("division by zero".to_string()))
} else {
Ok(Value::Integer(a / b))
}
}
BinaryOperator::Modulo => {
if *b == 0 {
Err(SQLRiteError::General("modulo by zero".to_string()))
} else {
Ok(Value::Integer(a % b))
}
}
_ => unreachable!(),
},
(a, b) => {
let af = as_number(a)?;
let bf = as_number(b)?;
match op {
BinaryOperator::Plus => Ok(Value::Real(af + bf)),
BinaryOperator::Minus => Ok(Value::Real(af - bf)),
BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
BinaryOperator::Divide => {
if bf == 0.0 {
Err(SQLRiteError::General("division by zero".to_string()))
} else {
Ok(Value::Real(af / bf))
}
}
BinaryOperator::Modulo => {
if bf == 0.0 {
Err(SQLRiteError::General("modulo by zero".to_string()))
} else {
Ok(Value::Real(af % bf))
}
}
_ => unreachable!(),
}
}
}
}
fn as_number(v: &Value) -> Result<f64> {
match v {
Value::Integer(i) => Ok(*i as f64),
Value::Real(f) => Ok(*f),
Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
other => Err(SQLRiteError::General(format!(
"arithmetic on non-numeric value '{}'",
other.to_display_string()
))),
}
}
fn as_bool(v: &Value) -> Result<bool> {
match v {
Value::Bool(b) => Ok(*b),
Value::Null => Ok(false),
Value::Integer(i) => Ok(*i != 0),
other => Err(SQLRiteError::Internal(format!(
"expected boolean, got {}",
other.to_display_string()
))),
}
}
fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
use sqlparser::ast::Value as AstValue;
match v {
AstValue::Number(n, _) => {
if let Ok(i) = n.parse::<i64>() {
Ok(Value::Integer(i))
} else if let Ok(f) = n.parse::<f64>() {
Ok(Value::Real(f))
} else {
Err(SQLRiteError::Internal(format!(
"could not parse numeric literal '{n}'"
)))
}
}
AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
AstValue::Boolean(b) => Ok(Value::Bool(*b)),
AstValue::Null => Ok(Value::Null),
other => Err(SQLRiteError::NotImplemented(format!(
"unsupported literal value: {other:?}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn vec_distance_l2_identical_is_zero() {
let v = vec![0.1, 0.2, 0.3];
assert_eq!(vec_distance_l2(&v, &v), 0.0);
}
#[test]
fn vec_distance_l2_unit_basis_is_sqrt2() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
}
#[test]
fn vec_distance_l2_known_value() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
}
#[test]
fn vec_distance_cosine_identical_is_zero() {
let v = vec![0.1, 0.2, 0.3];
let d = vec_distance_cosine(&v, &v).unwrap();
assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
}
#[test]
fn vec_distance_cosine_orthogonal_is_one() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
}
#[test]
fn vec_distance_cosine_opposite_is_two() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
}
#[test]
fn vec_distance_cosine_zero_magnitude_errors() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
let err = vec_distance_cosine(&a, &b).unwrap_err();
assert!(format!("{err}").contains("zero-magnitude"));
}
#[test]
fn vec_distance_dot_negates() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
}
#[test]
fn vec_distance_dot_orthogonal_is_zero() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert_eq!(vec_distance_dot(&a, &b), 0.0);
}
#[test]
fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
let cos = vec_distance_cosine(&a, &b).unwrap();
assert!(approx_eq(dot, cos - 1.0, 1e-5));
}
use crate::sql::db::database::Database;
use crate::sql::parser::select::SelectQuery;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
fn seed_score_table(n: usize) -> Database {
let mut db = Database::new("tempdb".to_string());
crate::sql::process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
&mut db,
)
.expect("create");
for i in 0..n {
let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
let sql = format!("INSERT INTO docs (score) VALUES ({score});");
crate::sql::process_command(&sql, &mut db).expect("insert");
}
db
}
fn parse_select(sql: &str) -> SelectQuery {
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
let stmt = ast.pop().expect("one statement");
SelectQuery::new(&stmt).expect("select-query")
}
#[test]
fn topk_matches_full_sort_asc() {
let db = seed_score_table(200);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(10);
let topk = select_topk(&all_rowids, table, order, 10).unwrap();
assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
}
#[test]
fn topk_matches_full_sort_desc() {
let db = seed_score_table(200);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(10);
let topk = select_topk(&all_rowids, table, order, 10).unwrap();
assert_eq!(
topk, full,
"top-k DESC via heap should match full-sort+truncate"
);
}
#[test]
fn topk_k_larger_than_n_returns_everything_sorted() {
let db = seed_score_table(50);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
assert_eq!(topk.len(), 50);
let scores: Vec<f64> = topk
.iter()
.filter_map(|r| match table.get_value("score", *r) {
Some(Value::Real(f)) => Some(f),
_ => None,
})
.collect();
assert!(scores.windows(2).all(|w| w[0] <= w[1]));
}
#[test]
fn topk_k_zero_returns_empty() {
let db = seed_score_table(10);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
assert!(topk.is_empty());
}
#[test]
fn topk_empty_input_returns_empty() {
let db = seed_score_table(0);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&[], table, order, 5).unwrap();
assert!(topk.is_empty());
}
#[test]
fn topk_works_through_select_executor_with_distance_function() {
let mut db = Database::new("tempdb".to_string());
crate::sql::process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
&mut db,
)
.unwrap();
for v in &[
"[1.0, 0.0]",
"[2.0, 0.0]",
"[0.0, 3.0]",
"[1.0, 4.0]",
"[10.0, 10.0]",
] {
crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
.unwrap();
}
let resp = crate::sql::process_command(
"SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
&mut db,
)
.unwrap();
assert!(resp.contains("3 rows returned"), "got: {resp}");
}
#[test]
#[ignore]
fn topk_benchmark() {
use std::time::Instant;
const N: usize = 10_000;
const K: usize = 10;
let db = seed_score_table(N);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let t0 = Instant::now();
let _topk = select_topk(&all_rowids, table, order, K).unwrap();
let heap_dur = t0.elapsed();
let t1 = Instant::now();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(K);
let sort_dur = t1.elapsed();
let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
println!("\n--- topk_benchmark (N={N}, k={K}) ---");
println!(" bounded heap: {heap_dur:?}");
println!(" full sort+trunc: {sort_dur:?}");
println!(" speedup ratio: {ratio:.2}×");
assert!(
ratio > 1.4,
"bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
);
}
}