use std::cmp::Ordering;
use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
use sqlparser::ast::{
AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
TableWithJoins, UnaryOperator, Update, Value as AstValue,
};
use crate::error::{Result, SQLRiteError};
use crate::sql::db::database::Database;
use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
use crate::sql::db::table::{
DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
};
use crate::sql::fts::{Bm25Params, PostingList};
use crate::sql::hnsw::{DistanceMetric, HnswIndex};
use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
pub struct SelectResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<Value>>,
}
pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
let table = db
.get_table(query.table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
let projected_cols: Vec<String> = match &query.projection {
Projection::All => table.column_names(),
Projection::Columns(cols) => {
for c in cols {
if !table.contains_column(c.to_string()) {
return Err(SQLRiteError::Internal(format!(
"Column '{c}' does not exist on table '{}'",
query.table_name
)));
}
}
cols.clone()
}
};
let matching = match select_rowids(table, query.selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in table.rowids() {
if let Some(expr) = &query.selection {
if !eval_predicate(expr, table, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
};
let mut matching = matching;
match (&query.order_by, query.limit) {
(Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
matching = try_hnsw_probe(table, &order.expr, k).unwrap();
}
(Some(order), Some(k))
if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
{
matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
}
(Some(order), Some(k)) if k < matching.len() => {
matching = select_topk(&matching, table, order, k)?;
}
(Some(order), _) => {
sort_rowids(&mut matching, table, order)?;
if let Some(k) = query.limit {
matching.truncate(k);
}
}
(None, Some(k)) => {
matching.truncate(k);
}
(None, None) => {}
}
let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
for rowid in &matching {
let row: Vec<Value> = projected_cols
.iter()
.map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
.collect();
rows.push(row);
}
Ok(SelectResult {
columns: projected_cols,
rows,
})
}
pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
let result = execute_select_rows(query, db)?;
let row_count = result.rows.len();
let mut print_table = PrintTable::new();
let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
print_table.add_row(PrintRow::new(header_cells));
for row in &result.rows {
let cells: Vec<PrintCell> = row
.iter()
.map(|v| PrintCell::new(&v.to_display_string()))
.collect();
print_table.add_row(PrintRow::new(cells));
}
Ok((print_table.to_string(), row_count))
}
pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
let Statement::Delete(Delete {
from, selection, ..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_delete called on a non-DELETE statement".to_string(),
));
};
let tables = match from {
FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
};
let table_name = extract_single_table_name(tables)?;
let matching: Vec<i64> = {
let table = db
.get_table(table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
match select_rowids(table, selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in table.rowids() {
if let Some(expr) = selection {
if !eval_predicate(expr, table, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
}
};
let table = db.get_table_mut(table_name)?;
for rowid in &matching {
table.delete_row(*rowid);
}
if !matching.is_empty() {
for entry in &mut table.hnsw_indexes {
entry.needs_rebuild = true;
}
for entry in &mut table.fts_indexes {
entry.needs_rebuild = true;
}
}
Ok(matching.len())
}
pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
let Statement::Update(Update {
table,
assignments,
from,
selection,
..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_update called on a non-UPDATE statement".to_string(),
));
};
if from.is_some() {
return Err(SQLRiteError::NotImplemented(
"UPDATE ... FROM is not supported yet".to_string(),
));
}
let table_name = extract_table_name(table)?;
let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
{
let tbl = db
.get_table(table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
for a in assignments {
let col = match &a.target {
AssignmentTarget::ColumnName(name) => name
.0
.last()
.map(|p| p.to_string())
.ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
AssignmentTarget::Tuple(_) => {
return Err(SQLRiteError::NotImplemented(
"tuple assignment targets are not supported".to_string(),
));
}
};
if !tbl.contains_column(col.clone()) {
return Err(SQLRiteError::Internal(format!(
"UPDATE references unknown column '{col}'"
)));
}
parsed_assignments.push((col, a.value.clone()));
}
}
let work: Vec<(i64, Vec<(String, Value)>)> = {
let tbl = db.get_table(table_name.clone())?;
let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in tbl.rowids() {
if let Some(expr) = selection {
if !eval_predicate(expr, tbl, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
};
let mut rows_to_update = Vec::new();
for rowid in matched_rowids {
let mut values = Vec::with_capacity(parsed_assignments.len());
for (col, expr) in &parsed_assignments {
let v = eval_expr(expr, tbl, rowid)?;
values.push((col.clone(), v));
}
rows_to_update.push((rowid, values));
}
rows_to_update
};
let tbl = db.get_table_mut(table_name)?;
for (rowid, values) in &work {
for (col, v) in values {
tbl.set_value(col, *rowid, v.clone())?;
}
}
if !work.is_empty() {
let updated_columns: std::collections::HashSet<&str> = work
.iter()
.flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
.collect();
for entry in &mut tbl.hnsw_indexes {
if updated_columns.contains(entry.column_name.as_str()) {
entry.needs_rebuild = true;
}
}
for entry in &mut tbl.fts_indexes {
if updated_columns.contains(entry.column_name.as_str()) {
entry.needs_rebuild = true;
}
}
}
Ok(work.len())
}
pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
let Statement::CreateIndex(CreateIndex {
name,
table_name,
columns,
using,
unique,
if_not_exists,
predicate,
..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_create_index called on a non-CREATE-INDEX statement".to_string(),
));
};
if predicate.is_some() {
return Err(SQLRiteError::NotImplemented(
"partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
));
}
if columns.len() != 1 {
return Err(SQLRiteError::NotImplemented(format!(
"multi-column indexes are not supported yet ({} columns given)",
columns.len()
)));
}
let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
SQLRiteError::NotImplemented(
"anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
)
})?;
let method = match using {
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
IndexMethod::Hnsw
}
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
IndexMethod::Fts
}
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
IndexMethod::Btree
}
Some(other) => {
return Err(SQLRiteError::NotImplemented(format!(
"CREATE INDEX … USING {other:?} is not supported \
(try `hnsw`, `fts`, or no USING clause)"
)));
}
None => IndexMethod::Btree,
};
let table_name_str = table_name.to_string();
let column_name = match &columns[0].column.expr {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"CREATE INDEX only supports simple column references, got {other:?}"
)));
}
};
let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
let table = db.get_table(table_name_str.clone()).map_err(|_| {
SQLRiteError::General(format!(
"CREATE INDEX references unknown table '{table_name_str}'"
))
})?;
if !table.contains_column(column_name.clone()) {
return Err(SQLRiteError::General(format!(
"CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
)));
}
let col = table
.columns
.iter()
.find(|c| c.column_name == column_name)
.expect("we just verified the column exists");
if table.index_by_name(&index_name).is_some()
|| table.hnsw_indexes.iter().any(|i| i.name == index_name)
|| table.fts_indexes.iter().any(|i| i.name == index_name)
{
if *if_not_exists {
return Ok(index_name);
}
return Err(SQLRiteError::General(format!(
"index '{index_name}' already exists"
)));
}
let datatype = clone_datatype(&col.datatype);
let mut pairs = Vec::new();
for rowid in table.rowids() {
if let Some(v) = table.get_value(&column_name, rowid) {
pairs.push((rowid, v));
}
}
(datatype, pairs)
};
match method {
IndexMethod::Btree => create_btree_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
IndexMethod::Hnsw => create_hnsw_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
IndexMethod::Fts => create_fts_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
}
}
#[derive(Debug, Clone, Copy)]
enum IndexMethod {
Btree,
Hnsw,
Fts,
}
fn create_btree_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
let mut idx = SecondaryIndex::new(
index_name.to_string(),
table_name.to_string(),
column_name.to_string(),
datatype,
unique,
IndexOrigin::Explicit,
)?;
for (rowid, v) in existing {
if unique && idx.would_violate_unique(v) {
return Err(SQLRiteError::General(format!(
"cannot create UNIQUE index '{index_name}': column '{column_name}' \
already contains the duplicate value {}",
v.to_display_string()
)));
}
idx.insert(v, *rowid)?;
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.secondary_indexes.push(idx);
Ok(index_name.to_string())
}
fn create_hnsw_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
let dim = match datatype {
DataType::Vector(d) => *d,
other => {
return Err(SQLRiteError::General(format!(
"USING hnsw requires a VECTOR column; '{column_name}' is {other}"
)));
}
};
if unique {
return Err(SQLRiteError::General(
"UNIQUE has no meaning for HNSW indexes".to_string(),
));
}
let seed = hash_str_to_seed(index_name);
let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
std::collections::HashMap::with_capacity(existing.len());
for (rowid, v) in existing {
match v {
Value::Vector(vec) => {
if vec.len() != dim {
return Err(SQLRiteError::Internal(format!(
"row {rowid} stores a {}-dim vector in column '{column_name}' \
declared as VECTOR({dim}) — schema invariant violated",
vec.len()
)));
}
vec_map.insert(*rowid, vec.clone());
}
_ => continue,
}
}
for (rowid, _) in existing {
if let Some(v) = vec_map.get(rowid) {
let v_clone = v.clone();
idx.insert(*rowid, &v_clone, |id| {
vec_map.get(&id).cloned().unwrap_or_default()
});
}
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.hnsw_indexes.push(HnswIndexEntry {
name: index_name.to_string(),
column_name: column_name.to_string(),
index: idx,
needs_rebuild: false,
});
Ok(index_name.to_string())
}
fn create_fts_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
match datatype {
DataType::Text => {}
other => {
return Err(SQLRiteError::General(format!(
"USING fts requires a TEXT column; '{column_name}' is {other}"
)));
}
}
if unique {
return Err(SQLRiteError::General(
"UNIQUE has no meaning for FTS indexes".to_string(),
));
}
let mut idx = PostingList::new();
for (rowid, v) in existing {
if let Value::Text(text) = v {
idx.insert(*rowid, text);
}
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.fts_indexes.push(FtsIndexEntry {
name: index_name.to_string(),
column_name: column_name.to_string(),
index: idx,
needs_rebuild: false,
});
Ok(index_name.to_string())
}
fn hash_str_to_seed(s: &str) -> u64 {
let mut h: u64 = 0xCBF29CE484222325;
for b in s.as_bytes() {
h ^= *b as u64;
h = h.wrapping_mul(0x100000001B3);
}
h
}
fn clone_datatype(dt: &DataType) -> DataType {
match dt {
DataType::Integer => DataType::Integer,
DataType::Text => DataType::Text,
DataType::Real => DataType::Real,
DataType::Bool => DataType::Bool,
DataType::Vector(dim) => DataType::Vector(*dim),
DataType::Json => DataType::Json,
DataType::None => DataType::None,
DataType::Invalid => DataType::Invalid,
}
}
fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
if tables.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"multi-table DELETE is not supported yet".to_string(),
));
}
extract_table_name(&tables[0])
}
fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
if !twj.joins.is_empty() {
return Err(SQLRiteError::NotImplemented(
"JOIN is not supported yet".to_string(),
));
}
match &twj.relation {
TableFactor::Table { name, .. } => Ok(name.to_string()),
_ => Err(SQLRiteError::NotImplemented(
"only plain table references are supported".to_string(),
)),
}
}
enum RowidSource {
IndexProbe(Vec<i64>),
FullScan,
}
fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
let Some(expr) = selection else {
return Ok(RowidSource::FullScan);
};
let Some((col, literal)) = try_extract_equality(expr) else {
return Ok(RowidSource::FullScan);
};
let Some(idx) = table.index_for_column(&col) else {
return Ok(RowidSource::FullScan);
};
let literal_value = match convert_literal(&literal) {
Ok(v) => v,
Err(_) => return Ok(RowidSource::FullScan),
};
let mut rowids = idx.lookup(&literal_value);
rowids.sort_unstable();
Ok(RowidSource::IndexProbe(rowids))
}
fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
let peeled = match expr {
Expr::Nested(inner) => inner.as_ref(),
other => other,
};
let Expr::BinaryOp { left, op, right } = peeled else {
return None;
};
if !matches!(op, BinaryOperator::Eq) {
return None;
}
let col_from = |e: &Expr| -> Option<String> {
match e {
Expr::Identifier(ident) => Some(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
_ => None,
}
};
let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
if let Expr::Value(v) = e {
Some(v.value.clone())
} else {
None
}
};
if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
return Some((c, l));
}
if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
return Some((c, l));
}
None
}
fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
if k == 0 {
return None;
}
let func = match order_expr {
Expr::Function(f) => f,
_ => return None,
};
let fname = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => return None,
};
if fname != "vec_distance_l2" {
return None;
}
let arg_list = match &func.args {
FunctionArguments::List(l) => &l.args,
_ => return None,
};
if arg_list.len() != 2 {
return None;
}
let exprs: Vec<&Expr> = arg_list
.iter()
.filter_map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
_ => None,
})
.collect();
if exprs.len() != 2 {
return None;
}
let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
Some(v) => v,
None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
Some(v) => v,
None => return None,
},
};
let entry = table
.hnsw_indexes
.iter()
.find(|e| e.column_name == col_name)?;
let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
Some(c) => match &c.datatype {
DataType::Vector(d) => *d,
_ => return None,
},
None => return None,
};
if query_vec.len() != declared_dim {
return None;
}
let column_for_closure = col_name.clone();
let table_ref = table;
let result = entry.index.search(&query_vec, k, |id| {
match table_ref.get_value(&column_for_closure, id) {
Some(Value::Vector(v)) => v,
_ => Vec::new(),
}
});
Some(result)
}
fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
if k == 0 || ascending {
return None;
}
let func = match order_expr {
Expr::Function(f) => f,
_ => return None,
};
let fname = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => return None,
};
if fname != "bm25_score" {
return None;
}
let arg_list = match &func.args {
FunctionArguments::List(l) => &l.args,
_ => return None,
};
if arg_list.len() != 2 {
return None;
}
let exprs: Vec<&Expr> = arg_list
.iter()
.filter_map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
_ => None,
})
.collect();
if exprs.len() != 2 {
return None;
}
let col_name = match exprs[0] {
Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
_ => return None,
};
let query = match exprs[1] {
Expr::Value(v) => match &v.value {
AstValue::SingleQuotedString(s) => s.clone(),
_ => return None,
},
_ => return None,
};
let entry = table
.fts_indexes
.iter()
.find(|e| e.column_name == col_name)?;
let scored = entry.index.query(&query, &Bm25Params::default());
let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
if out.len() > k {
out.truncate(k);
}
Some(out)
}
fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
let col_name = match a {
Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
_ => return None,
};
let lit_str = match b {
Expr::Identifier(ident) if ident.quote_style == Some('[') => {
format!("[{}]", ident.value)
}
_ => return None,
};
let v = parse_vector_literal(&lit_str).ok()?;
Some((col_name, v))
}
struct HeapEntry {
key: Value,
rowid: i64,
asc: bool,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
let raw = compare_values(Some(&self.key), Some(&other.key));
if self.asc { raw } else { raw.reverse() }
}
}
fn select_topk(
matching: &[i64],
table: &Table,
order: &OrderByClause,
k: usize,
) -> Result<Vec<i64>> {
use std::collections::BinaryHeap;
if k == 0 || matching.is_empty() {
return Ok(Vec::new());
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
for &rowid in matching {
let key = eval_expr(&order.expr, table, rowid)?;
let entry = HeapEntry {
key,
rowid,
asc: order.ascending,
};
if heap.len() < k {
heap.push(entry);
} else {
if entry < *heap.peek().unwrap() {
heap.pop();
heap.push(entry);
}
}
}
Ok(heap
.into_sorted_vec()
.into_iter()
.map(|e| e.rowid)
.collect())
}
fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
let mut keys: Vec<(i64, Result<Value>)> = rowids
.iter()
.map(|r| (*r, eval_expr(&order.expr, table, *r)))
.collect();
for (_, k) in &keys {
if let Err(e) = k {
return Err(SQLRiteError::General(format!(
"ORDER BY expression failed: {e}"
)));
}
}
keys.sort_by(|(_, ka), (_, kb)| {
let va = ka.as_ref().unwrap();
let vb = kb.as_ref().unwrap();
let ord = compare_values(Some(va), Some(vb));
if order.ascending { ord } else { ord.reverse() }
});
for (i, (rowid, _)) in keys.into_iter().enumerate() {
rowids[i] = rowid;
}
Ok(())
}
fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
match (a, b) {
(None, None) => Ordering::Equal,
(None, _) => Ordering::Less,
(_, None) => Ordering::Greater,
(Some(a), Some(b)) => match (a, b) {
(Value::Null, Value::Null) => Ordering::Equal,
(Value::Null, _) => Ordering::Less,
(_, Value::Null) => Ordering::Greater,
(Value::Integer(x), Value::Integer(y)) => x.cmp(y),
(Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
(Value::Integer(x), Value::Real(y)) => {
(*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
}
(Value::Real(x), Value::Integer(y)) => {
x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
}
(Value::Text(x), Value::Text(y)) => x.cmp(y),
(Value::Bool(x), Value::Bool(y)) => x.cmp(y),
(x, y) => x.to_display_string().cmp(&y.to_display_string()),
},
}
}
pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
let v = eval_expr(expr, table, rowid)?;
match v {
Value::Bool(b) => Ok(b),
Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
other => Err(SQLRiteError::Internal(format!(
"WHERE clause must evaluate to boolean, got {}",
other.to_display_string()
))),
}
}
fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
match expr {
Expr::Nested(inner) => eval_expr(inner, table, rowid),
Expr::Identifier(ident) => {
if ident.quote_style == Some('[') {
let raw = format!("[{}]", ident.value);
let v = parse_vector_literal(&raw)?;
return Ok(Value::Vector(v));
}
Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
}
Expr::CompoundIdentifier(parts) => {
let col = parts
.last()
.map(|i| i.value.as_str())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
}
Expr::Value(v) => convert_literal(&v.value),
Expr::UnaryOp { op, expr } => {
let inner = eval_expr(expr, table, rowid)?;
match op {
UnaryOperator::Not => match inner {
Value::Bool(b) => Ok(Value::Bool(!b)),
Value::Null => Ok(Value::Null),
other => Err(SQLRiteError::Internal(format!(
"NOT applied to non-boolean value: {}",
other.to_display_string()
))),
},
UnaryOperator::Minus => match inner {
Value::Integer(i) => Ok(Value::Integer(-i)),
Value::Real(f) => Ok(Value::Real(-f)),
Value::Null => Ok(Value::Null),
other => Err(SQLRiteError::Internal(format!(
"unary minus on non-numeric value: {}",
other.to_display_string()
))),
},
UnaryOperator::Plus => Ok(inner),
other => Err(SQLRiteError::NotImplemented(format!(
"unary operator {other:?} is not supported"
))),
}
}
Expr::BinaryOp { left, op, right } => match op {
BinaryOperator::And => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
}
BinaryOperator::Or => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
}
cmp @ (BinaryOperator::Eq
| BinaryOperator::NotEq
| BinaryOperator::Lt
| BinaryOperator::LtEq
| BinaryOperator::Gt
| BinaryOperator::GtEq) => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Bool(false));
}
let ord = compare_values(Some(&l), Some(&r));
let result = match cmp {
BinaryOperator::Eq => ord == Ordering::Equal,
BinaryOperator::NotEq => ord != Ordering::Equal,
BinaryOperator::Lt => ord == Ordering::Less,
BinaryOperator::LtEq => ord != Ordering::Greater,
BinaryOperator::Gt => ord == Ordering::Greater,
BinaryOperator::GtEq => ord != Ordering::Less,
_ => unreachable!(),
};
Ok(Value::Bool(result))
}
arith @ (BinaryOperator::Plus
| BinaryOperator::Minus
| BinaryOperator::Multiply
| BinaryOperator::Divide
| BinaryOperator::Modulo) => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
eval_arith(arith, &l, &r)
}
BinaryOperator::StringConcat => {
let l = eval_expr(left, table, rowid)?;
let r = eval_expr(right, table, rowid)?;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
Ok(Value::Text(format!(
"{}{}",
l.to_display_string(),
r.to_display_string()
)))
}
other => Err(SQLRiteError::NotImplemented(format!(
"binary operator {other:?} is not supported yet"
))),
},
Expr::Function(func) => eval_function(func, table, rowid),
other => Err(SQLRiteError::NotImplemented(format!(
"unsupported expression in WHERE/projection: {other:?}"
))),
}
}
fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
let name = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => {
return Err(SQLRiteError::NotImplemented(format!(
"qualified function names not supported: {:?}",
func.name
)));
}
};
match name.as_str() {
"vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
let dist = match name.as_str() {
"vec_distance_l2" => vec_distance_l2(&a, &b),
"vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
"vec_distance_dot" => vec_distance_dot(&a, &b),
_ => unreachable!(),
};
Ok(Value::Real(dist as f64))
}
"json_extract" => json_fn_extract(&name, &func.args, table, rowid),
"json_type" => json_fn_type(&name, &func.args, table, rowid),
"json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
"json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
"fts_match" => {
let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
Ok(Value::Bool(entry.index.matches(rowid, &query)))
}
"bm25_score" => {
let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
let s = entry.index.score(rowid, &query, &Bm25Params::default());
Ok(Value::Real(s))
}
other => Err(SQLRiteError::NotImplemented(format!(
"unknown function: {other}(...)"
))),
}
}
fn resolve_fts_args<'t>(
fn_name: &str,
args: &FunctionArguments,
table: &'t Table,
rowid: i64,
) -> Result<(&'t FtsIndexEntry, String)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly two arguments: (column, query_text)"
)));
}
};
if arg_list.len() != 2 {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly 2 arguments, got {}",
arg_list.len()
)));
}
let col_expr = match &arg_list[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 0 must be a column name, got {other:?}"
)));
}
};
let col_name = match col_expr {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument 0 must be a column reference, got {other:?}"
)));
}
};
let q_expr = match &arg_list[1] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 1 must be a text expression, got {other:?}"
)));
}
};
let query = match eval_expr(q_expr, table, rowid)? {
Value::Text(s) => s,
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument 1 must be TEXT, got {}",
other.to_display_string()
)));
}
};
let entry = table
.fts_indexes
.iter()
.find(|e| e.column_name == col_name)
.ok_or_else(|| {
SQLRiteError::General(format!(
"{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
(run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
))
})?;
Ok((entry, query))
}
fn extract_json_and_path(
fn_name: &str,
args: &FunctionArguments,
table: &Table,
rowid: i64,
) -> Result<(String, String)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects 1 or 2 arguments"
)));
}
};
if !(arg_list.len() == 1 || arg_list.len() == 2) {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects 1 or 2 arguments, got {}",
arg_list.len()
)));
}
let first_expr = match &arg_list[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 0 has unsupported shape: {other:?}"
)));
}
};
let json_text = match eval_expr(first_expr, table, rowid)? {
Value::Text(s) => s,
Value::Null => {
return Err(SQLRiteError::General(format!(
"{fn_name}() called on NULL — JSON column has no value for this row"
)));
}
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument 0 is not JSON-typed: got {}",
other.to_display_string()
)));
}
};
let path = if arg_list.len() == 2 {
let path_expr = match &arg_list[1] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 1 has unsupported shape: {other:?}"
)));
}
};
match eval_expr(path_expr, table, rowid)? {
Value::Text(s) => s,
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() path argument must be a string literal, got {}",
other.to_display_string()
)));
}
}
} else {
"$".to_string()
};
Ok((json_text, path))
}
fn walk_json_path<'a>(
value: &'a serde_json::Value,
path: &str,
) -> Result<Option<&'a serde_json::Value>> {
let mut chars = path.chars().peekable();
if chars.next() != Some('$') {
return Err(SQLRiteError::General(format!(
"JSON path must start with '$', got `{path}`"
)));
}
let mut current = value;
while let Some(&c) = chars.peek() {
match c {
'.' => {
chars.next();
let mut key = String::new();
while let Some(&c) = chars.peek() {
if c == '.' || c == '[' {
break;
}
key.push(c);
chars.next();
}
if key.is_empty() {
return Err(SQLRiteError::General(format!(
"JSON path has empty key after '.' in `{path}`"
)));
}
match current.get(&key) {
Some(v) => current = v,
None => return Ok(None),
}
}
'[' => {
chars.next();
let mut idx_str = String::new();
while let Some(&c) = chars.peek() {
if c == ']' {
break;
}
idx_str.push(c);
chars.next();
}
if chars.next() != Some(']') {
return Err(SQLRiteError::General(format!(
"JSON path has unclosed `[` in `{path}`"
)));
}
let idx: usize = idx_str.trim().parse().map_err(|_| {
SQLRiteError::General(format!(
"JSON path has non-integer index `[{idx_str}]` in `{path}`"
))
})?;
match current.get(idx) {
Some(v) => current = v,
None => return Ok(None),
}
}
other => {
return Err(SQLRiteError::General(format!(
"JSON path has unexpected character `{other}` in `{path}` \
(expected `.`, `[`, or end-of-path)"
)));
}
}
}
Ok(Some(current))
}
fn json_value_to_sql(v: &serde_json::Value) -> Value {
match v {
serde_json::Value::Null => Value::Null,
serde_json::Value::Bool(b) => Value::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Integer(i)
} else if let Some(f) = n.as_f64() {
Value::Real(f)
} else {
Value::Null
}
}
serde_json::Value::String(s) => Value::Text(s.clone()),
composite => Value::Text(composite.to_string()),
}
}
fn json_fn_extract(
name: &str,
args: &FunctionArguments,
table: &Table,
rowid: i64,
) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
match walk_json_path(&parsed, &path)? {
Some(v) => Ok(json_value_to_sql(v)),
None => Ok(Value::Null),
}
}
fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
let resolved = match walk_json_path(&parsed, &path)? {
Some(v) => v,
None => return Ok(Value::Null),
};
let ty = match resolved {
serde_json::Value::Null => "null",
serde_json::Value::Bool(true) => "true",
serde_json::Value::Bool(false) => "false",
serde_json::Value::Number(n) => {
if n.is_i64() || n.is_u64() {
"integer"
} else {
"real"
}
}
serde_json::Value::String(_) => "text",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
};
Ok(Value::Text(ty.to_string()))
}
fn json_fn_array_length(
name: &str,
args: &FunctionArguments,
table: &Table,
rowid: i64,
) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
let resolved = match walk_json_path(&parsed, &path)? {
Some(v) => v,
None => return Ok(Value::Null),
};
match resolved.as_array() {
Some(arr) => Ok(Value::Integer(arr.len() as i64)),
None => Err(SQLRiteError::General(format!(
"{name}() resolved to a non-array value at path `{path}`"
))),
}
}
fn json_fn_object_keys(
name: &str,
args: &FunctionArguments,
table: &Table,
rowid: i64,
) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
let resolved = match walk_json_path(&parsed, &path)? {
Some(v) => v,
None => return Ok(Value::Null),
};
let obj = resolved.as_object().ok_or_else(|| {
SQLRiteError::General(format!(
"{name}() resolved to a non-object value at path `{path}`"
))
})?;
let keys: Vec<serde_json::Value> = obj
.keys()
.map(|k| serde_json::Value::String(k.clone()))
.collect();
Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
}
fn extract_two_vector_args(
fn_name: &str,
args: &FunctionArguments,
table: &Table,
rowid: i64,
) -> Result<(Vec<f32>, Vec<f32>)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly two vector arguments"
)));
}
};
if arg_list.len() != 2 {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly 2 arguments, got {}",
arg_list.len()
)));
}
let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
for (i, arg) in arg_list.iter().enumerate() {
let expr = match arg {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument {i} has unsupported shape: {other:?}"
)));
}
};
let val = eval_expr(expr, table, rowid)?;
match val {
Value::Vector(v) => out.push(v),
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument {i} is not a vector: got {}",
other.to_display_string()
)));
}
}
}
let b = out.pop().unwrap();
let a = out.pop().unwrap();
if a.len() != b.len() {
return Err(SQLRiteError::General(format!(
"{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
a.len(),
b.len()
)));
}
Ok((a, b))
}
pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum.sqrt()
}
pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a_sq += a[i] * a[i];
norm_b_sq += b[i] * b[i];
}
let denom = (norm_a_sq * norm_b_sq).sqrt();
if denom == 0.0 {
return Err(SQLRiteError::General(
"vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
));
}
Ok(1.0 - dot / denom)
}
pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
}
-dot
}
fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
match (l, r) {
(Value::Integer(a), Value::Integer(b)) => match op {
BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
BinaryOperator::Divide => {
if *b == 0 {
Err(SQLRiteError::General("division by zero".to_string()))
} else {
Ok(Value::Integer(a / b))
}
}
BinaryOperator::Modulo => {
if *b == 0 {
Err(SQLRiteError::General("modulo by zero".to_string()))
} else {
Ok(Value::Integer(a % b))
}
}
_ => unreachable!(),
},
(a, b) => {
let af = as_number(a)?;
let bf = as_number(b)?;
match op {
BinaryOperator::Plus => Ok(Value::Real(af + bf)),
BinaryOperator::Minus => Ok(Value::Real(af - bf)),
BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
BinaryOperator::Divide => {
if bf == 0.0 {
Err(SQLRiteError::General("division by zero".to_string()))
} else {
Ok(Value::Real(af / bf))
}
}
BinaryOperator::Modulo => {
if bf == 0.0 {
Err(SQLRiteError::General("modulo by zero".to_string()))
} else {
Ok(Value::Real(af % bf))
}
}
_ => unreachable!(),
}
}
}
}
fn as_number(v: &Value) -> Result<f64> {
match v {
Value::Integer(i) => Ok(*i as f64),
Value::Real(f) => Ok(*f),
Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
other => Err(SQLRiteError::General(format!(
"arithmetic on non-numeric value '{}'",
other.to_display_string()
))),
}
}
fn as_bool(v: &Value) -> Result<bool> {
match v {
Value::Bool(b) => Ok(*b),
Value::Null => Ok(false),
Value::Integer(i) => Ok(*i != 0),
other => Err(SQLRiteError::Internal(format!(
"expected boolean, got {}",
other.to_display_string()
))),
}
}
fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
use sqlparser::ast::Value as AstValue;
match v {
AstValue::Number(n, _) => {
if let Ok(i) = n.parse::<i64>() {
Ok(Value::Integer(i))
} else if let Ok(f) = n.parse::<f64>() {
Ok(Value::Real(f))
} else {
Err(SQLRiteError::Internal(format!(
"could not parse numeric literal '{n}'"
)))
}
}
AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
AstValue::Boolean(b) => Ok(Value::Bool(*b)),
AstValue::Null => Ok(Value::Null),
other => Err(SQLRiteError::NotImplemented(format!(
"unsupported literal value: {other:?}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn vec_distance_l2_identical_is_zero() {
let v = vec![0.1, 0.2, 0.3];
assert_eq!(vec_distance_l2(&v, &v), 0.0);
}
#[test]
fn vec_distance_l2_unit_basis_is_sqrt2() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
}
#[test]
fn vec_distance_l2_known_value() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
}
#[test]
fn vec_distance_cosine_identical_is_zero() {
let v = vec![0.1, 0.2, 0.3];
let d = vec_distance_cosine(&v, &v).unwrap();
assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
}
#[test]
fn vec_distance_cosine_orthogonal_is_one() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
}
#[test]
fn vec_distance_cosine_opposite_is_two() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
}
#[test]
fn vec_distance_cosine_zero_magnitude_errors() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
let err = vec_distance_cosine(&a, &b).unwrap_err();
assert!(format!("{err}").contains("zero-magnitude"));
}
#[test]
fn vec_distance_dot_negates() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
}
#[test]
fn vec_distance_dot_orthogonal_is_zero() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert_eq!(vec_distance_dot(&a, &b), 0.0);
}
#[test]
fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
let cos = vec_distance_cosine(&a, &b).unwrap();
assert!(approx_eq(dot, cos - 1.0, 1e-5));
}
use crate::sql::db::database::Database;
use crate::sql::parser::select::SelectQuery;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
fn seed_score_table(n: usize) -> Database {
let mut db = Database::new("tempdb".to_string());
crate::sql::process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
&mut db,
)
.expect("create");
for i in 0..n {
let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
let sql = format!("INSERT INTO docs (score) VALUES ({score});");
crate::sql::process_command(&sql, &mut db).expect("insert");
}
db
}
fn parse_select(sql: &str) -> SelectQuery {
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
let stmt = ast.pop().expect("one statement");
SelectQuery::new(&stmt).expect("select-query")
}
#[test]
fn topk_matches_full_sort_asc() {
let db = seed_score_table(200);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(10);
let topk = select_topk(&all_rowids, table, order, 10).unwrap();
assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
}
#[test]
fn topk_matches_full_sort_desc() {
let db = seed_score_table(200);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(10);
let topk = select_topk(&all_rowids, table, order, 10).unwrap();
assert_eq!(
topk, full,
"top-k DESC via heap should match full-sort+truncate"
);
}
#[test]
fn topk_k_larger_than_n_returns_everything_sorted() {
let db = seed_score_table(50);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
assert_eq!(topk.len(), 50);
let scores: Vec<f64> = topk
.iter()
.filter_map(|r| match table.get_value("score", *r) {
Some(Value::Real(f)) => Some(f),
_ => None,
})
.collect();
assert!(scores.windows(2).all(|w| w[0] <= w[1]));
}
#[test]
fn topk_k_zero_returns_empty() {
let db = seed_score_table(10);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
assert!(topk.is_empty());
}
#[test]
fn topk_empty_input_returns_empty() {
let db = seed_score_table(0);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&[], table, order, 5).unwrap();
assert!(topk.is_empty());
}
#[test]
fn topk_works_through_select_executor_with_distance_function() {
let mut db = Database::new("tempdb".to_string());
crate::sql::process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
&mut db,
)
.unwrap();
for v in &[
"[1.0, 0.0]",
"[2.0, 0.0]",
"[0.0, 3.0]",
"[1.0, 4.0]",
"[10.0, 10.0]",
] {
crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
.unwrap();
}
let resp = crate::sql::process_command(
"SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
&mut db,
)
.unwrap();
assert!(resp.contains("3 rows returned"), "got: {resp}");
}
#[test]
#[ignore]
fn topk_benchmark() {
use std::time::Instant;
const N: usize = 10_000;
const K: usize = 10;
let db = seed_score_table(N);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let t0 = Instant::now();
let _topk = select_topk(&all_rowids, table, order, K).unwrap();
let heap_dur = t0.elapsed();
let t1 = Instant::now();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(K);
let sort_dur = t1.elapsed();
let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
println!("\n--- topk_benchmark (N={N}, k={K}) ---");
println!(" bounded heap: {heap_dur:?}");
println!(" full sort+trunc: {sort_dur:?}");
println!(" speedup ratio: {ratio:.2}×");
assert!(
ratio > 1.4,
"bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
);
}
}