use std::cmp::Ordering;
use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
use sqlparser::ast::{
AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
Update, Value as AstValue,
};
use crate::error::{Result, SQLRiteError};
use crate::sql::agg::{AggState, DistinctKey, like_match};
use crate::sql::db::database::Database;
use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
use crate::sql::db::table::{
DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
};
use crate::sql::fts::{Bm25Params, PostingList};
use crate::sql::hnsw::{DistanceMetric, HnswIndex};
use crate::sql::parser::select::{
AggregateArg, JoinType, OrderByClause, Projection, ProjectionItem, ProjectionKind, SelectQuery,
};
pub(crate) trait RowScope {
fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
fn single_table_view(&self) -> Option<(&Table, i64)>;
}
pub(crate) struct SingleTableScope<'a> {
table: &'a Table,
rowid: i64,
}
impl<'a> SingleTableScope<'a> {
pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
Self { table, rowid }
}
}
impl RowScope for SingleTableScope<'_> {
fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
let _ = qualifier;
Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
}
fn single_table_view(&self) -> Option<(&Table, i64)> {
Some((self.table, self.rowid))
}
}
pub(crate) struct JoinedTableRef<'a> {
pub table: &'a Table,
pub scope_name: String,
}
pub(crate) struct JoinedScope<'a> {
pub tables: &'a [JoinedTableRef<'a>],
pub rowids: &'a [Option<i64>],
}
impl RowScope for JoinedScope<'_> {
fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
if let Some(q) = qualifier {
let pos = self
.tables
.iter()
.position(|t| t.scope_name.eq_ignore_ascii_case(q))
.ok_or_else(|| {
SQLRiteError::Internal(format!(
"unknown table qualifier '{q}' in column reference '{q}.{col}'"
))
})?;
if !self.tables[pos].table.contains_column(col.to_string()) {
return Err(SQLRiteError::Internal(format!(
"column '{col}' does not exist on '{}'",
self.tables[pos].scope_name
)));
}
return Ok(match self.rowids[pos] {
None => Value::Null,
Some(r) => self.tables[pos]
.table
.get_value(col, r)
.unwrap_or(Value::Null),
});
}
let mut hit: Option<usize> = None;
for (i, t) in self.tables.iter().enumerate() {
if t.table.contains_column(col.to_string()) {
if hit.is_some() {
return Err(SQLRiteError::Internal(format!(
"column reference '{col}' is ambiguous — qualify it as <table>.{col}"
)));
}
hit = Some(i);
}
}
let i = hit.ok_or_else(|| {
SQLRiteError::Internal(format!(
"unknown column '{col}' in joined SELECT (no in-scope table has it)"
))
})?;
Ok(match self.rowids[i] {
None => Value::Null,
Some(r) => self.tables[i]
.table
.get_value(col, r)
.unwrap_or(Value::Null),
})
}
fn single_table_view(&self) -> Option<(&Table, i64)> {
None
}
}
pub struct SelectResult {
pub columns: Vec<String>,
pub rows: Vec<Vec<Value>>,
}
pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
if !query.joins.is_empty() {
return execute_select_rows_joined(query, db);
}
let table = db
.get_table(query.table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
let proj_items: Vec<ProjectionItem> = match &query.projection {
Projection::All => table
.column_names()
.into_iter()
.map(|c| ProjectionItem {
kind: ProjectionKind::Column {
qualifier: None,
name: c,
},
alias: None,
})
.collect(),
Projection::Items(items) => items.clone(),
};
let has_aggregates = proj_items
.iter()
.any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
for item in &proj_items {
if let ProjectionKind::Column { name: c, .. } = &item.kind
&& !table.contains_column(c.clone())
{
return Err(SQLRiteError::Internal(format!(
"Column '{c}' does not exist on table '{}'",
query.table_name
)));
}
}
for c in &query.group_by {
if !table.contains_column(c.clone()) {
return Err(SQLRiteError::Internal(format!(
"GROUP BY references unknown column '{c}' on table '{}'",
query.table_name
)));
}
}
let matching = match select_rowids(table, query.selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in table.rowids() {
if let Some(expr) = &query.selection
&& !eval_predicate(expr, table, rowid)?
{
continue;
}
out.push(rowid);
}
out
}
};
let mut matching = matching;
let aggregating = has_aggregates || !query.group_by.is_empty();
if aggregating {
for item in &proj_items {
if let ProjectionKind::Aggregate(call) = &item.kind
&& let AggregateArg::Column(c) = &call.arg
&& !table.contains_column(c.clone())
{
return Err(SQLRiteError::Internal(format!(
"{}({}) references unknown column '{c}' on table '{}'",
call.func.as_str(),
c,
query.table_name
)));
}
}
let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
if query.distinct {
rows = dedupe_rows(rows);
}
if let Some(order) = &query.order_by {
sort_output_rows(&mut rows, &columns, &proj_items, order)?;
}
if let Some(k) = query.limit {
rows.truncate(k);
}
return Ok(SelectResult { columns, rows });
}
let defer_limit_for_distinct = query.distinct;
match (&query.order_by, query.limit) {
(Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
matching = try_hnsw_probe(table, &order.expr, k).unwrap();
}
(Some(order), Some(k))
if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
{
matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
}
(Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
matching = select_topk(&matching, table, order, k)?;
}
(Some(order), _) => {
sort_rowids(&mut matching, table, order)?;
if let Some(k) = query.limit
&& !defer_limit_for_distinct
{
matching.truncate(k);
}
}
(None, Some(k)) if !defer_limit_for_distinct => {
matching.truncate(k);
}
_ => {}
}
let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
let projected_cols: Vec<String> = proj_items
.iter()
.map(|i| match &i.kind {
ProjectionKind::Column { name, .. } => name.clone(),
ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
})
.collect();
let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
for rowid in &matching {
let row: Vec<Value> = projected_cols
.iter()
.map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
.collect();
rows.push(row);
}
if query.distinct {
rows = dedupe_rows(rows);
if let Some(k) = query.limit {
rows.truncate(k);
}
}
Ok(SelectResult { columns, rows })
}
fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
let primary = db
.get_table(query.table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
joined_tables.push(JoinedTableRef {
table: primary,
scope_name: query
.table_alias
.clone()
.unwrap_or_else(|| query.table_name.clone()),
});
for j in &query.joins {
let t = db
.get_table(j.right_table.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
joined_tables.push(JoinedTableRef {
table: t,
scope_name: j
.right_alias
.clone()
.unwrap_or_else(|| j.right_table.clone()),
});
}
{
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
for t in &joined_tables {
let key = t.scope_name.to_ascii_lowercase();
if !seen.insert(key) {
return Err(SQLRiteError::Internal(format!(
"duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
t.scope_name
)));
}
}
}
let proj_items: Vec<ProjectionItem> = match &query.projection {
Projection::All => {
let mut all = Vec::new();
for t in &joined_tables {
for col in t.table.column_names() {
all.push(ProjectionItem {
kind: ProjectionKind::Column {
qualifier: Some(t.scope_name.clone()),
name: col,
},
alias: None,
});
}
}
all
}
Projection::Items(items) => items.clone(),
};
let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
let mut acc: Vec<Vec<Option<i64>>> = primary
.rowids()
.into_iter()
.map(|r| {
let mut row = Vec::with_capacity(joined_tables.len());
row.push(Some(r));
row
})
.collect();
for (j_idx, join) in query.joins.iter().enumerate() {
let right_pos = j_idx + 1;
let right_table = joined_tables[right_pos].table;
let right_rowids: Vec<i64> = right_table.rowids();
let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
for left_row in acc.into_iter() {
let mut left_match_count = 0usize;
for (r_idx, &rrid) in right_rowids.iter().enumerate() {
let mut on_rowids: Vec<Option<i64>> = left_row.clone();
on_rowids.push(Some(rrid));
debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
let scope = JoinedScope {
tables: on_scope_tables,
rowids: &on_rowids,
};
if eval_predicate_scope(&join.on, &scope)? {
left_match_count += 1;
right_matched[r_idx] = true;
next_acc.push(on_rowids);
}
}
if left_match_count == 0
&& matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
{
let mut padded = left_row;
padded.push(None);
next_acc.push(padded);
}
}
if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
for (r_idx, matched) in right_matched.iter().enumerate() {
if *matched {
continue;
}
let mut row: Vec<Option<i64>> = vec![None; right_pos];
row.push(Some(right_rowids[r_idx]));
next_acc.push(row);
}
}
acc = next_acc;
}
let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
let mut out = Vec::with_capacity(acc.len());
for row in acc {
let scope = JoinedScope {
tables: &joined_tables,
rowids: &row,
};
if eval_predicate_scope(where_expr, &scope)? {
out.push(row);
}
}
out
} else {
acc
};
if let Some(order) = &query.order_by {
let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
for (i, row) in filtered.iter().enumerate() {
let scope = JoinedScope {
tables: &joined_tables,
rowids: row,
};
let v = eval_expr_scope(&order.expr, &scope)?;
keys.push((i, v));
}
keys.sort_by(|(_, a), (_, b)| {
let ord = compare_values(Some(a), Some(b));
if order.ascending { ord } else { ord.reverse() }
});
let mut sorted = Vec::with_capacity(filtered.len());
for (i, _) in keys {
sorted.push(filtered[i].clone());
}
filtered = sorted;
}
if let Some(k) = query.limit {
filtered.truncate(k);
}
let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
for row in &filtered {
let scope = JoinedScope {
tables: &joined_tables,
rowids: row,
};
let mut out_row = Vec::with_capacity(proj_items.len());
for item in &proj_items {
let v = match &item.kind {
ProjectionKind::Column { qualifier, name } => {
scope.lookup(qualifier.as_deref(), name)?
}
ProjectionKind::Aggregate(_) => {
return Err(SQLRiteError::Internal(
"aggregate functions over JOIN are not supported".to_string(),
));
}
};
out_row.push(v);
}
rows.push(out_row);
}
Ok(SelectResult { columns, rows })
}
pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
let result = execute_select_rows(query, db)?;
let row_count = result.rows.len();
let mut print_table = PrintTable::new();
let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
print_table.add_row(PrintRow::new(header_cells));
for row in &result.rows {
let cells: Vec<PrintCell> = row
.iter()
.map(|v| PrintCell::new(&v.to_display_string()))
.collect();
print_table.add_row(PrintRow::new(cells));
}
Ok((print_table.to_string(), row_count))
}
pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
let Statement::Delete(Delete {
from, selection, ..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_delete called on a non-DELETE statement".to_string(),
));
};
let tables = match from {
FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
};
let table_name = extract_single_table_name(tables)?;
let matching: Vec<i64> = {
let table = db
.get_table(table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
match select_rowids(table, selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in table.rowids() {
if let Some(expr) = selection {
if !eval_predicate(expr, table, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
}
};
let table = db.get_table_mut(table_name)?;
for rowid in &matching {
table.delete_row(*rowid);
}
if !matching.is_empty() {
for entry in &mut table.hnsw_indexes {
entry.needs_rebuild = true;
}
for entry in &mut table.fts_indexes {
entry.needs_rebuild = true;
}
}
Ok(matching.len())
}
pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
let Statement::Update(Update {
table,
assignments,
from,
selection,
..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_update called on a non-UPDATE statement".to_string(),
));
};
if from.is_some() {
return Err(SQLRiteError::NotImplemented(
"UPDATE ... FROM is not supported yet".to_string(),
));
}
let table_name = extract_table_name(table)?;
let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
{
let tbl = db
.get_table(table_name.clone())
.map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
for a in assignments {
let col = match &a.target {
AssignmentTarget::ColumnName(name) => name
.0
.last()
.map(|p| p.to_string())
.ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
AssignmentTarget::Tuple(_) => {
return Err(SQLRiteError::NotImplemented(
"tuple assignment targets are not supported".to_string(),
));
}
};
if !tbl.contains_column(col.clone()) {
return Err(SQLRiteError::Internal(format!(
"UPDATE references unknown column '{col}'"
)));
}
parsed_assignments.push((col, a.value.clone()));
}
}
let work: Vec<(i64, Vec<(String, Value)>)> = {
let tbl = db.get_table(table_name.clone())?;
let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
RowidSource::IndexProbe(rowids) => rowids,
RowidSource::FullScan => {
let mut out = Vec::new();
for rowid in tbl.rowids() {
if let Some(expr) = selection {
if !eval_predicate(expr, tbl, rowid)? {
continue;
}
}
out.push(rowid);
}
out
}
};
let mut rows_to_update = Vec::new();
for rowid in matched_rowids {
let mut values = Vec::with_capacity(parsed_assignments.len());
for (col, expr) in &parsed_assignments {
let v = eval_expr(expr, tbl, rowid)?;
values.push((col.clone(), v));
}
rows_to_update.push((rowid, values));
}
rows_to_update
};
let tbl = db.get_table_mut(table_name)?;
for (rowid, values) in &work {
for (col, v) in values {
tbl.set_value(col, *rowid, v.clone())?;
}
}
if !work.is_empty() {
let updated_columns: std::collections::HashSet<&str> = work
.iter()
.flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
.collect();
for entry in &mut tbl.hnsw_indexes {
if updated_columns.contains(entry.column_name.as_str()) {
entry.needs_rebuild = true;
}
}
for entry in &mut tbl.fts_indexes {
if updated_columns.contains(entry.column_name.as_str()) {
entry.needs_rebuild = true;
}
}
}
Ok(work.len())
}
pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
let Statement::CreateIndex(CreateIndex {
name,
table_name,
columns,
using,
unique,
if_not_exists,
predicate,
..
}) = stmt
else {
return Err(SQLRiteError::Internal(
"execute_create_index called on a non-CREATE-INDEX statement".to_string(),
));
};
if predicate.is_some() {
return Err(SQLRiteError::NotImplemented(
"partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
));
}
if columns.len() != 1 {
return Err(SQLRiteError::NotImplemented(format!(
"multi-column indexes are not supported yet ({} columns given)",
columns.len()
)));
}
let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
SQLRiteError::NotImplemented(
"anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
)
})?;
let method = match using {
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
IndexMethod::Hnsw
}
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
IndexMethod::Fts
}
Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
IndexMethod::Btree
}
Some(other) => {
return Err(SQLRiteError::NotImplemented(format!(
"CREATE INDEX … USING {other:?} is not supported \
(try `hnsw`, `fts`, or no USING clause)"
)));
}
None => IndexMethod::Btree,
};
let table_name_str = table_name.to_string();
let column_name = match &columns[0].column.expr {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"CREATE INDEX only supports simple column references, got {other:?}"
)));
}
};
let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
let table = db.get_table(table_name_str.clone()).map_err(|_| {
SQLRiteError::General(format!(
"CREATE INDEX references unknown table '{table_name_str}'"
))
})?;
if !table.contains_column(column_name.clone()) {
return Err(SQLRiteError::General(format!(
"CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
)));
}
let col = table
.columns
.iter()
.find(|c| c.column_name == column_name)
.expect("we just verified the column exists");
if table.index_by_name(&index_name).is_some()
|| table.hnsw_indexes.iter().any(|i| i.name == index_name)
|| table.fts_indexes.iter().any(|i| i.name == index_name)
{
if *if_not_exists {
return Ok(index_name);
}
return Err(SQLRiteError::General(format!(
"index '{index_name}' already exists"
)));
}
let datatype = clone_datatype(&col.datatype);
let mut pairs = Vec::new();
for rowid in table.rowids() {
if let Some(v) = table.get_value(&column_name, rowid) {
pairs.push((rowid, v));
}
}
(datatype, pairs)
};
match method {
IndexMethod::Btree => create_btree_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
IndexMethod::Hnsw => create_hnsw_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
IndexMethod::Fts => create_fts_index(
db,
&table_name_str,
&index_name,
&column_name,
&datatype,
*unique,
&existing_rowids_and_values,
),
}
}
pub fn execute_drop_table(
names: &[ObjectName],
if_exists: bool,
db: &mut Database,
) -> Result<usize> {
if names.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"DROP TABLE supports a single table per statement".to_string(),
));
}
let name = names[0].to_string();
if name == crate::sql::pager::MASTER_TABLE_NAME {
return Err(SQLRiteError::General(format!(
"'{}' is a reserved name used by the internal schema catalog",
crate::sql::pager::MASTER_TABLE_NAME
)));
}
if !db.contains_table(name.clone()) {
return if if_exists {
Ok(0)
} else {
Err(SQLRiteError::General(format!(
"Table '{name}' does not exist"
)))
};
}
db.tables.remove(&name);
Ok(1)
}
pub fn execute_drop_index(
names: &[ObjectName],
if_exists: bool,
db: &mut Database,
) -> Result<usize> {
if names.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"DROP INDEX supports a single index per statement".to_string(),
));
}
let name = names[0].to_string();
for table in db.tables.values_mut() {
if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
if secondary.origin == IndexOrigin::Auto {
return Err(SQLRiteError::General(format!(
"cannot drop auto-created index '{name}' (drop the column or table instead)"
)));
}
table.secondary_indexes.retain(|i| i.name != name);
return Ok(1);
}
if table.hnsw_indexes.iter().any(|i| i.name == name) {
table.hnsw_indexes.retain(|i| i.name != name);
return Ok(1);
}
if table.fts_indexes.iter().any(|i| i.name == name) {
table.fts_indexes.retain(|i| i.name != name);
return Ok(1);
}
}
if if_exists {
Ok(0)
} else {
Err(SQLRiteError::General(format!(
"Index '{name}' does not exist"
)))
}
}
pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
let table_name = alter.name.to_string();
if table_name == crate::sql::pager::MASTER_TABLE_NAME {
return Err(SQLRiteError::General(format!(
"'{}' is a reserved name used by the internal schema catalog",
crate::sql::pager::MASTER_TABLE_NAME
)));
}
if !db.contains_table(table_name.clone()) {
return if alter.if_exists {
Ok("ALTER TABLE: no-op (table does not exist)".to_string())
} else {
Err(SQLRiteError::General(format!(
"Table '{table_name}' does not exist"
)))
};
}
if alter.operations.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"ALTER TABLE supports one operation per statement".to_string(),
));
}
match &alter.operations[0] {
AlterTableOperation::RenameTable { table_name: kind } => {
let new_name = match kind {
RenameTableNameKind::To(name) => name.to_string(),
RenameTableNameKind::As(_) => {
return Err(SQLRiteError::NotImplemented(
"ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
.to_string(),
));
}
};
alter_rename_table(db, &table_name, &new_name)?;
Ok(format!(
"ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
))
}
AlterTableOperation::RenameColumn {
old_column_name,
new_column_name,
} => {
let old = old_column_name.value.clone();
let new = new_column_name.value.clone();
db.get_table_mut(table_name.clone())?
.rename_column(&old, &new)?;
Ok(format!(
"ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
))
}
AlterTableOperation::AddColumn {
column_def,
if_not_exists,
..
} => {
let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
let table = db.get_table_mut(table_name.clone())?;
if *if_not_exists && table.contains_column(parsed.name.clone()) {
return Ok(format!(
"ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
parsed.name
));
}
let col_name = parsed.name.clone();
table.add_column(parsed)?;
Ok(format!(
"ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
))
}
AlterTableOperation::DropColumn {
column_names,
if_exists,
..
} => {
if column_names.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
));
}
let col_name = column_names[0].value.clone();
let table = db.get_table_mut(table_name.clone())?;
if *if_exists && !table.contains_column(col_name.clone()) {
return Ok(format!(
"ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
));
}
table.drop_column(&col_name)?;
Ok(format!(
"ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
))
}
other => Err(SQLRiteError::NotImplemented(format!(
"ALTER TABLE operation {other:?} is not supported"
))),
}
}
pub fn execute_vacuum(db: &mut Database) -> Result<String> {
if db.in_transaction() {
return Err(SQLRiteError::General(
"VACUUM cannot run inside a transaction".to_string(),
));
}
let path = match db.source_path.clone() {
Some(p) => p,
None => {
return Ok("VACUUM is a no-op for in-memory databases".to_string());
}
};
if let Some(pager) = db.pager.as_mut() {
let _ = pager.checkpoint();
}
let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
let pages_before = db
.pager
.as_ref()
.map(|p| p.header().page_count)
.unwrap_or(0);
crate::sql::pager::vacuum_database(db, &path)?;
if let Some(pager) = db.pager.as_mut() {
let _ = pager.checkpoint();
}
let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
let pages_after = db
.pager
.as_ref()
.map(|p| p.header().page_count)
.unwrap_or(0);
let pages_reclaimed = pages_before.saturating_sub(pages_after);
let bytes_reclaimed = size_before.saturating_sub(size_after);
Ok(format!(
"VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
))
}
fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
if new == crate::sql::pager::MASTER_TABLE_NAME {
return Err(SQLRiteError::General(format!(
"'{}' is a reserved name used by the internal schema catalog",
crate::sql::pager::MASTER_TABLE_NAME
)));
}
if old == new {
return Ok(());
}
if db.contains_table(new.to_string()) {
return Err(SQLRiteError::General(format!(
"target table '{new}' already exists"
)));
}
let mut table = db
.tables
.remove(old)
.ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
table.tb_name = new.to_string();
for idx in table.secondary_indexes.iter_mut() {
idx.table_name = new.to_string();
if idx.origin == IndexOrigin::Auto
&& idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
{
idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
}
}
db.tables.insert(new.to_string(), table);
Ok(())
}
#[derive(Debug, Clone, Copy)]
enum IndexMethod {
Btree,
Hnsw,
Fts,
}
fn create_btree_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
let mut idx = SecondaryIndex::new(
index_name.to_string(),
table_name.to_string(),
column_name.to_string(),
datatype,
unique,
IndexOrigin::Explicit,
)?;
for (rowid, v) in existing {
if unique && idx.would_violate_unique(v) {
return Err(SQLRiteError::General(format!(
"cannot create UNIQUE index '{index_name}': column '{column_name}' \
already contains the duplicate value {}",
v.to_display_string()
)));
}
idx.insert(v, *rowid)?;
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.secondary_indexes.push(idx);
Ok(index_name.to_string())
}
fn create_hnsw_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
let dim = match datatype {
DataType::Vector(d) => *d,
other => {
return Err(SQLRiteError::General(format!(
"USING hnsw requires a VECTOR column; '{column_name}' is {other}"
)));
}
};
if unique {
return Err(SQLRiteError::General(
"UNIQUE has no meaning for HNSW indexes".to_string(),
));
}
let seed = hash_str_to_seed(index_name);
let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
std::collections::HashMap::with_capacity(existing.len());
for (rowid, v) in existing {
match v {
Value::Vector(vec) => {
if vec.len() != dim {
return Err(SQLRiteError::Internal(format!(
"row {rowid} stores a {}-dim vector in column '{column_name}' \
declared as VECTOR({dim}) — schema invariant violated",
vec.len()
)));
}
vec_map.insert(*rowid, vec.clone());
}
_ => continue,
}
}
for (rowid, _) in existing {
if let Some(v) = vec_map.get(rowid) {
let v_clone = v.clone();
idx.insert(*rowid, &v_clone, |id| {
vec_map.get(&id).cloned().unwrap_or_default()
});
}
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.hnsw_indexes.push(HnswIndexEntry {
name: index_name.to_string(),
column_name: column_name.to_string(),
index: idx,
needs_rebuild: false,
});
Ok(index_name.to_string())
}
fn create_fts_index(
db: &mut Database,
table_name: &str,
index_name: &str,
column_name: &str,
datatype: &DataType,
unique: bool,
existing: &[(i64, Value)],
) -> Result<String> {
match datatype {
DataType::Text => {}
other => {
return Err(SQLRiteError::General(format!(
"USING fts requires a TEXT column; '{column_name}' is {other}"
)));
}
}
if unique {
return Err(SQLRiteError::General(
"UNIQUE has no meaning for FTS indexes".to_string(),
));
}
let mut idx = PostingList::new();
for (rowid, v) in existing {
if let Value::Text(text) = v {
idx.insert(*rowid, text);
}
}
let table_mut = db.get_table_mut(table_name.to_string())?;
table_mut.fts_indexes.push(FtsIndexEntry {
name: index_name.to_string(),
column_name: column_name.to_string(),
index: idx,
needs_rebuild: false,
});
Ok(index_name.to_string())
}
fn hash_str_to_seed(s: &str) -> u64 {
let mut h: u64 = 0xCBF29CE484222325;
for b in s.as_bytes() {
h ^= *b as u64;
h = h.wrapping_mul(0x100000001B3);
}
h
}
fn clone_datatype(dt: &DataType) -> DataType {
match dt {
DataType::Integer => DataType::Integer,
DataType::Text => DataType::Text,
DataType::Real => DataType::Real,
DataType::Bool => DataType::Bool,
DataType::Vector(dim) => DataType::Vector(*dim),
DataType::Json => DataType::Json,
DataType::None => DataType::None,
DataType::Invalid => DataType::Invalid,
}
}
fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
if tables.len() != 1 {
return Err(SQLRiteError::NotImplemented(
"multi-table DELETE is not supported yet".to_string(),
));
}
extract_table_name(&tables[0])
}
fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
if !twj.joins.is_empty() {
return Err(SQLRiteError::NotImplemented(
"JOIN is not supported yet".to_string(),
));
}
match &twj.relation {
TableFactor::Table { name, .. } => Ok(name.to_string()),
_ => Err(SQLRiteError::NotImplemented(
"only plain table references are supported".to_string(),
)),
}
}
enum RowidSource {
IndexProbe(Vec<i64>),
FullScan,
}
fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
let Some(expr) = selection else {
return Ok(RowidSource::FullScan);
};
let Some((col, literal)) = try_extract_equality(expr) else {
return Ok(RowidSource::FullScan);
};
let Some(idx) = table.index_for_column(&col) else {
return Ok(RowidSource::FullScan);
};
let literal_value = match convert_literal(&literal) {
Ok(v) => v,
Err(_) => return Ok(RowidSource::FullScan),
};
let mut rowids = idx.lookup(&literal_value);
rowids.sort_unstable();
Ok(RowidSource::IndexProbe(rowids))
}
fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
let peeled = match expr {
Expr::Nested(inner) => inner.as_ref(),
other => other,
};
let Expr::BinaryOp { left, op, right } = peeled else {
return None;
};
if !matches!(op, BinaryOperator::Eq) {
return None;
}
let col_from = |e: &Expr| -> Option<String> {
match e {
Expr::Identifier(ident) => Some(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
_ => None,
}
};
let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
if let Expr::Value(v) = e {
Some(v.value.clone())
} else {
None
}
};
if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
return Some((c, l));
}
if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
return Some((c, l));
}
None
}
fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
if k == 0 {
return None;
}
let func = match order_expr {
Expr::Function(f) => f,
_ => return None,
};
let fname = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => return None,
};
if fname != "vec_distance_l2" {
return None;
}
let arg_list = match &func.args {
FunctionArguments::List(l) => &l.args,
_ => return None,
};
if arg_list.len() != 2 {
return None;
}
let exprs: Vec<&Expr> = arg_list
.iter()
.filter_map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
_ => None,
})
.collect();
if exprs.len() != 2 {
return None;
}
let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
Some(v) => v,
None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
Some(v) => v,
None => return None,
},
};
let entry = table
.hnsw_indexes
.iter()
.find(|e| e.column_name == col_name)?;
let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
Some(c) => match &c.datatype {
DataType::Vector(d) => *d,
_ => return None,
},
None => return None,
};
if query_vec.len() != declared_dim {
return None;
}
let column_for_closure = col_name.clone();
let table_ref = table;
let result = entry.index.search(&query_vec, k, |id| {
match table_ref.get_value(&column_for_closure, id) {
Some(Value::Vector(v)) => v,
_ => Vec::new(),
}
});
Some(result)
}
fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
if k == 0 || ascending {
return None;
}
let func = match order_expr {
Expr::Function(f) => f,
_ => return None,
};
let fname = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => return None,
};
if fname != "bm25_score" {
return None;
}
let arg_list = match &func.args {
FunctionArguments::List(l) => &l.args,
_ => return None,
};
if arg_list.len() != 2 {
return None;
}
let exprs: Vec<&Expr> = arg_list
.iter()
.filter_map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
_ => None,
})
.collect();
if exprs.len() != 2 {
return None;
}
let col_name = match exprs[0] {
Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
_ => return None,
};
let query = match exprs[1] {
Expr::Value(v) => match &v.value {
AstValue::SingleQuotedString(s) => s.clone(),
_ => return None,
},
_ => return None,
};
let entry = table
.fts_indexes
.iter()
.find(|e| e.column_name == col_name)?;
let scored = entry.index.query(&query, &Bm25Params::default());
let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
if out.len() > k {
out.truncate(k);
}
Some(out)
}
fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
let col_name = match a {
Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
_ => return None,
};
let lit_str = match b {
Expr::Identifier(ident) if ident.quote_style == Some('[') => {
format!("[{}]", ident.value)
}
_ => return None,
};
let v = parse_vector_literal(&lit_str).ok()?;
Some((col_name, v))
}
struct HeapEntry {
key: Value,
rowid: i64,
asc: bool,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
let raw = compare_values(Some(&self.key), Some(&other.key));
if self.asc { raw } else { raw.reverse() }
}
}
fn select_topk(
matching: &[i64],
table: &Table,
order: &OrderByClause,
k: usize,
) -> Result<Vec<i64>> {
use std::collections::BinaryHeap;
if k == 0 || matching.is_empty() {
return Ok(Vec::new());
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
for &rowid in matching {
let key = eval_expr(&order.expr, table, rowid)?;
let entry = HeapEntry {
key,
rowid,
asc: order.ascending,
};
if heap.len() < k {
heap.push(entry);
} else {
if entry < *heap.peek().unwrap() {
heap.pop();
heap.push(entry);
}
}
}
Ok(heap
.into_sorted_vec()
.into_iter()
.map(|e| e.rowid)
.collect())
}
fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
let mut keys: Vec<(i64, Result<Value>)> = rowids
.iter()
.map(|r| (*r, eval_expr(&order.expr, table, *r)))
.collect();
for (_, k) in &keys {
if let Err(e) = k {
return Err(SQLRiteError::General(format!(
"ORDER BY expression failed: {e}"
)));
}
}
keys.sort_by(|(_, ka), (_, kb)| {
let va = ka.as_ref().unwrap();
let vb = kb.as_ref().unwrap();
let ord = compare_values(Some(va), Some(vb));
if order.ascending { ord } else { ord.reverse() }
});
for (i, (rowid, _)) in keys.into_iter().enumerate() {
rowids[i] = rowid;
}
Ok(())
}
fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
match (a, b) {
(None, None) => Ordering::Equal,
(None, _) => Ordering::Less,
(_, None) => Ordering::Greater,
(Some(a), Some(b)) => match (a, b) {
(Value::Null, Value::Null) => Ordering::Equal,
(Value::Null, _) => Ordering::Less,
(_, Value::Null) => Ordering::Greater,
(Value::Integer(x), Value::Integer(y)) => x.cmp(y),
(Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
(Value::Integer(x), Value::Real(y)) => {
(*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
}
(Value::Real(x), Value::Integer(y)) => {
x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
}
(Value::Text(x), Value::Text(y)) => x.cmp(y),
(Value::Bool(x), Value::Bool(y)) => x.cmp(y),
(x, y) => x.to_display_string().cmp(&y.to_display_string()),
},
}
}
pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
}
pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
let v = eval_expr_scope(expr, scope)?;
match v {
Value::Bool(b) => Ok(b),
Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
other => Err(SQLRiteError::Internal(format!(
"WHERE clause must evaluate to boolean, got {}",
other.to_display_string()
))),
}
}
fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
}
fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
match expr {
Expr::Nested(inner) => eval_expr_scope(inner, scope),
Expr::Identifier(ident) => {
if ident.quote_style == Some('[') {
let raw = format!("[{}]", ident.value);
let v = parse_vector_literal(&raw)?;
return Ok(Value::Vector(v));
}
scope.lookup(None, &ident.value)
}
Expr::CompoundIdentifier(parts) => {
match parts.as_slice() {
[only] => scope.lookup(None, &only.value),
[q, c] => scope.lookup(Some(&q.value), &c.value),
_ => Err(SQLRiteError::NotImplemented(format!(
"compound identifier with {} parts is not supported",
parts.len()
))),
}
}
Expr::Value(v) => convert_literal(&v.value),
Expr::UnaryOp { op, expr } => {
let inner = eval_expr_scope(expr, scope)?;
match op {
UnaryOperator::Not => match inner {
Value::Bool(b) => Ok(Value::Bool(!b)),
Value::Null => Ok(Value::Null),
other => Err(SQLRiteError::Internal(format!(
"NOT applied to non-boolean value: {}",
other.to_display_string()
))),
},
UnaryOperator::Minus => match inner {
Value::Integer(i) => Ok(Value::Integer(-i)),
Value::Real(f) => Ok(Value::Real(-f)),
Value::Null => Ok(Value::Null),
other => Err(SQLRiteError::Internal(format!(
"unary minus on non-numeric value: {}",
other.to_display_string()
))),
},
UnaryOperator::Plus => Ok(inner),
other => Err(SQLRiteError::NotImplemented(format!(
"unary operator {other:?} is not supported"
))),
}
}
Expr::BinaryOp { left, op, right } => match op {
BinaryOperator::And => {
let l = eval_expr_scope(left, scope)?;
let r = eval_expr_scope(right, scope)?;
Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
}
BinaryOperator::Or => {
let l = eval_expr_scope(left, scope)?;
let r = eval_expr_scope(right, scope)?;
Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
}
cmp @ (BinaryOperator::Eq
| BinaryOperator::NotEq
| BinaryOperator::Lt
| BinaryOperator::LtEq
| BinaryOperator::Gt
| BinaryOperator::GtEq) => {
let l = eval_expr_scope(left, scope)?;
let r = eval_expr_scope(right, scope)?;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Bool(false));
}
let ord = compare_values(Some(&l), Some(&r));
let result = match cmp {
BinaryOperator::Eq => ord == Ordering::Equal,
BinaryOperator::NotEq => ord != Ordering::Equal,
BinaryOperator::Lt => ord == Ordering::Less,
BinaryOperator::LtEq => ord != Ordering::Greater,
BinaryOperator::Gt => ord == Ordering::Greater,
BinaryOperator::GtEq => ord != Ordering::Less,
_ => unreachable!(),
};
Ok(Value::Bool(result))
}
arith @ (BinaryOperator::Plus
| BinaryOperator::Minus
| BinaryOperator::Multiply
| BinaryOperator::Divide
| BinaryOperator::Modulo) => {
let l = eval_expr_scope(left, scope)?;
let r = eval_expr_scope(right, scope)?;
eval_arith(arith, &l, &r)
}
BinaryOperator::StringConcat => {
let l = eval_expr_scope(left, scope)?;
let r = eval_expr_scope(right, scope)?;
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
Ok(Value::Text(format!(
"{}{}",
l.to_display_string(),
r.to_display_string()
)))
}
other => Err(SQLRiteError::NotImplemented(format!(
"binary operator {other:?} is not supported yet"
))),
},
Expr::IsNull(inner) => {
let v = eval_expr_scope(inner, scope)?;
Ok(Value::Bool(matches!(v, Value::Null)))
}
Expr::IsNotNull(inner) => {
let v = eval_expr_scope(inner, scope)?;
Ok(Value::Bool(!matches!(v, Value::Null)))
}
Expr::Like {
negated,
any,
expr: lhs,
pattern,
escape_char,
} => eval_like(
scope,
*negated,
*any,
lhs,
pattern,
escape_char.as_ref(),
true,
),
Expr::ILike {
negated,
any,
expr: lhs,
pattern,
escape_char,
} => eval_like(
scope,
*negated,
*any,
lhs,
pattern,
escape_char.as_ref(),
true,
),
Expr::InList {
expr: lhs,
list,
negated,
} => eval_in_list(scope, lhs, list, *negated),
Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
"IN (subquery) is not supported (only literal lists are)".to_string(),
)),
Expr::Function(func) => eval_function(func, scope),
other => Err(SQLRiteError::NotImplemented(format!(
"unsupported expression in WHERE/projection: {other:?}"
))),
}
}
fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
let name = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
_ => {
return Err(SQLRiteError::NotImplemented(format!(
"qualified function names not supported: {:?}",
func.name
)));
}
};
match name.as_str() {
"vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
let dist = match name.as_str() {
"vec_distance_l2" => vec_distance_l2(&a, &b),
"vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
"vec_distance_dot" => vec_distance_dot(&a, &b),
_ => unreachable!(),
};
Ok(Value::Real(dist as f64))
}
"json_extract" => json_fn_extract(&name, &func.args, scope),
"json_type" => json_fn_type(&name, &func.args, scope),
"json_array_length" => json_fn_array_length(&name, &func.args, scope),
"json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
"fts_match" | "bm25_score" => {
let Some((table, rowid)) = scope.single_table_view() else {
return Err(SQLRiteError::NotImplemented(format!(
"{name}() is not yet supported inside a JOIN query — \
use it on a single-table SELECT or move the FTS lookup into a subquery"
)));
};
let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
Ok(match name.as_str() {
"fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
"bm25_score" => {
Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
}
_ => unreachable!(),
})
}
"count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
"aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
use it as a top-level projection item (HAVING is not yet supported)"
))),
other => Err(SQLRiteError::NotImplemented(format!(
"unknown function: {other}(...)"
))),
}
}
fn resolve_fts_args<'t>(
fn_name: &str,
args: &FunctionArguments,
table: &'t Table,
scope: &dyn RowScope,
) -> Result<(&'t FtsIndexEntry, String)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly two arguments: (column, query_text)"
)));
}
};
if arg_list.len() != 2 {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly 2 arguments, got {}",
arg_list.len()
)));
}
let col_expr = match &arg_list[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 0 must be a column name, got {other:?}"
)));
}
};
let col_name = match col_expr {
Expr::Identifier(ident) => ident.value.clone(),
Expr::CompoundIdentifier(parts) => parts
.last()
.map(|p| p.value.clone())
.ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument 0 must be a column reference, got {other:?}"
)));
}
};
let q_expr = match &arg_list[1] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 1 must be a text expression, got {other:?}"
)));
}
};
let query = match eval_expr_scope(q_expr, scope)? {
Value::Text(s) => s,
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument 1 must be TEXT, got {}",
other.to_display_string()
)));
}
};
let entry = table
.fts_indexes
.iter()
.find(|e| e.column_name == col_name)
.ok_or_else(|| {
SQLRiteError::General(format!(
"{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
(run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
))
})?;
Ok((entry, query))
}
fn extract_json_and_path(
fn_name: &str,
args: &FunctionArguments,
scope: &dyn RowScope,
) -> Result<(String, String)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects 1 or 2 arguments"
)));
}
};
if !(arg_list.len() == 1 || arg_list.len() == 2) {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects 1 or 2 arguments, got {}",
arg_list.len()
)));
}
let first_expr = match &arg_list[0] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 0 has unsupported shape: {other:?}"
)));
}
};
let json_text = match eval_expr_scope(first_expr, scope)? {
Value::Text(s) => s,
Value::Null => {
return Err(SQLRiteError::General(format!(
"{fn_name}() called on NULL — JSON column has no value for this row"
)));
}
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument 0 is not JSON-typed: got {}",
other.to_display_string()
)));
}
};
let path = if arg_list.len() == 2 {
let path_expr = match &arg_list[1] {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument 1 has unsupported shape: {other:?}"
)));
}
};
match eval_expr_scope(path_expr, scope)? {
Value::Text(s) => s,
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() path argument must be a string literal, got {}",
other.to_display_string()
)));
}
}
} else {
"$".to_string()
};
Ok((json_text, path))
}
fn walk_json_path<'a>(
value: &'a serde_json::Value,
path: &str,
) -> Result<Option<&'a serde_json::Value>> {
let mut chars = path.chars().peekable();
if chars.next() != Some('$') {
return Err(SQLRiteError::General(format!(
"JSON path must start with '$', got `{path}`"
)));
}
let mut current = value;
while let Some(&c) = chars.peek() {
match c {
'.' => {
chars.next();
let mut key = String::new();
while let Some(&c) = chars.peek() {
if c == '.' || c == '[' {
break;
}
key.push(c);
chars.next();
}
if key.is_empty() {
return Err(SQLRiteError::General(format!(
"JSON path has empty key after '.' in `{path}`"
)));
}
match current.get(&key) {
Some(v) => current = v,
None => return Ok(None),
}
}
'[' => {
chars.next();
let mut idx_str = String::new();
while let Some(&c) = chars.peek() {
if c == ']' {
break;
}
idx_str.push(c);
chars.next();
}
if chars.next() != Some(']') {
return Err(SQLRiteError::General(format!(
"JSON path has unclosed `[` in `{path}`"
)));
}
let idx: usize = idx_str.trim().parse().map_err(|_| {
SQLRiteError::General(format!(
"JSON path has non-integer index `[{idx_str}]` in `{path}`"
))
})?;
match current.get(idx) {
Some(v) => current = v,
None => return Ok(None),
}
}
other => {
return Err(SQLRiteError::General(format!(
"JSON path has unexpected character `{other}` in `{path}` \
(expected `.`, `[`, or end-of-path)"
)));
}
}
}
Ok(Some(current))
}
fn json_value_to_sql(v: &serde_json::Value) -> Value {
match v {
serde_json::Value::Null => Value::Null,
serde_json::Value::Bool(b) => Value::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
Value::Integer(i)
} else if let Some(f) = n.as_f64() {
Value::Real(f)
} else {
Value::Null
}
}
serde_json::Value::String(s) => Value::Text(s.clone()),
composite => Value::Text(composite.to_string()),
}
}
fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, scope)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
match walk_json_path(&parsed, &path)? {
Some(v) => Ok(json_value_to_sql(v)),
None => Ok(Value::Null),
}
}
fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, scope)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
let resolved = match walk_json_path(&parsed, &path)? {
Some(v) => v,
None => return Ok(Value::Null),
};
let ty = match resolved {
serde_json::Value::Null => "null",
serde_json::Value::Bool(true) => "true",
serde_json::Value::Bool(false) => "false",
serde_json::Value::Number(n) => {
if n.is_i64() || n.is_u64() {
"integer"
} else {
"real"
}
}
serde_json::Value::String(_) => "text",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
};
Ok(Value::Text(ty.to_string()))
}
fn json_fn_array_length(
name: &str,
args: &FunctionArguments,
scope: &dyn RowScope,
) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, scope)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
let resolved = match walk_json_path(&parsed, &path)? {
Some(v) => v,
None => return Ok(Value::Null),
};
match resolved.as_array() {
Some(arr) => Ok(Value::Integer(arr.len() as i64)),
None => Err(SQLRiteError::General(format!(
"{name}() resolved to a non-array value at path `{path}`"
))),
}
}
fn json_fn_object_keys(
name: &str,
args: &FunctionArguments,
scope: &dyn RowScope,
) -> Result<Value> {
let (json_text, path) = extract_json_and_path(name, args, scope)?;
let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
})?;
let resolved = match walk_json_path(&parsed, &path)? {
Some(v) => v,
None => return Ok(Value::Null),
};
let obj = resolved.as_object().ok_or_else(|| {
SQLRiteError::General(format!(
"{name}() resolved to a non-object value at path `{path}`"
))
})?;
let keys: Vec<serde_json::Value> = obj
.keys()
.map(|k| serde_json::Value::String(k.clone()))
.collect();
Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
}
fn extract_two_vector_args(
fn_name: &str,
args: &FunctionArguments,
scope: &dyn RowScope,
) -> Result<(Vec<f32>, Vec<f32>)> {
let arg_list = match args {
FunctionArguments::List(l) => &l.args,
_ => {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly two vector arguments"
)));
}
};
if arg_list.len() != 2 {
return Err(SQLRiteError::General(format!(
"{fn_name}() expects exactly 2 arguments, got {}",
arg_list.len()
)));
}
let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
for (i, arg) in arg_list.iter().enumerate() {
let expr = match arg {
FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
other => {
return Err(SQLRiteError::NotImplemented(format!(
"{fn_name}() argument {i} has unsupported shape: {other:?}"
)));
}
};
let val = eval_expr_scope(expr, scope)?;
match val {
Value::Vector(v) => out.push(v),
other => {
return Err(SQLRiteError::General(format!(
"{fn_name}() argument {i} is not a vector: got {}",
other.to_display_string()
)));
}
}
}
let b = out.pop().unwrap();
let a = out.pop().unwrap();
if a.len() != b.len() {
return Err(SQLRiteError::General(format!(
"{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
a.len(),
b.len()
)));
}
Ok((a, b))
}
pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut sum = 0.0f32;
for i in 0..a.len() {
let d = a[i] - b[i];
sum += d * d;
}
sum.sqrt()
}
pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
let mut norm_a_sq = 0.0f32;
let mut norm_b_sq = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a_sq += a[i] * a[i];
norm_b_sq += b[i] * b[i];
}
let denom = (norm_a_sq * norm_b_sq).sqrt();
if denom == 0.0 {
return Err(SQLRiteError::General(
"vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
));
}
Ok(1.0 - dot / denom)
}
pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
debug_assert_eq!(a.len(), b.len());
let mut dot = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
}
-dot
}
fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
if matches!(l, Value::Null) || matches!(r, Value::Null) {
return Ok(Value::Null);
}
match (l, r) {
(Value::Integer(a), Value::Integer(b)) => match op {
BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
BinaryOperator::Divide => {
if *b == 0 {
Err(SQLRiteError::General("division by zero".to_string()))
} else {
Ok(Value::Integer(a / b))
}
}
BinaryOperator::Modulo => {
if *b == 0 {
Err(SQLRiteError::General("modulo by zero".to_string()))
} else {
Ok(Value::Integer(a % b))
}
}
_ => unreachable!(),
},
(a, b) => {
let af = as_number(a)?;
let bf = as_number(b)?;
match op {
BinaryOperator::Plus => Ok(Value::Real(af + bf)),
BinaryOperator::Minus => Ok(Value::Real(af - bf)),
BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
BinaryOperator::Divide => {
if bf == 0.0 {
Err(SQLRiteError::General("division by zero".to_string()))
} else {
Ok(Value::Real(af / bf))
}
}
BinaryOperator::Modulo => {
if bf == 0.0 {
Err(SQLRiteError::General("modulo by zero".to_string()))
} else {
Ok(Value::Real(af % bf))
}
}
_ => unreachable!(),
}
}
}
}
fn as_number(v: &Value) -> Result<f64> {
match v {
Value::Integer(i) => Ok(*i as f64),
Value::Real(f) => Ok(*f),
Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
other => Err(SQLRiteError::General(format!(
"arithmetic on non-numeric value '{}'",
other.to_display_string()
))),
}
}
fn as_bool(v: &Value) -> Result<bool> {
match v {
Value::Bool(b) => Ok(*b),
Value::Null => Ok(false),
Value::Integer(i) => Ok(*i != 0),
other => Err(SQLRiteError::Internal(format!(
"expected boolean, got {}",
other.to_display_string()
))),
}
}
#[allow(clippy::too_many_arguments)]
fn eval_like(
scope: &dyn RowScope,
negated: bool,
any: bool,
lhs: &Expr,
pattern: &Expr,
escape_char: Option<&AstValue>,
case_insensitive: bool,
) -> Result<Value> {
if any {
return Err(SQLRiteError::NotImplemented(
"LIKE ANY (...) is not supported".to_string(),
));
}
if escape_char.is_some() {
return Err(SQLRiteError::NotImplemented(
"LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
));
}
let l = eval_expr_scope(lhs, scope)?;
let p = eval_expr_scope(pattern, scope)?;
if matches!(l, Value::Null) || matches!(p, Value::Null) {
return Ok(Value::Null);
}
let text = match l {
Value::Text(s) => s,
other => other.to_display_string(),
};
let pat = match p {
Value::Text(s) => s,
other => other.to_display_string(),
};
let m = like_match(&text, &pat, case_insensitive);
Ok(Value::Bool(if negated { !m } else { m }))
}
fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
let l = eval_expr_scope(lhs, scope)?;
if matches!(l, Value::Null) {
return Ok(Value::Null);
}
let mut saw_null = false;
for item in list {
let r = eval_expr_scope(item, scope)?;
if matches!(r, Value::Null) {
saw_null = true;
continue;
}
if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
return Ok(Value::Bool(!negated));
}
}
if saw_null {
Ok(Value::Null)
} else {
Ok(Value::Bool(negated))
}
}
fn aggregate_rows(
table: &Table,
matching: &[i64],
group_by: &[String],
proj_items: &[ProjectionItem],
) -> Result<Vec<Vec<Value>>> {
let template: Vec<Option<AggState>> = proj_items
.iter()
.map(|i| match &i.kind {
ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
ProjectionKind::Column { .. } => None,
})
.collect();
let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
let mut group_key_values: Vec<Vec<Value>> = Vec::new();
for &rowid in matching {
let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
for col in group_by {
let v = table.get_value(col, rowid).unwrap_or(Value::Null);
key.push(DistinctKey::from_value(&v));
key_values.push(v);
}
let idx = match keys.iter().position(|k| k == &key) {
Some(i) => i,
None => {
keys.push(key);
group_states.push(template.clone());
group_key_values.push(key_values);
keys.len() - 1
}
};
for (slot, item) in proj_items.iter().enumerate() {
if let ProjectionKind::Aggregate(call) = &item.kind {
let v = match &call.arg {
AggregateArg::Star => Value::Null,
AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
};
if let Some(state) = group_states[idx][slot].as_mut() {
state.update(&v)?;
}
}
}
}
if keys.is_empty() && group_by.is_empty() {
keys.push(Vec::new());
group_states.push(template.clone());
group_key_values.push(Vec::new());
}
let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
for (group_idx, _) in keys.iter().enumerate() {
let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
for (slot, item) in proj_items.iter().enumerate() {
match &item.kind {
ProjectionKind::Column { name: c, .. } => {
let pos = group_by
.iter()
.position(|g| g == c)
.expect("validated to be in GROUP BY");
row.push(group_key_values[group_idx][pos].clone());
}
ProjectionKind::Aggregate(_) => {
let state = group_states[group_idx][slot]
.as_ref()
.expect("aggregate slot has state");
row.push(state.finalize());
}
}
}
rows.push(row);
}
Ok(rows)
}
fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
use std::collections::HashSet;
let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
let mut out = Vec::with_capacity(rows.len());
for row in rows {
let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
if seen.insert(key) {
out.push(row);
}
}
out
}
fn sort_output_rows(
rows: &mut [Vec<Value>],
columns: &[String],
proj_items: &[ProjectionItem],
order: &OrderByClause,
) -> Result<()> {
let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
rows.sort_by(|a, b| {
let va = &a[target_idx];
let vb = &b[target_idx];
let ord = compare_values(Some(va), Some(vb));
if order.ascending { ord } else { ord.reverse() }
});
Ok(())
}
fn resolve_order_by_index(
expr: &Expr,
columns: &[String],
proj_items: &[ProjectionItem],
) -> Result<usize> {
let target_name: Option<String> = match expr {
Expr::Identifier(ident) => Some(ident.value.clone()),
Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
Expr::Function(_) => None,
Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
other => {
return Err(SQLRiteError::NotImplemented(format!(
"ORDER BY expression not supported on aggregating queries: {other:?}"
)));
}
};
if let Some(name) = target_name {
if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
return Ok(i);
}
return Err(SQLRiteError::Internal(format!(
"ORDER BY references unknown column '{name}' in the SELECT output"
)));
}
if let Expr::Function(func) = expr {
let user_disp = format_function_display(func);
for (i, item) in proj_items.iter().enumerate() {
if let ProjectionKind::Aggregate(call) = &item.kind
&& call.display_name().eq_ignore_ascii_case(&user_disp)
{
return Ok(i);
}
}
return Err(SQLRiteError::Internal(format!(
"ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
)));
}
Err(SQLRiteError::Internal(
"ORDER BY expression could not be resolved against the output columns".to_string(),
))
}
fn format_function_display(func: &sqlparser::ast::Function) -> String {
let name = match func.name.0.as_slice() {
[ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
_ => format!("{:?}", func.name).to_uppercase(),
};
let inner = match &func.args {
FunctionArguments::List(l) => {
let distinct = matches!(
l.duplicate_treatment,
Some(sqlparser::ast::DuplicateTreatment::Distinct)
);
let arg = l.args.first().map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
parts.last().map(|p| p.value.clone()).unwrap_or_default()
}
_ => String::new(),
});
match (distinct, arg) {
(true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
(_, Some(a)) => a,
_ => String::new(),
}
}
_ => String::new(),
};
format!("{name}({inner})")
}
fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
use sqlparser::ast::Value as AstValue;
match v {
AstValue::Number(n, _) => {
if let Ok(i) = n.parse::<i64>() {
Ok(Value::Integer(i))
} else if let Ok(f) = n.parse::<f64>() {
Ok(Value::Real(f))
} else {
Err(SQLRiteError::Internal(format!(
"could not parse numeric literal '{n}'"
)))
}
}
AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
AstValue::Boolean(b) => Ok(Value::Bool(*b)),
AstValue::Null => Ok(Value::Null),
other => Err(SQLRiteError::NotImplemented(format!(
"unsupported literal value: {other:?}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
(a - b).abs() < eps
}
#[test]
fn vec_distance_l2_identical_is_zero() {
let v = vec![0.1, 0.2, 0.3];
assert_eq!(vec_distance_l2(&v, &v), 0.0);
}
#[test]
fn vec_distance_l2_unit_basis_is_sqrt2() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
}
#[test]
fn vec_distance_l2_known_value() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![3.0, 4.0, 0.0];
assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
}
#[test]
fn vec_distance_cosine_identical_is_zero() {
let v = vec![0.1, 0.2, 0.3];
let d = vec_distance_cosine(&v, &v).unwrap();
assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
}
#[test]
fn vec_distance_cosine_orthogonal_is_one() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
}
#[test]
fn vec_distance_cosine_opposite_is_two() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
}
#[test]
fn vec_distance_cosine_zero_magnitude_errors() {
let a = vec![0.0, 0.0];
let b = vec![1.0, 0.0];
let err = vec_distance_cosine(&a, &b).unwrap_err();
assert!(format!("{err}").contains("zero-magnitude"));
}
#[test]
fn vec_distance_dot_negates() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
}
#[test]
fn vec_distance_dot_orthogonal_is_zero() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
assert_eq!(vec_distance_dot(&a, &b), 0.0);
}
#[test]
fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
let cos = vec_distance_cosine(&a, &b).unwrap();
assert!(approx_eq(dot, cos - 1.0, 1e-5));
}
use crate::sql::db::database::Database;
use crate::sql::parser::select::SelectQuery;
use sqlparser::dialect::SQLiteDialect;
use sqlparser::parser::Parser;
fn seed_score_table(n: usize) -> Database {
let mut db = Database::new("tempdb".to_string());
crate::sql::process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
&mut db,
)
.expect("create");
for i in 0..n {
let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
let sql = format!("INSERT INTO docs (score) VALUES ({score});");
crate::sql::process_command(&sql, &mut db).expect("insert");
}
db
}
fn parse_select(sql: &str) -> SelectQuery {
let dialect = SQLiteDialect {};
let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
let stmt = ast.pop().expect("one statement");
SelectQuery::new(&stmt).expect("select-query")
}
#[test]
fn topk_matches_full_sort_asc() {
let db = seed_score_table(200);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(10);
let topk = select_topk(&all_rowids, table, order, 10).unwrap();
assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
}
#[test]
fn topk_matches_full_sort_desc() {
let db = seed_score_table(200);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(10);
let topk = select_topk(&all_rowids, table, order, 10).unwrap();
assert_eq!(
topk, full,
"top-k DESC via heap should match full-sort+truncate"
);
}
#[test]
fn topk_k_larger_than_n_returns_everything_sorted() {
let db = seed_score_table(50);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
assert_eq!(topk.len(), 50);
let scores: Vec<f64> = topk
.iter()
.filter_map(|r| match table.get_value("score", *r) {
Some(Value::Real(f)) => Some(f),
_ => None,
})
.collect();
assert!(scores.windows(2).all(|w| w[0] <= w[1]));
}
#[test]
fn topk_k_zero_returns_empty() {
let db = seed_score_table(10);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
assert!(topk.is_empty());
}
#[test]
fn topk_empty_input_returns_empty() {
let db = seed_score_table(0);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
let order = q.order_by.as_ref().unwrap();
let topk = select_topk(&[], table, order, 5).unwrap();
assert!(topk.is_empty());
}
#[test]
fn topk_works_through_select_executor_with_distance_function() {
let mut db = Database::new("tempdb".to_string());
crate::sql::process_command(
"CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
&mut db,
)
.unwrap();
for v in &[
"[1.0, 0.0]",
"[2.0, 0.0]",
"[0.0, 3.0]",
"[1.0, 4.0]",
"[10.0, 10.0]",
] {
crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
.unwrap();
}
let resp = crate::sql::process_command(
"SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
&mut db,
)
.unwrap();
assert!(resp.contains("3 rows returned"), "got: {resp}");
}
#[test]
#[ignore]
fn topk_benchmark() {
use std::time::Instant;
const N: usize = 10_000;
const K: usize = 10;
let db = seed_score_table(N);
let table = db.get_table("docs".to_string()).unwrap();
let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
let order = q.order_by.as_ref().unwrap();
let all_rowids = table.rowids();
let t0 = Instant::now();
let _topk = select_topk(&all_rowids, table, order, K).unwrap();
let heap_dur = t0.elapsed();
let t1 = Instant::now();
let mut full = all_rowids.clone();
sort_rowids(&mut full, table, order).unwrap();
full.truncate(K);
let sort_dur = t1.elapsed();
let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
println!("\n--- topk_benchmark (N={N}, k={K}) ---");
println!(" bounded heap: {heap_dur:?}");
println!(" full sort+trunc: {sort_dur:?}");
println!(" speedup ratio: {ratio:.2}×");
assert!(
ratio > 1.4,
"bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
);
}
fn run_select(db: &mut Database, sql: &str) -> String {
crate::sql::process_command(sql, db).expect("select")
}
#[test]
fn where_is_null_returns_null_rows() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
&mut db,
)
.unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
assert!(
response.contains("2 rows returned"),
"IS NULL should return 2 rows, got: {response}"
);
}
#[test]
fn where_is_not_null_returns_non_null_rows() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
&mut db,
)
.unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
assert!(
response.contains("2 rows returned"),
"IS NOT NULL should return 2 rows, got: {response}"
);
}
#[test]
fn where_is_null_on_indexed_column() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
&mut db,
)
.unwrap();
crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
.unwrap();
crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
.unwrap();
let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
assert!(
null_rows.contains("1 row returned"),
"indexed IS NULL should return 1 row, got: {null_rows}"
);
let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
assert!(
not_null_rows.contains("2 rows returned"),
"indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
);
}
#[test]
fn where_is_null_works_on_omitted_column() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
&mut db,
)
.unwrap();
crate::sql::process_command(
"INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
&mut db,
)
.unwrap();
crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
assert!(
response.contains("1 row returned"),
"IS NULL should match the omitted-column row, got: {response}"
);
}
#[test]
fn where_is_null_combines_with_and_or() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
&mut db,
)
.unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
assert!(
response.contains("1 row returned"),
"IS NULL combined with AND should match exactly row 2, got: {response}"
);
}
fn seed_employees() -> Database {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
&mut db,
)
.unwrap();
let rows = [
"INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
"INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
"INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
"INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
"INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
"INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
];
for sql in rows {
crate::sql::process_command(sql, &mut db).unwrap();
}
db
}
fn run_rows(db: &Database, sql: &str) -> SelectResult {
let q = parse_select(sql);
execute_select_rows(q, db).expect("select")
}
#[test]
fn like_percent_prefix_case_insensitive() {
let db = seed_employees();
let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
assert!(names.contains(&"Alice".to_string()));
assert!(names.contains(&"alex".to_string()));
}
#[test]
fn like_underscore_singlechar() {
let db = seed_employees();
let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
assert_eq!(names, vec!["Eve".to_string()]);
}
#[test]
fn not_like_excludes_match() {
let db = seed_employees();
let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
assert_eq!(r.rows.len(), 4);
}
#[test]
fn like_with_null_excludes_row() {
let db = seed_employees();
let r = run_rows(
&db,
"SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
);
assert_eq!(r.rows.len(), 1);
assert_eq!(r.rows[0][0].to_display_string(), "Dave");
}
#[test]
fn in_list_positive() {
let db = seed_employees();
let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
assert_eq!(names.len(), 3);
assert!(names.contains(&"Alice".to_string()));
assert!(names.contains(&"Bob".to_string()));
assert!(names.contains(&"Dave".to_string()));
}
#[test]
fn not_in_excludes_listed() {
let db = seed_employees();
let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
assert_eq!(r.rows.len(), 4);
}
#[test]
fn in_list_with_null_three_valued() {
let db = seed_employees();
let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
assert_eq!(r.rows.len(), 1);
assert_eq!(r.rows[0][0].to_display_string(), "Alice");
}
#[test]
fn distinct_single_column() {
let db = seed_employees();
let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
assert_eq!(r.rows.len(), 3);
}
#[test]
fn distinct_multi_column_with_null() {
let db = seed_employees();
let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
assert_eq!(r.rows.len(), 5);
}
#[test]
fn count_star_no_groupby() {
let db = seed_employees();
let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
assert_eq!(r.rows.len(), 1);
assert_eq!(r.rows[0][0], Value::Integer(6));
}
#[test]
fn count_col_skips_nulls() {
let db = seed_employees();
let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
assert_eq!(r.rows[0][0], Value::Integer(5));
}
#[test]
fn count_distinct_dedupes_and_skips_nulls() {
let db = seed_employees();
let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
assert_eq!(r.rows[0][0], Value::Integer(4));
}
#[test]
fn sum_int_stays_integer() {
let db = seed_employees();
let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
assert_eq!(r.rows[0][0], Value::Integer(490));
}
#[test]
fn avg_returns_real() {
let db = seed_employees();
let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
match &r.rows[0][0] {
Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
other => panic!("expected Real, got {other:?}"),
}
}
#[test]
fn min_max_skip_nulls() {
let db = seed_employees();
let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
assert_eq!(r.rows[0][0], Value::Integer(80));
assert_eq!(r.rows[0][1], Value::Integer(120));
}
#[test]
fn aggregates_on_empty_table_emit_one_row() {
let mut db = Database::new("t".to_string());
crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
let r = run_rows(
&db,
"SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
);
assert_eq!(r.rows.len(), 1);
assert_eq!(r.rows[0][0], Value::Integer(0));
assert_eq!(r.rows[0][1], Value::Null);
assert_eq!(r.rows[0][2], Value::Null);
assert_eq!(r.rows[0][3], Value::Null);
assert_eq!(r.rows[0][4], Value::Null);
}
#[test]
fn group_by_single_col_with_count() {
let db = seed_employees();
let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
assert_eq!(r.rows.len(), 3);
let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
for row in &r.rows {
let d = row[0].to_display_string();
let c = match &row[1] {
Value::Integer(i) => *i,
v => panic!("expected Integer count, got {v:?}"),
};
by_dept.insert(d, c);
}
assert_eq!(by_dept["eng"], 3);
assert_eq!(by_dept["sales"], 2);
assert_eq!(by_dept["ops"], 1);
}
#[test]
fn group_by_with_where_filter() {
let db = seed_employees();
let r = run_rows(
&db,
"SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
);
let by: std::collections::HashMap<String, i64> = r
.rows
.iter()
.map(|row| {
(
row[0].to_display_string(),
match &row[1] {
Value::Integer(i) => *i,
v => panic!("expected Integer sum, got {v:?}"),
},
)
})
.collect();
assert_eq!(by.len(), 2);
assert_eq!(by["eng"], 320);
assert_eq!(by["sales"], 90);
}
#[test]
fn group_by_without_aggregates_is_distinct() {
let db = seed_employees();
let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
assert_eq!(r.rows.len(), 3);
}
#[test]
fn order_by_count_desc() {
let db = seed_employees();
let r = run_rows(
&db,
"SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
);
assert_eq!(r.rows.len(), 2);
assert_eq!(r.rows[0][0].to_display_string(), "eng");
assert_eq!(r.rows[0][1], Value::Integer(3));
}
#[test]
fn order_by_aggregate_call_form() {
let db = seed_employees();
let r = run_rows(
&db,
"SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
);
assert_eq!(r.rows.len(), 3);
assert_eq!(r.rows[0][0].to_display_string(), "eng");
}
#[test]
fn group_by_invalid_bare_column_errors() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
&mut db,
)
.unwrap();
let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
}
#[test]
fn aggregate_in_where_errors_friendly() {
let mut db = Database::new("t".to_string());
crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
assert!(err.is_err(), "aggregates must not be allowed in WHERE");
}
fn seed_join_fixture() -> Database {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
"CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
"INSERT INTO customers (name) VALUES ('Alice');",
"INSERT INTO customers (name) VALUES ('Bob');",
"INSERT INTO customers (name) VALUES ('Carol');",
"INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
"INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
"INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
"INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
db
}
#[test]
fn inner_join_returns_only_matched_rows() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name, orders.amount FROM customers \
INNER JOIN orders ON customers.id = orders.customer_id;",
);
assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
let pairs: Vec<(String, i64)> = r
.rows
.iter()
.map(|row| {
(
row[0].to_display_string(),
match row[1] {
Value::Integer(i) => i,
ref v => panic!("expected integer amount, got {v:?}"),
},
)
})
.collect();
assert_eq!(pairs.len(), 3);
assert!(pairs.contains(&("Alice".to_string(), 100)));
assert!(pairs.contains(&("Alice".to_string(), 200)));
assert!(pairs.contains(&("Bob".to_string(), 50)));
}
#[test]
fn bare_join_defaults_to_inner() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name FROM customers \
JOIN orders ON customers.id = orders.customer_id;",
);
assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
}
#[test]
fn left_outer_join_preserves_unmatched_left() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name, orders.amount FROM customers \
LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
);
assert_eq!(r.rows.len(), 4);
let carol = r
.rows
.iter()
.find(|row| row[0].to_display_string() == "Carol")
.expect("Carol should appear with a NULL-padded right side");
assert_eq!(carol[1], Value::Null);
}
#[test]
fn right_outer_join_preserves_unmatched_right() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name, orders.amount FROM customers \
RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
);
assert_eq!(r.rows.len(), 4);
let dangling = r
.rows
.iter()
.find(|row| matches!(row[1], Value::Integer(999)))
.expect("dangling order 999 should appear with a NULL-padded customer name");
assert_eq!(dangling[0], Value::Null);
}
#[test]
fn full_outer_join_preserves_both_sides() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name, orders.amount FROM customers \
FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
);
assert_eq!(r.rows.len(), 5);
assert!(
r.rows
.iter()
.any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
);
assert!(
r.rows
.iter()
.any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
);
}
#[test]
fn join_with_table_aliases_resolves_qualifiers() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT c.name, o.amount FROM customers AS c \
INNER JOIN orders AS o ON c.id = o.customer_id;",
);
assert_eq!(r.rows.len(), 3);
assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
}
#[test]
fn join_with_where_filter_applies_after_join() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name, orders.amount FROM customers \
INNER JOIN orders ON customers.id = orders.customer_id \
WHERE orders.amount >= 100;",
);
assert_eq!(r.rows.len(), 2);
assert!(
r.rows
.iter()
.all(|row| row[0].to_display_string() == "Alice")
);
}
#[test]
fn left_join_with_where_on_right_side_is_not_inner() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT customers.name, orders.amount FROM customers \
LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
WHERE orders.amount IS NULL;",
);
assert_eq!(r.rows.len(), 1);
assert_eq!(r.rows[0][0].to_display_string(), "Carol");
assert_eq!(r.rows[0][1], Value::Null);
}
#[test]
fn select_star_over_join_emits_all_columns_from_both_tables() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT * FROM customers \
INNER JOIN orders ON customers.id = orders.customer_id;",
);
assert_eq!(
r.columns,
vec![
"id".to_string(),
"name".to_string(),
"id".to_string(),
"customer_id".to_string(),
"amount".to_string(),
]
);
assert_eq!(r.rows.len(), 3);
}
#[test]
fn join_order_by_sorts_full_joined_rows() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT c.name, o.amount FROM customers AS c \
INNER JOIN orders AS o ON c.id = o.customer_id \
ORDER BY o.amount;",
);
let amounts: Vec<i64> = r
.rows
.iter()
.map(|row| match row[1] {
Value::Integer(i) => i,
ref v => panic!("expected integer, got {v:?}"),
})
.collect();
assert_eq!(amounts, vec![50, 100, 200]);
}
#[test]
fn join_limit_truncates_after_join_and_sort() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT c.name, o.amount FROM customers AS c \
INNER JOIN orders AS o ON c.id = o.customer_id \
ORDER BY o.amount DESC LIMIT 2;",
);
assert_eq!(r.rows.len(), 2);
let amounts: Vec<i64> = r
.rows
.iter()
.map(|row| match row[1] {
Value::Integer(i) => i,
ref v => panic!("expected integer, got {v:?}"),
})
.collect();
assert_eq!(amounts, vec![200, 100]);
}
#[test]
fn three_table_join_chains_correctly() {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
"CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
"CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
"INSERT INTO a (label) VALUES ('a-one');",
"INSERT INTO a (label) VALUES ('a-two');",
"INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
"INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
"INSERT INTO c (b_id, note) VALUES (1, 'c1');",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
let r = run_rows(
&db,
"SELECT a.label, b.tag, c.note FROM a \
INNER JOIN b ON a.id = b.a_id \
INNER JOIN c ON b.id = c.b_id;",
);
assert_eq!(r.rows.len(), 1);
assert_eq!(r.rows[0][0].to_display_string(), "a-one");
assert_eq!(r.rows[0][1].to_display_string(), "b1");
assert_eq!(r.rows[0][2].to_display_string(), "c1");
}
#[test]
fn ambiguous_unqualified_column_in_join_errors() {
let db = seed_join_fixture();
let q = parse_select(
"SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
);
let res = execute_select_rows(q, &db);
assert!(res.is_err(), "unqualified ambiguous 'id' should error");
}
#[test]
fn join_self_without_alias_is_rejected() {
let mut db = Database::new("t".to_string());
crate::sql::process_command(
"CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
&mut db,
)
.unwrap();
let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
let res = execute_select_rows(q, &db);
assert!(
res.is_err(),
"self-join without an alias should error on duplicate qualifier"
);
}
#[test]
fn using_or_natural_join_returns_not_implemented() {
let mut db = Database::new("t".to_string());
crate::sql::process_command("CREATE TABLE a (id INTEGER PRIMARY KEY);", &mut db).unwrap();
crate::sql::process_command("CREATE TABLE b (id INTEGER PRIMARY KEY);", &mut db).unwrap();
let err = crate::sql::process_command("SELECT * FROM a INNER JOIN b USING (id);", &mut db);
assert!(err.is_err(), "USING is not yet supported");
let err = crate::sql::process_command("SELECT * FROM a NATURAL JOIN b;", &mut db);
assert!(err.is_err(), "NATURAL is not supported");
}
#[test]
fn aggregates_over_join_are_rejected() {
let db = seed_join_fixture();
let err = crate::sql::process_command(
"SELECT COUNT(*) FROM customers \
INNER JOIN orders ON customers.id = orders.customer_id;",
&mut seed_join_fixture(),
);
assert!(err.is_err(), "aggregates over JOIN are not yet supported");
let _ = db; }
#[test]
fn left_join_with_no_matches_pads_every_row() {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
"CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
"INSERT INTO a (x) VALUES (1);",
"INSERT INTO a (x) VALUES (2);",
"INSERT INTO b (y) VALUES (10);",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
let r = run_rows(
&db,
"SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
);
assert_eq!(r.rows.len(), 2);
for row in &r.rows {
assert_eq!(row[1], Value::Null);
}
}
#[test]
fn left_outer_join_order_by_places_nulls_first() {
let db = seed_join_fixture();
let r = run_rows(
&db,
"SELECT c.name, o.amount FROM customers AS c \
LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
ORDER BY o.amount ASC;",
);
assert_eq!(r.rows.len(), 4);
assert_eq!(r.rows[0][0].to_display_string(), "Carol");
assert_eq!(r.rows[0][1], Value::Null);
}
#[test]
fn chained_left_outer_join_preserves_left_through_two_levels() {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
"CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
"CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
"INSERT INTO a (label) VALUES ('a-one');",
"INSERT INTO a (label) VALUES ('a-two');",
"INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
let r = run_rows(
&db,
"SELECT a.label, b.tag, c.note FROM a \
LEFT OUTER JOIN b ON a.id = b.a_id \
LEFT OUTER JOIN c ON b.id = c.b_id;",
);
assert_eq!(r.rows.len(), 2);
let by_label: std::collections::HashMap<String, &Vec<Value>> = r
.rows
.iter()
.map(|row| (row[0].to_display_string(), row))
.collect();
assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
assert_eq!(by_label["a-one"][2], Value::Null);
assert_eq!(by_label["a-two"][1], Value::Null);
assert_eq!(by_label["a-two"][2], Value::Null);
}
#[test]
fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
"CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
"CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
"INSERT INTO a (x) VALUES (1);",
"INSERT INTO b (x) VALUES (1);",
"INSERT INTO c (x) VALUES (1);",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
let q =
parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
let res = execute_select_rows(q, &db);
assert!(
res.is_err(),
"ON referencing not-yet-joined table 'c' should error"
);
}
#[test]
fn join_on_truthy_integer_is_accepted() {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
"CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
"INSERT INTO a (x) VALUES (1);",
"INSERT INTO a (x) VALUES (2);",
"INSERT INTO b (y) VALUES (10);",
"INSERT INTO b (y) VALUES (20);",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
assert_eq!(r.rows.len(), 4);
}
#[test]
fn full_join_on_empty_tables_returns_empty() {
let mut db = Database::new("t".to_string());
for sql in [
"CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
"CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
] {
crate::sql::process_command(sql, &mut db).unwrap();
}
let r = run_rows(
&db,
"SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
);
assert!(r.rows.is_empty());
}
}