1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, JoinConstraintKind, JoinType, OrderByClause, Projection, ProjectionItem,
25 ProjectionKind, SelectQuery,
26};
27
28pub(crate) trait RowScope {
57 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
58
59 fn single_table_view(&self) -> Option<(&Table, i64)>;
65}
66
67pub(crate) struct SingleTableScope<'a> {
69 table: &'a Table,
70 rowid: i64,
71}
72
73impl<'a> SingleTableScope<'a> {
74 pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
75 Self { table, rowid }
76 }
77}
78
79impl RowScope for SingleTableScope<'_> {
80 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
81 let _ = qualifier;
86 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
87 }
88
89 fn single_table_view(&self) -> Option<(&Table, i64)> {
90 Some((self.table, self.rowid))
91 }
92}
93
94pub(crate) struct JoinedTableRef<'a> {
98 pub table: &'a Table,
99 pub scope_name: String,
100}
101
102pub(crate) struct JoinedScope<'a> {
106 pub tables: &'a [JoinedTableRef<'a>],
107 pub rowids: &'a [Option<i64>],
108}
109
110impl RowScope for JoinedScope<'_> {
111 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
112 if let Some(q) = qualifier {
113 let pos = self
116 .tables
117 .iter()
118 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
119 .ok_or_else(|| {
120 SQLRiteError::Internal(format!(
121 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
122 ))
123 })?;
124 if !self.tables[pos].table.contains_column(col.to_string()) {
125 return Err(SQLRiteError::Internal(format!(
126 "column '{col}' does not exist on '{}'",
127 self.tables[pos].scope_name
128 )));
129 }
130 return Ok(match self.rowids[pos] {
131 None => Value::Null,
132 Some(r) => self.tables[pos]
133 .table
134 .get_value(col, r)
135 .unwrap_or(Value::Null),
136 });
137 }
138 let mut hit: Option<usize> = None;
142 for (i, t) in self.tables.iter().enumerate() {
143 if t.table.contains_column(col.to_string()) {
144 if hit.is_some() {
145 return Err(SQLRiteError::Internal(format!(
146 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
147 )));
148 }
149 hit = Some(i);
150 }
151 }
152 let i = hit.ok_or_else(|| {
153 SQLRiteError::Internal(format!(
154 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
155 ))
156 })?;
157 Ok(match self.rowids[i] {
158 None => Value::Null,
159 Some(r) => self.tables[i]
160 .table
161 .get_value(col, r)
162 .unwrap_or(Value::Null),
163 })
164 }
165
166 fn single_table_view(&self) -> Option<(&Table, i64)> {
167 None
168 }
169}
170
171pub struct SelectResult {
180 pub columns: Vec<String>,
181 pub rows: Vec<Vec<Value>>,
182}
183
184pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
188 if !query.joins.is_empty() {
193 return execute_select_rows_joined(query, db);
194 }
195
196 let master_snapshot;
205 let table: &Table = if query.table_name == crate::sql::pager::MASTER_TABLE_NAME {
206 master_snapshot = crate::sql::pager::build_master_table_snapshot(db)?;
207 &master_snapshot
208 } else {
209 db.get_table(query.table_name.clone()).map_err(|_| {
210 SQLRiteError::Internal(format!("Table '{}' not found", query.table_name))
211 })?
212 };
213
214 let proj_items: Vec<ProjectionItem> = match &query.projection {
219 Projection::All => table
220 .column_names()
221 .into_iter()
222 .map(|c| ProjectionItem {
223 kind: ProjectionKind::Column {
224 qualifier: None,
225 name: c,
226 },
227 alias: None,
228 })
229 .collect(),
230 Projection::Items(items) => items.clone(),
231 };
232 let has_aggregates = proj_items
233 .iter()
234 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
235 for item in &proj_items {
237 if let ProjectionKind::Column { name: c, .. } = &item.kind
238 && !table.contains_column(c.clone())
239 {
240 return Err(SQLRiteError::Internal(format!(
241 "Column '{c}' does not exist on table '{}'",
242 query.table_name
243 )));
244 }
245 }
246 for c in &query.group_by {
247 if !table.contains_column(c.clone()) {
248 return Err(SQLRiteError::Internal(format!(
249 "GROUP BY references unknown column '{c}' on table '{}'",
250 query.table_name
251 )));
252 }
253 }
254 let matching = match select_rowids(table, query.selection.as_ref())? {
258 RowidSource::IndexProbe(rowids) => rowids,
259 RowidSource::FullScan => {
260 let mut out = Vec::new();
261 for rowid in table.rowids() {
262 if let Some(expr) = &query.selection
263 && !eval_predicate(expr, table, rowid)?
264 {
265 continue;
266 }
267 out.push(rowid);
268 }
269 out
270 }
271 };
272 let mut matching = matching;
273
274 let aggregating = has_aggregates || !query.group_by.is_empty();
275
276 if aggregating {
282 for item in &proj_items {
284 if let ProjectionKind::Aggregate(call) = &item.kind
285 && let AggregateArg::Column(c) = &call.arg
286 && !table.contains_column(c.clone())
287 {
288 return Err(SQLRiteError::Internal(format!(
289 "{}({}) references unknown column '{c}' on table '{}'",
290 call.func.as_str(),
291 c,
292 query.table_name
293 )));
294 }
295 }
296
297 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
298 let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
299
300 if query.distinct {
301 rows = dedupe_rows(rows);
302 }
303
304 if let Some(order) = &query.order_by {
305 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
306 }
307 if let Some(k) = query.limit {
308 rows.truncate(k);
309 }
310
311 return Ok(SelectResult { columns, rows });
312 }
313
314 let defer_limit_for_distinct = query.distinct;
352 match (&query.order_by, query.limit) {
353 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
354 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
355 }
356 (Some(order), Some(k))
357 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
358 {
359 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
360 }
361 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
362 matching = select_topk(&matching, table, order, k)?;
363 }
364 (Some(order), _) => {
365 sort_rowids(&mut matching, table, order)?;
366 if let Some(k) = query.limit
367 && !defer_limit_for_distinct
368 {
369 matching.truncate(k);
370 }
371 }
372 (None, Some(k)) if !defer_limit_for_distinct => {
373 matching.truncate(k);
374 }
375 _ => {}
376 }
377
378 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
379 let projected_cols: Vec<String> = proj_items
380 .iter()
381 .map(|i| match &i.kind {
382 ProjectionKind::Column { name, .. } => name.clone(),
383 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
384 })
385 .collect();
386
387 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
391 for rowid in &matching {
392 let row: Vec<Value> = projected_cols
393 .iter()
394 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
395 .collect();
396 rows.push(row);
397 }
398
399 if query.distinct {
400 rows = dedupe_rows(rows);
401 if let Some(k) = query.limit {
402 rows.truncate(k);
403 }
404 }
405
406 Ok(SelectResult { columns, rows })
407}
408
409struct ResolvedJoin {
414 on: Expr,
415 using_columns: Vec<String>,
416}
417
418fn resolve_join_constraint(
433 constraint: &JoinConstraintKind,
434 tables: &[JoinedTableRef<'_>],
435 right_pos: usize,
436) -> Result<ResolvedJoin> {
437 match constraint {
438 JoinConstraintKind::On(expr) => Ok(ResolvedJoin {
439 on: (**expr).clone(),
440 using_columns: Vec::new(),
441 }),
442 JoinConstraintKind::Using(cols) => build_using_join(cols, tables, right_pos),
443 JoinConstraintKind::Natural => {
444 let shared: Vec<String> = tables[right_pos]
448 .table
449 .column_names()
450 .into_iter()
451 .filter(|c| {
452 tables[..right_pos]
453 .iter()
454 .any(|t| t.table.contains_column(c.clone()))
455 })
456 .collect();
457 build_using_join(&shared, tables, right_pos)
458 }
459 }
460}
461
462fn build_using_join(
467 cols: &[String],
468 tables: &[JoinedTableRef<'_>],
469 right_pos: usize,
470) -> Result<ResolvedJoin> {
471 let right = &tables[right_pos];
472 let mut predicate: Option<Expr> = None;
473 for col in cols {
474 if !right.table.contains_column(col.clone()) {
476 return Err(SQLRiteError::Internal(format!(
477 "cannot join USING column '{col}' — it is not present on table '{}'",
478 right.scope_name
479 )));
480 }
481 let left = tables[..right_pos]
484 .iter()
485 .find(|t| t.table.contains_column(col.clone()))
486 .ok_or_else(|| {
487 SQLRiteError::Internal(format!(
488 "cannot join USING column '{col}' — it is not present on any left-side table"
489 ))
490 })?;
491 let eq = col_eq(&left.scope_name, &right.scope_name, col);
492 predicate = Some(match predicate {
493 None => eq,
494 Some(prev) => Expr::BinaryOp {
495 left: Box::new(prev),
496 op: BinaryOperator::And,
497 right: Box::new(eq),
498 },
499 });
500 }
501 Ok(ResolvedJoin {
502 on: predicate
503 .unwrap_or_else(|| Expr::Value(sqlparser::ast::Value::Boolean(true).with_empty_span())),
504 using_columns: cols.to_vec(),
505 })
506}
507
508fn col_eq(left_scope: &str, right_scope: &str, col: &str) -> Expr {
511 let col_ref = |scope: &str| {
512 Expr::CompoundIdentifier(vec![
513 Ident::new(scope.to_string()),
514 Ident::new(col.to_string()),
515 ])
516 };
517 Expr::BinaryOp {
518 left: Box::new(col_ref(left_scope)),
519 op: BinaryOperator::Eq,
520 right: Box::new(col_ref(right_scope)),
521 }
522}
523
524fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
551 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
558
559 let primary = db
560 .get_table(query.table_name.clone())
561 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
562 joined_tables.push(JoinedTableRef {
563 table: primary,
564 scope_name: query
565 .table_alias
566 .clone()
567 .unwrap_or_else(|| query.table_name.clone()),
568 });
569 for j in &query.joins {
570 let t = db
571 .get_table(j.right_table.clone())
572 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
573 joined_tables.push(JoinedTableRef {
574 table: t,
575 scope_name: j
576 .right_alias
577 .clone()
578 .unwrap_or_else(|| j.right_table.clone()),
579 });
580 }
581
582 {
587 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
588 for t in &joined_tables {
589 let key = t.scope_name.to_ascii_lowercase();
590 if !seen.insert(key) {
591 return Err(SQLRiteError::Internal(format!(
592 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
593 t.scope_name
594 )));
595 }
596 }
597 }
598
599 let resolved: Vec<ResolvedJoin> = query
607 .joins
608 .iter()
609 .enumerate()
610 .map(|(j_idx, join)| resolve_join_constraint(&join.constraint, &joined_tables, j_idx + 1))
611 .collect::<Result<Vec<_>>>()?;
612
613 let proj_items: Vec<ProjectionItem> = match &query.projection {
619 Projection::All => {
620 let mut all = Vec::new();
636 for (t_idx, t) in joined_tables.iter().enumerate() {
637 let dedup: &[String] = t_idx
640 .checked_sub(1)
641 .map(|r| resolved[r].using_columns.as_slice())
642 .unwrap_or(&[]);
643 for col in t.table.column_names() {
644 if dedup.contains(&col) {
645 continue;
646 }
647 all.push(ProjectionItem {
648 kind: ProjectionKind::Column {
649 qualifier: Some(t.scope_name.clone()),
654 name: col,
655 },
656 alias: None,
657 });
658 }
659 }
660 all
661 }
662 Projection::Items(items) => items.clone(),
663 };
664
665 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
666
667 let mut acc: Vec<Vec<Option<i64>>> = primary
672 .rowids()
673 .into_iter()
674 .map(|r| {
675 let mut row = Vec::with_capacity(joined_tables.len());
676 row.push(Some(r));
677 row
678 })
679 .collect();
680
681 for (j_idx, join) in query.joins.iter().enumerate() {
686 let right_pos = j_idx + 1;
687 let right_table = joined_tables[right_pos].table;
688 let right_rowids: Vec<i64> = right_table.rowids();
689
690 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
694
695 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
696
697 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
705
706 for left_row in acc.into_iter() {
707 let mut left_match_count = 0usize;
711 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
712 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
713 on_rowids.push(Some(rrid));
714 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
715 let scope = JoinedScope {
716 tables: on_scope_tables,
717 rowids: &on_rowids,
718 };
719 if eval_predicate_scope(&resolved[j_idx].on, &scope)? {
726 left_match_count += 1;
727 right_matched[r_idx] = true;
728 next_acc.push(on_rowids);
733 }
734 }
735
736 if left_match_count == 0
737 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
738 {
739 let mut padded = left_row;
742 padded.push(None);
743 next_acc.push(padded);
744 }
745 }
746
747 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
751 for (r_idx, matched) in right_matched.iter().enumerate() {
752 if *matched {
753 continue;
754 }
755 let mut row: Vec<Option<i64>> = vec![None; right_pos];
756 row.push(Some(right_rowids[r_idx]));
757 next_acc.push(row);
758 }
759 }
760
761 acc = next_acc;
762 }
763
764 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
769 let mut out = Vec::with_capacity(acc.len());
770 for row in acc {
771 let scope = JoinedScope {
772 tables: &joined_tables,
773 rowids: &row,
774 };
775 if eval_predicate_scope(where_expr, &scope)? {
776 out.push(row);
777 }
778 }
779 out
780 } else {
781 acc
782 };
783
784 if let Some(order) = &query.order_by {
788 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
791 for (i, row) in filtered.iter().enumerate() {
792 let scope = JoinedScope {
793 tables: &joined_tables,
794 rowids: row,
795 };
796 let v = eval_expr_scope(&order.expr, &scope)?;
797 keys.push((i, v));
798 }
799 keys.sort_by(|(_, a), (_, b)| {
800 let ord = compare_values(Some(a), Some(b));
801 if order.ascending { ord } else { ord.reverse() }
802 });
803 let mut sorted = Vec::with_capacity(filtered.len());
804 for (i, _) in keys {
805 sorted.push(filtered[i].clone());
806 }
807 filtered = sorted;
808 }
809
810 if let Some(k) = query.limit {
812 filtered.truncate(k);
813 }
814
815 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
818 for row in &filtered {
819 let scope = JoinedScope {
820 tables: &joined_tables,
821 rowids: row,
822 };
823 let mut out_row = Vec::with_capacity(proj_items.len());
824 for item in &proj_items {
825 let v = match &item.kind {
826 ProjectionKind::Column { qualifier, name } => {
827 scope.lookup(qualifier.as_deref(), name)?
828 }
829 ProjectionKind::Aggregate(_) => {
830 return Err(SQLRiteError::Internal(
833 "aggregate functions over JOIN are not supported".to_string(),
834 ));
835 }
836 };
837 out_row.push(v);
838 }
839 rows.push(out_row);
840 }
841
842 Ok(SelectResult { columns, rows })
843}
844
845pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
850 let result = execute_select_rows(query, db)?;
851 let row_count = result.rows.len();
852
853 let mut print_table = PrintTable::new();
854 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
855 print_table.add_row(PrintRow::new(header_cells));
856
857 for row in &result.rows {
858 let cells: Vec<PrintCell> = row
859 .iter()
860 .map(|v| PrintCell::new(&v.to_display_string()))
861 .collect();
862 print_table.add_row(PrintRow::new(cells));
863 }
864
865 Ok((print_table.to_string(), row_count))
866}
867
868pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
870 let Statement::Delete(Delete {
871 from, selection, ..
872 }) = stmt
873 else {
874 return Err(SQLRiteError::Internal(
875 "execute_delete called on a non-DELETE statement".to_string(),
876 ));
877 };
878
879 let tables = match from {
880 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
881 };
882 let table_name = extract_single_table_name(tables)?;
883
884 let matching: Vec<i64> = {
886 let table = db
887 .get_table(table_name.clone())
888 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
889 match select_rowids(table, selection.as_ref())? {
890 RowidSource::IndexProbe(rowids) => rowids,
891 RowidSource::FullScan => {
892 let mut out = Vec::new();
893 for rowid in table.rowids() {
894 if let Some(expr) = selection {
895 if !eval_predicate(expr, table, rowid)? {
896 continue;
897 }
898 }
899 out.push(rowid);
900 }
901 out
902 }
903 }
904 };
905
906 let table = db.get_table_mut(table_name)?;
907 for rowid in &matching {
908 table.delete_row(*rowid);
909 }
910 if !matching.is_empty() {
919 for entry in &mut table.hnsw_indexes {
920 entry.needs_rebuild = true;
921 }
922 for entry in &mut table.fts_indexes {
923 entry.needs_rebuild = true;
924 }
925 }
926 Ok(matching.len())
927}
928
929pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
931 let Statement::Update(Update {
932 table,
933 assignments,
934 from,
935 selection,
936 ..
937 }) = stmt
938 else {
939 return Err(SQLRiteError::Internal(
940 "execute_update called on a non-UPDATE statement".to_string(),
941 ));
942 };
943
944 if from.is_some() {
945 return Err(SQLRiteError::NotImplemented(
946 "UPDATE ... FROM is not supported yet".to_string(),
947 ));
948 }
949
950 let table_name = extract_table_name(table)?;
951
952 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
954 {
955 let tbl = db
956 .get_table(table_name.clone())
957 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
958 for a in assignments {
959 let col = match &a.target {
960 AssignmentTarget::ColumnName(name) => name
961 .0
962 .last()
963 .map(|p| p.to_string())
964 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
965 AssignmentTarget::Tuple(_) => {
966 return Err(SQLRiteError::NotImplemented(
967 "tuple assignment targets are not supported".to_string(),
968 ));
969 }
970 };
971 if !tbl.contains_column(col.clone()) {
972 return Err(SQLRiteError::Internal(format!(
973 "UPDATE references unknown column '{col}'"
974 )));
975 }
976 parsed_assignments.push((col, a.value.clone()));
977 }
978 }
979
980 let work: Vec<(i64, Vec<(String, Value)>)> = {
984 let tbl = db.get_table(table_name.clone())?;
985 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
986 RowidSource::IndexProbe(rowids) => rowids,
987 RowidSource::FullScan => {
988 let mut out = Vec::new();
989 for rowid in tbl.rowids() {
990 if let Some(expr) = selection {
991 if !eval_predicate(expr, tbl, rowid)? {
992 continue;
993 }
994 }
995 out.push(rowid);
996 }
997 out
998 }
999 };
1000 let mut rows_to_update = Vec::new();
1001 for rowid in matched_rowids {
1002 let mut values = Vec::with_capacity(parsed_assignments.len());
1003 for (col, expr) in &parsed_assignments {
1004 let v = eval_expr(expr, tbl, rowid)?;
1007 values.push((col.clone(), v));
1008 }
1009 rows_to_update.push((rowid, values));
1010 }
1011 rows_to_update
1012 };
1013
1014 let tbl = db.get_table_mut(table_name)?;
1015 for (rowid, values) in &work {
1016 for (col, v) in values {
1017 tbl.set_value(col, *rowid, v.clone())?;
1018 }
1019 }
1020
1021 if !work.is_empty() {
1030 let updated_columns: std::collections::HashSet<&str> = work
1031 .iter()
1032 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
1033 .collect();
1034 for entry in &mut tbl.hnsw_indexes {
1035 if updated_columns.contains(entry.column_name.as_str()) {
1036 entry.needs_rebuild = true;
1037 }
1038 }
1039 for entry in &mut tbl.fts_indexes {
1040 if updated_columns.contains(entry.column_name.as_str()) {
1041 entry.needs_rebuild = true;
1042 }
1043 }
1044 }
1045 Ok(work.len())
1046}
1047
1048pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
1060 let Statement::CreateIndex(CreateIndex {
1061 name,
1062 table_name,
1063 columns,
1064 using,
1065 unique,
1066 if_not_exists,
1067 predicate,
1068 with,
1069 ..
1070 }) = stmt
1071 else {
1072 return Err(SQLRiteError::Internal(
1073 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
1074 ));
1075 };
1076
1077 if predicate.is_some() {
1078 return Err(SQLRiteError::NotImplemented(
1079 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
1080 ));
1081 }
1082
1083 if columns.len() != 1 {
1084 return Err(SQLRiteError::NotImplemented(format!(
1085 "multi-column indexes are not supported yet ({} columns given)",
1086 columns.len()
1087 )));
1088 }
1089
1090 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
1091 SQLRiteError::NotImplemented(
1092 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
1093 )
1094 })?;
1095
1096 let method = match using {
1102 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
1103 IndexMethod::Hnsw
1104 }
1105 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
1106 IndexMethod::Fts
1107 }
1108 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
1109 IndexMethod::Btree
1110 }
1111 Some(other) => {
1112 return Err(SQLRiteError::NotImplemented(format!(
1113 "CREATE INDEX … USING {other:?} is not supported \
1114 (try `hnsw`, `fts`, or no USING clause)"
1115 )));
1116 }
1117 None => IndexMethod::Btree,
1118 };
1119
1120 let hnsw_metric = parse_hnsw_with_options(with, &index_name, method)?;
1126
1127 let table_name_str = table_name.to_string();
1128 let column_name = match &columns[0].column.expr {
1129 Expr::Identifier(ident) => ident.value.clone(),
1130 Expr::CompoundIdentifier(parts) => parts
1131 .last()
1132 .map(|p| p.value.clone())
1133 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
1134 other => {
1135 return Err(SQLRiteError::NotImplemented(format!(
1136 "CREATE INDEX only supports simple column references, got {other:?}"
1137 )));
1138 }
1139 };
1140
1141 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
1146 let table = db.get_table(table_name_str.clone()).map_err(|_| {
1147 SQLRiteError::General(format!(
1148 "CREATE INDEX references unknown table '{table_name_str}'"
1149 ))
1150 })?;
1151 if !table.contains_column(column_name.clone()) {
1152 return Err(SQLRiteError::General(format!(
1153 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
1154 )));
1155 }
1156 let col = table
1157 .columns
1158 .iter()
1159 .find(|c| c.column_name == column_name)
1160 .expect("we just verified the column exists");
1161
1162 if table.index_by_name(&index_name).is_some()
1165 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
1166 || table.fts_indexes.iter().any(|i| i.name == index_name)
1167 {
1168 if *if_not_exists {
1169 return Ok(index_name);
1170 }
1171 return Err(SQLRiteError::General(format!(
1172 "index '{index_name}' already exists"
1173 )));
1174 }
1175 let datatype = clone_datatype(&col.datatype);
1176
1177 let mut pairs = Vec::new();
1178 for rowid in table.rowids() {
1179 if let Some(v) = table.get_value(&column_name, rowid) {
1180 pairs.push((rowid, v));
1181 }
1182 }
1183 (datatype, pairs)
1184 };
1185
1186 match method {
1187 IndexMethod::Btree => create_btree_index(
1188 db,
1189 &table_name_str,
1190 &index_name,
1191 &column_name,
1192 &datatype,
1193 *unique,
1194 &existing_rowids_and_values,
1195 ),
1196 IndexMethod::Hnsw => create_hnsw_index(
1197 db,
1198 &table_name_str,
1199 &index_name,
1200 &column_name,
1201 &datatype,
1202 *unique,
1203 hnsw_metric.unwrap_or(DistanceMetric::L2),
1204 &existing_rowids_and_values,
1205 ),
1206 IndexMethod::Fts => create_fts_index(
1207 db,
1208 &table_name_str,
1209 &index_name,
1210 &column_name,
1211 &datatype,
1212 *unique,
1213 &existing_rowids_and_values,
1214 ),
1215 }
1216}
1217
1218pub fn execute_drop_table(
1229 names: &[ObjectName],
1230 if_exists: bool,
1231 db: &mut Database,
1232) -> Result<usize> {
1233 if names.len() != 1 {
1234 return Err(SQLRiteError::NotImplemented(
1235 "DROP TABLE supports a single table per statement".to_string(),
1236 ));
1237 }
1238 let name = names[0].to_string();
1239
1240 if name == crate::sql::pager::MASTER_TABLE_NAME {
1241 return Err(SQLRiteError::General(format!(
1242 "'{}' is a reserved name used by the internal schema catalog",
1243 crate::sql::pager::MASTER_TABLE_NAME
1244 )));
1245 }
1246
1247 if !db.contains_table(name.clone()) {
1248 return if if_exists {
1249 Ok(0)
1250 } else {
1251 Err(SQLRiteError::General(format!(
1252 "Table '{name}' does not exist"
1253 )))
1254 };
1255 }
1256
1257 db.tables.remove(&name);
1258 Ok(1)
1259}
1260
1261pub fn execute_drop_index(
1270 names: &[ObjectName],
1271 if_exists: bool,
1272 db: &mut Database,
1273) -> Result<usize> {
1274 if names.len() != 1 {
1275 return Err(SQLRiteError::NotImplemented(
1276 "DROP INDEX supports a single index per statement".to_string(),
1277 ));
1278 }
1279 let name = names[0].to_string();
1280
1281 for table in db.tables.values_mut() {
1282 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1283 if secondary.origin == IndexOrigin::Auto {
1284 return Err(SQLRiteError::General(format!(
1285 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1286 )));
1287 }
1288 table.secondary_indexes.retain(|i| i.name != name);
1289 return Ok(1);
1290 }
1291 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1292 table.hnsw_indexes.retain(|i| i.name != name);
1293 return Ok(1);
1294 }
1295 if table.fts_indexes.iter().any(|i| i.name == name) {
1296 table.fts_indexes.retain(|i| i.name != name);
1297 return Ok(1);
1298 }
1299 }
1300
1301 if if_exists {
1302 Ok(0)
1303 } else {
1304 Err(SQLRiteError::General(format!(
1305 "Index '{name}' does not exist"
1306 )))
1307 }
1308}
1309
1310pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1322 let table_name = alter.name.to_string();
1323
1324 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1325 return Err(SQLRiteError::General(format!(
1326 "'{}' is a reserved name used by the internal schema catalog",
1327 crate::sql::pager::MASTER_TABLE_NAME
1328 )));
1329 }
1330
1331 if !db.contains_table(table_name.clone()) {
1332 return if alter.if_exists {
1333 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1334 } else {
1335 Err(SQLRiteError::General(format!(
1336 "Table '{table_name}' does not exist"
1337 )))
1338 };
1339 }
1340
1341 if alter.operations.len() != 1 {
1342 return Err(SQLRiteError::NotImplemented(
1343 "ALTER TABLE supports one operation per statement".to_string(),
1344 ));
1345 }
1346
1347 match &alter.operations[0] {
1348 AlterTableOperation::RenameTable { table_name: kind } => {
1349 let new_name = match kind {
1350 RenameTableNameKind::To(name) => name.to_string(),
1351 RenameTableNameKind::As(_) => {
1352 return Err(SQLRiteError::NotImplemented(
1353 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1354 .to_string(),
1355 ));
1356 }
1357 };
1358 alter_rename_table(db, &table_name, &new_name)?;
1359 Ok(format!(
1360 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1361 ))
1362 }
1363 AlterTableOperation::RenameColumn {
1364 old_column_name,
1365 new_column_name,
1366 } => {
1367 let old = old_column_name.value.clone();
1368 let new = new_column_name.value.clone();
1369 db.get_table_mut(table_name.clone())?
1370 .rename_column(&old, &new)?;
1371 Ok(format!(
1372 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1373 ))
1374 }
1375 AlterTableOperation::AddColumn {
1376 column_def,
1377 if_not_exists,
1378 ..
1379 } => {
1380 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1381 let table = db.get_table_mut(table_name.clone())?;
1382 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1383 return Ok(format!(
1384 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1385 parsed.name
1386 ));
1387 }
1388 let col_name = parsed.name.clone();
1389 table.add_column(parsed)?;
1390 Ok(format!(
1391 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1392 ))
1393 }
1394 AlterTableOperation::DropColumn {
1395 column_names,
1396 if_exists,
1397 ..
1398 } => {
1399 if column_names.len() != 1 {
1400 return Err(SQLRiteError::NotImplemented(
1401 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1402 ));
1403 }
1404 let col_name = column_names[0].value.clone();
1405 let table = db.get_table_mut(table_name.clone())?;
1406 if *if_exists && !table.contains_column(col_name.clone()) {
1407 return Ok(format!(
1408 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1409 ));
1410 }
1411 table.drop_column(&col_name)?;
1412 Ok(format!(
1413 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1414 ))
1415 }
1416 other => Err(SQLRiteError::NotImplemented(format!(
1417 "ALTER TABLE operation {other:?} is not supported"
1418 ))),
1419 }
1420}
1421
1422pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1432 if db.in_transaction() {
1433 return Err(SQLRiteError::General(
1434 "VACUUM cannot run inside a transaction".to_string(),
1435 ));
1436 }
1437 let path = match db.source_path.clone() {
1438 Some(p) => p,
1439 None => {
1440 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1441 }
1442 };
1443 if let Some(pager) = db.pager.as_mut() {
1449 let _ = pager.checkpoint();
1450 }
1451 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1452 let pages_before = db
1453 .pager
1454 .as_ref()
1455 .map(|p| p.header().page_count)
1456 .unwrap_or(0);
1457 crate::sql::pager::vacuum_database(db, &path)?;
1458 if let Some(pager) = db.pager.as_mut() {
1461 let _ = pager.checkpoint();
1462 }
1463 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1464 let pages_after = db
1465 .pager
1466 .as_ref()
1467 .map(|p| p.header().page_count)
1468 .unwrap_or(0);
1469 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1470 let bytes_reclaimed = size_before.saturating_sub(size_after);
1471 Ok(format!(
1472 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1473 ))
1474}
1475
1476fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1482 if new == crate::sql::pager::MASTER_TABLE_NAME {
1483 return Err(SQLRiteError::General(format!(
1484 "'{}' is a reserved name used by the internal schema catalog",
1485 crate::sql::pager::MASTER_TABLE_NAME
1486 )));
1487 }
1488 if old == new {
1489 return Ok(());
1490 }
1491 if db.contains_table(new.to_string()) {
1492 return Err(SQLRiteError::General(format!(
1493 "target table '{new}' already exists"
1494 )));
1495 }
1496
1497 let mut table = db
1498 .tables
1499 .remove(old)
1500 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1501 table.tb_name = new.to_string();
1502 for idx in table.secondary_indexes.iter_mut() {
1503 idx.table_name = new.to_string();
1504 if idx.origin == IndexOrigin::Auto
1505 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1506 {
1507 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1508 }
1509 }
1510 db.tables.insert(new.to_string(), table);
1511 Ok(())
1512}
1513
1514#[derive(Debug, Clone, Copy)]
1518enum IndexMethod {
1519 Btree,
1520 Hnsw,
1521 Fts,
1523}
1524
1525fn create_btree_index(
1527 db: &mut Database,
1528 table_name: &str,
1529 index_name: &str,
1530 column_name: &str,
1531 datatype: &DataType,
1532 unique: bool,
1533 existing: &[(i64, Value)],
1534) -> Result<String> {
1535 let mut idx = SecondaryIndex::new(
1536 index_name.to_string(),
1537 table_name.to_string(),
1538 column_name.to_string(),
1539 datatype,
1540 unique,
1541 IndexOrigin::Explicit,
1542 )?;
1543
1544 for (rowid, v) in existing {
1548 if unique && idx.would_violate_unique(v) {
1549 return Err(SQLRiteError::General(format!(
1550 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1551 already contains the duplicate value {}",
1552 v.to_display_string()
1553 )));
1554 }
1555 idx.insert(v, *rowid)?;
1556 }
1557
1558 let table_mut = db.get_table_mut(table_name.to_string())?;
1559 table_mut.secondary_indexes.push(idx);
1560 Ok(index_name.to_string())
1561}
1562
1563fn create_hnsw_index(
1565 db: &mut Database,
1566 table_name: &str,
1567 index_name: &str,
1568 column_name: &str,
1569 datatype: &DataType,
1570 unique: bool,
1571 metric: DistanceMetric,
1572 existing: &[(i64, Value)],
1573) -> Result<String> {
1574 let dim = match datatype {
1577 DataType::Vector(d) => *d,
1578 other => {
1579 return Err(SQLRiteError::General(format!(
1580 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1581 )));
1582 }
1583 };
1584
1585 if unique {
1586 return Err(SQLRiteError::General(
1587 "UNIQUE has no meaning for HNSW indexes".to_string(),
1588 ));
1589 }
1590
1591 let seed = hash_str_to_seed(index_name);
1602 let mut idx = HnswIndex::new(metric, seed);
1603
1604 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1608 std::collections::HashMap::with_capacity(existing.len());
1609 for (rowid, v) in existing {
1610 match v {
1611 Value::Vector(vec) => {
1612 if vec.len() != dim {
1613 return Err(SQLRiteError::Internal(format!(
1614 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1615 declared as VECTOR({dim}) — schema invariant violated",
1616 vec.len()
1617 )));
1618 }
1619 vec_map.insert(*rowid, vec.clone());
1620 }
1621 _ => continue,
1625 }
1626 }
1627
1628 for (rowid, _) in existing {
1629 if let Some(v) = vec_map.get(rowid) {
1630 let v_clone = v.clone();
1631 idx.insert(*rowid, &v_clone, |id| {
1632 vec_map.get(&id).cloned().unwrap_or_default()
1633 })?;
1634 }
1635 }
1636
1637 let table_mut = db.get_table_mut(table_name.to_string())?;
1638 table_mut.hnsw_indexes.push(HnswIndexEntry {
1639 name: index_name.to_string(),
1640 column_name: column_name.to_string(),
1641 metric,
1642 index: idx,
1643 needs_rebuild: false,
1645 });
1646 Ok(index_name.to_string())
1647}
1648
1649fn parse_hnsw_with_options(
1660 with: &[Expr],
1661 index_name: &str,
1662 method: IndexMethod,
1663) -> Result<Option<DistanceMetric>> {
1664 if with.is_empty() {
1665 return Ok(None);
1666 }
1667 if !matches!(method, IndexMethod::Hnsw) {
1668 return Err(SQLRiteError::General(format!(
1669 "CREATE INDEX '{index_name}' has a WITH (...) clause but its index method \
1670 doesn't support any options — only `USING hnsw` recognises `WITH (metric = ...)`"
1671 )));
1672 }
1673
1674 let mut metric: Option<DistanceMetric> = None;
1675 for opt in with {
1676 let Expr::BinaryOp { left, op, right } = opt else {
1677 return Err(SQLRiteError::General(format!(
1678 "CREATE INDEX '{index_name}': unsupported WITH option {opt:?} \
1679 (expected `key = 'value'`)"
1680 )));
1681 };
1682 if !matches!(op, BinaryOperator::Eq) {
1683 return Err(SQLRiteError::General(format!(
1684 "CREATE INDEX '{index_name}': WITH options must use `=` (got {op:?})"
1685 )));
1686 }
1687 let key = match left.as_ref() {
1688 Expr::Identifier(ident) => ident.value.clone(),
1689 other => {
1690 return Err(SQLRiteError::General(format!(
1691 "CREATE INDEX '{index_name}': WITH option key must be a bare identifier, \
1692 got {other:?}"
1693 )));
1694 }
1695 };
1696 let value = match right.as_ref() {
1697 Expr::Value(v) => match &v.value {
1698 AstValue::SingleQuotedString(s) => s.clone(),
1699 AstValue::DoubleQuotedString(s) => s.clone(),
1700 other => {
1701 return Err(SQLRiteError::General(format!(
1702 "CREATE INDEX '{index_name}': WITH option '{key}' value must be \
1703 a quoted string, got {other:?}"
1704 )));
1705 }
1706 },
1707 Expr::Identifier(ident) => ident.value.clone(),
1708 other => {
1709 return Err(SQLRiteError::General(format!(
1710 "CREATE INDEX '{index_name}': WITH option '{key}' value must be a \
1711 quoted string, got {other:?}"
1712 )));
1713 }
1714 };
1715
1716 if key.eq_ignore_ascii_case("metric") {
1717 let parsed = DistanceMetric::from_sql_name(&value).ok_or_else(|| {
1718 SQLRiteError::General(format!(
1719 "CREATE INDEX '{index_name}': unknown HNSW metric '{value}' \
1720 (try 'l2', 'cosine', or 'dot')"
1721 ))
1722 })?;
1723 if metric.is_some() {
1724 return Err(SQLRiteError::General(format!(
1725 "CREATE INDEX '{index_name}': metric specified more than once in WITH (...)"
1726 )));
1727 }
1728 metric = Some(parsed);
1729 } else {
1730 return Err(SQLRiteError::General(format!(
1731 "CREATE INDEX '{index_name}': unknown WITH option '{key}' \
1732 (only 'metric' is recognised on HNSW indexes)"
1733 )));
1734 }
1735 }
1736
1737 Ok(metric)
1738}
1739
1740fn create_fts_index(
1745 db: &mut Database,
1746 table_name: &str,
1747 index_name: &str,
1748 column_name: &str,
1749 datatype: &DataType,
1750 unique: bool,
1751 existing: &[(i64, Value)],
1752) -> Result<String> {
1753 match datatype {
1758 DataType::Text => {}
1759 other => {
1760 return Err(SQLRiteError::General(format!(
1761 "USING fts requires a TEXT column; '{column_name}' is {other}"
1762 )));
1763 }
1764 }
1765
1766 if unique {
1767 return Err(SQLRiteError::General(
1768 "UNIQUE has no meaning for FTS indexes".to_string(),
1769 ));
1770 }
1771
1772 let mut idx = PostingList::new();
1773 for (rowid, v) in existing {
1774 if let Value::Text(text) = v {
1775 idx.insert(*rowid, text);
1776 }
1777 }
1780
1781 let table_mut = db.get_table_mut(table_name.to_string())?;
1782 table_mut.fts_indexes.push(FtsIndexEntry {
1783 name: index_name.to_string(),
1784 column_name: column_name.to_string(),
1785 index: idx,
1786 needs_rebuild: false,
1787 });
1788 Ok(index_name.to_string())
1789}
1790
1791fn hash_str_to_seed(s: &str) -> u64 {
1795 let mut h: u64 = 0xCBF29CE484222325;
1796 for b in s.as_bytes() {
1797 h ^= *b as u64;
1798 h = h.wrapping_mul(0x100000001B3);
1799 }
1800 h
1801}
1802
1803fn clone_datatype(dt: &DataType) -> DataType {
1806 match dt {
1807 DataType::Integer => DataType::Integer,
1808 DataType::Text => DataType::Text,
1809 DataType::Real => DataType::Real,
1810 DataType::Bool => DataType::Bool,
1811 DataType::Vector(dim) => DataType::Vector(*dim),
1812 DataType::Json => DataType::Json,
1813 DataType::None => DataType::None,
1814 DataType::Invalid => DataType::Invalid,
1815 }
1816}
1817
1818fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1819 if tables.len() != 1 {
1820 return Err(SQLRiteError::NotImplemented(
1821 "multi-table DELETE is not supported yet".to_string(),
1822 ));
1823 }
1824 extract_table_name(&tables[0])
1825}
1826
1827fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1828 if !twj.joins.is_empty() {
1829 return Err(SQLRiteError::NotImplemented(
1830 "JOIN is not supported yet".to_string(),
1831 ));
1832 }
1833 match &twj.relation {
1834 TableFactor::Table { name, .. } => Ok(name.to_string()),
1835 _ => Err(SQLRiteError::NotImplemented(
1836 "only plain table references are supported".to_string(),
1837 )),
1838 }
1839}
1840
1841enum RowidSource {
1843 IndexProbe(Vec<i64>),
1847 FullScan,
1850}
1851
1852fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1857 let Some(expr) = selection else {
1858 return Ok(RowidSource::FullScan);
1859 };
1860 let Some((col, literal)) = try_extract_equality(expr) else {
1861 return Ok(RowidSource::FullScan);
1862 };
1863 let Some(idx) = table.index_for_column(&col) else {
1864 return Ok(RowidSource::FullScan);
1865 };
1866
1867 let literal_value = match convert_literal(&literal) {
1871 Ok(v) => v,
1872 Err(_) => return Ok(RowidSource::FullScan),
1873 };
1874
1875 let mut rowids = idx.lookup(&literal_value);
1879 rowids.sort_unstable();
1880 Ok(RowidSource::IndexProbe(rowids))
1881}
1882
1883fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1887 let peeled = match expr {
1889 Expr::Nested(inner) => inner.as_ref(),
1890 other => other,
1891 };
1892 let Expr::BinaryOp { left, op, right } = peeled else {
1893 return None;
1894 };
1895 if !matches!(op, BinaryOperator::Eq) {
1896 return None;
1897 }
1898 let col_from = |e: &Expr| -> Option<String> {
1899 match e {
1900 Expr::Identifier(ident) => Some(ident.value.clone()),
1901 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1902 _ => None,
1903 }
1904 };
1905 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1906 if let Expr::Value(v) = e {
1907 Some(v.value.clone())
1908 } else {
1909 None
1910 }
1911 };
1912 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1913 return Some((c, l));
1914 }
1915 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1916 return Some((c, l));
1917 }
1918 None
1919}
1920
1921fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1946 if k == 0 {
1947 return None;
1948 }
1949
1950 let func = match order_expr {
1953 Expr::Function(f) => f,
1954 _ => return None,
1955 };
1956 let fname = match func.name.0.as_slice() {
1957 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1958 _ => return None,
1959 };
1960 let query_metric = match fname.as_str() {
1961 "vec_distance_l2" => DistanceMetric::L2,
1962 "vec_distance_cosine" => DistanceMetric::Cosine,
1963 "vec_distance_dot" => DistanceMetric::Dot,
1964 _ => return None,
1965 };
1966
1967 let arg_list = match &func.args {
1969 FunctionArguments::List(l) => &l.args,
1970 _ => return None,
1971 };
1972 if arg_list.len() != 2 {
1973 return None;
1974 }
1975 let exprs: Vec<&Expr> = arg_list
1976 .iter()
1977 .filter_map(|a| match a {
1978 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1979 _ => None,
1980 })
1981 .collect();
1982 if exprs.len() != 2 {
1983 return None;
1984 }
1985
1986 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1991 Some(v) => v,
1992 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1993 Some(v) => v,
1994 None => return None,
1995 },
1996 };
1997
1998 let entry = table
2003 .hnsw_indexes
2004 .iter()
2005 .find(|e| e.column_name == col_name && e.metric == query_metric)?;
2006
2007 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
2013 Some(c) => match &c.datatype {
2014 DataType::Vector(d) => *d,
2015 _ => return None,
2016 },
2017 None => return None,
2018 };
2019 if query_vec.len() != declared_dim {
2020 return None;
2021 }
2022
2023 let column_for_closure = col_name.clone();
2027 let table_ref = table;
2028 let result = entry
2029 .index
2030 .search(&query_vec, k, |id| {
2031 match table_ref.get_value(&column_for_closure, id) {
2032 Some(Value::Vector(v)) => v,
2033 _ => Vec::new(),
2034 }
2035 })
2036 .ok()?;
2037 Some(result)
2038}
2039
2040fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
2056 if k == 0 || ascending {
2057 return None;
2061 }
2062
2063 let func = match order_expr {
2064 Expr::Function(f) => f,
2065 _ => return None,
2066 };
2067 let fname = match func.name.0.as_slice() {
2068 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2069 _ => return None,
2070 };
2071 if fname != "bm25_score" {
2072 return None;
2073 }
2074
2075 let arg_list = match &func.args {
2076 FunctionArguments::List(l) => &l.args,
2077 _ => return None,
2078 };
2079 if arg_list.len() != 2 {
2080 return None;
2081 }
2082 let exprs: Vec<&Expr> = arg_list
2083 .iter()
2084 .filter_map(|a| match a {
2085 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
2086 _ => None,
2087 })
2088 .collect();
2089 if exprs.len() != 2 {
2090 return None;
2091 }
2092
2093 let col_name = match exprs[0] {
2095 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2096 _ => return None,
2097 };
2098
2099 let query = match exprs[1] {
2103 Expr::Value(v) => match &v.value {
2104 AstValue::SingleQuotedString(s) => s.clone(),
2105 _ => return None,
2106 },
2107 _ => return None,
2108 };
2109
2110 let entry = table
2111 .fts_indexes
2112 .iter()
2113 .find(|e| e.column_name == col_name)?;
2114
2115 let scored = entry.index.query(&query, &Bm25Params::default());
2116 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
2117 if out.len() > k {
2118 out.truncate(k);
2119 }
2120 Some(out)
2121}
2122
2123fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
2128 let col_name = match a {
2129 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2130 _ => return None,
2131 };
2132 let lit_str = match b {
2133 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
2134 format!("[{}]", ident.value)
2135 }
2136 _ => return None,
2137 };
2138 let v = parse_vector_literal(&lit_str).ok()?;
2139 Some((col_name, v))
2140}
2141
2142struct HeapEntry {
2155 key: Value,
2156 rowid: i64,
2157 asc: bool,
2158}
2159
2160impl PartialEq for HeapEntry {
2161 fn eq(&self, other: &Self) -> bool {
2162 self.cmp(other) == Ordering::Equal
2163 }
2164}
2165
2166impl Eq for HeapEntry {}
2167
2168impl PartialOrd for HeapEntry {
2169 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2170 Some(self.cmp(other))
2171 }
2172}
2173
2174impl Ord for HeapEntry {
2175 fn cmp(&self, other: &Self) -> Ordering {
2176 let raw = compare_values(Some(&self.key), Some(&other.key));
2177 if self.asc { raw } else { raw.reverse() }
2178 }
2179}
2180
2181fn select_topk(
2190 matching: &[i64],
2191 table: &Table,
2192 order: &OrderByClause,
2193 k: usize,
2194) -> Result<Vec<i64>> {
2195 use std::collections::BinaryHeap;
2196
2197 if k == 0 || matching.is_empty() {
2198 return Ok(Vec::new());
2199 }
2200
2201 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
2202
2203 for &rowid in matching {
2204 let key = eval_expr(&order.expr, table, rowid)?;
2205 let entry = HeapEntry {
2206 key,
2207 rowid,
2208 asc: order.ascending,
2209 };
2210
2211 if heap.len() < k {
2212 heap.push(entry);
2213 } else {
2214 if entry < *heap.peek().unwrap() {
2218 heap.pop();
2219 heap.push(entry);
2220 }
2221 }
2222 }
2223
2224 Ok(heap
2229 .into_sorted_vec()
2230 .into_iter()
2231 .map(|e| e.rowid)
2232 .collect())
2233}
2234
2235fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
2236 let mut keys: Vec<(i64, Result<Value>)> = rowids
2244 .iter()
2245 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
2246 .collect();
2247
2248 for (_, k) in &keys {
2252 if let Err(e) = k {
2253 return Err(SQLRiteError::General(format!(
2254 "ORDER BY expression failed: {e}"
2255 )));
2256 }
2257 }
2258
2259 keys.sort_by(|(_, ka), (_, kb)| {
2260 let va = ka.as_ref().unwrap();
2263 let vb = kb.as_ref().unwrap();
2264 let ord = compare_values(Some(va), Some(vb));
2265 if order.ascending { ord } else { ord.reverse() }
2266 });
2267
2268 for (i, (rowid, _)) in keys.into_iter().enumerate() {
2270 rowids[i] = rowid;
2271 }
2272 Ok(())
2273}
2274
2275fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
2276 match (a, b) {
2277 (None, None) => Ordering::Equal,
2278 (None, _) => Ordering::Less,
2279 (_, None) => Ordering::Greater,
2280 (Some(a), Some(b)) => match (a, b) {
2281 (Value::Null, Value::Null) => Ordering::Equal,
2282 (Value::Null, _) => Ordering::Less,
2283 (_, Value::Null) => Ordering::Greater,
2284 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2285 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2286 (Value::Integer(x), Value::Real(y)) => {
2287 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2288 }
2289 (Value::Real(x), Value::Integer(y)) => {
2290 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2291 }
2292 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2293 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2294 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2296 },
2297 }
2298}
2299
2300pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
2302 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
2303}
2304
2305pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2309 let v = eval_expr_scope(expr, scope)?;
2310 match v {
2311 Value::Bool(b) => Ok(b),
2312 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2314 other => Err(SQLRiteError::Internal(format!(
2315 "WHERE clause must evaluate to boolean, got {}",
2316 other.to_display_string()
2317 ))),
2318 }
2319}
2320
2321fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
2323 eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
2324}
2325
2326fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2327 match expr {
2328 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2329
2330 Expr::Identifier(ident) => {
2331 if ident.quote_style == Some('[') {
2341 let raw = format!("[{}]", ident.value);
2342 let v = parse_vector_literal(&raw)?;
2343 return Ok(Value::Vector(v));
2344 }
2345 scope.lookup(None, &ident.value)
2346 }
2347
2348 Expr::CompoundIdentifier(parts) => {
2349 match parts.as_slice() {
2355 [only] => scope.lookup(None, &only.value),
2356 [q, c] => scope.lookup(Some(&q.value), &c.value),
2357 _ => Err(SQLRiteError::NotImplemented(format!(
2358 "compound identifier with {} parts is not supported",
2359 parts.len()
2360 ))),
2361 }
2362 }
2363
2364 Expr::Value(v) => convert_literal(&v.value),
2365
2366 Expr::UnaryOp { op, expr } => {
2367 let inner = eval_expr_scope(expr, scope)?;
2368 match op {
2369 UnaryOperator::Not => match inner {
2370 Value::Bool(b) => Ok(Value::Bool(!b)),
2371 Value::Null => Ok(Value::Null),
2372 other => Err(SQLRiteError::Internal(format!(
2373 "NOT applied to non-boolean value: {}",
2374 other.to_display_string()
2375 ))),
2376 },
2377 UnaryOperator::Minus => match inner {
2378 Value::Integer(i) => Ok(Value::Integer(-i)),
2379 Value::Real(f) => Ok(Value::Real(-f)),
2380 Value::Null => Ok(Value::Null),
2381 other => Err(SQLRiteError::Internal(format!(
2382 "unary minus on non-numeric value: {}",
2383 other.to_display_string()
2384 ))),
2385 },
2386 UnaryOperator::Plus => Ok(inner),
2387 other => Err(SQLRiteError::NotImplemented(format!(
2388 "unary operator {other:?} is not supported"
2389 ))),
2390 }
2391 }
2392
2393 Expr::BinaryOp { left, op, right } => match op {
2394 BinaryOperator::And => {
2395 let l = eval_expr_scope(left, scope)?;
2396 let r = eval_expr_scope(right, scope)?;
2397 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2398 }
2399 BinaryOperator::Or => {
2400 let l = eval_expr_scope(left, scope)?;
2401 let r = eval_expr_scope(right, scope)?;
2402 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2403 }
2404 cmp @ (BinaryOperator::Eq
2405 | BinaryOperator::NotEq
2406 | BinaryOperator::Lt
2407 | BinaryOperator::LtEq
2408 | BinaryOperator::Gt
2409 | BinaryOperator::GtEq) => {
2410 let l = eval_expr_scope(left, scope)?;
2411 let r = eval_expr_scope(right, scope)?;
2412 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2414 return Ok(Value::Bool(false));
2415 }
2416 let ord = compare_values(Some(&l), Some(&r));
2417 let result = match cmp {
2418 BinaryOperator::Eq => ord == Ordering::Equal,
2419 BinaryOperator::NotEq => ord != Ordering::Equal,
2420 BinaryOperator::Lt => ord == Ordering::Less,
2421 BinaryOperator::LtEq => ord != Ordering::Greater,
2422 BinaryOperator::Gt => ord == Ordering::Greater,
2423 BinaryOperator::GtEq => ord != Ordering::Less,
2424 _ => unreachable!(),
2425 };
2426 Ok(Value::Bool(result))
2427 }
2428 arith @ (BinaryOperator::Plus
2429 | BinaryOperator::Minus
2430 | BinaryOperator::Multiply
2431 | BinaryOperator::Divide
2432 | BinaryOperator::Modulo) => {
2433 let l = eval_expr_scope(left, scope)?;
2434 let r = eval_expr_scope(right, scope)?;
2435 eval_arith(arith, &l, &r)
2436 }
2437 BinaryOperator::StringConcat => {
2438 let l = eval_expr_scope(left, scope)?;
2439 let r = eval_expr_scope(right, scope)?;
2440 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2441 return Ok(Value::Null);
2442 }
2443 Ok(Value::Text(format!(
2444 "{}{}",
2445 l.to_display_string(),
2446 r.to_display_string()
2447 )))
2448 }
2449 other => Err(SQLRiteError::NotImplemented(format!(
2450 "binary operator {other:?} is not supported yet"
2451 ))),
2452 },
2453
2454 Expr::IsNull(inner) => {
2462 let v = eval_expr_scope(inner, scope)?;
2463 Ok(Value::Bool(matches!(v, Value::Null)))
2464 }
2465 Expr::IsNotNull(inner) => {
2466 let v = eval_expr_scope(inner, scope)?;
2467 Ok(Value::Bool(!matches!(v, Value::Null)))
2468 }
2469
2470 Expr::Like {
2477 negated,
2478 any,
2479 expr: lhs,
2480 pattern,
2481 escape_char,
2482 } => eval_like(
2483 scope,
2484 *negated,
2485 *any,
2486 lhs,
2487 pattern,
2488 escape_char.as_ref(),
2489 true,
2490 ),
2491 Expr::ILike {
2492 negated,
2493 any,
2494 expr: lhs,
2495 pattern,
2496 escape_char,
2497 } => eval_like(
2498 scope,
2499 *negated,
2500 *any,
2501 lhs,
2502 pattern,
2503 escape_char.as_ref(),
2504 true,
2505 ),
2506
2507 Expr::InList {
2513 expr: lhs,
2514 list,
2515 negated,
2516 } => eval_in_list(scope, lhs, list, *negated),
2517 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2518 "IN (subquery) is not supported (only literal lists are)".to_string(),
2519 )),
2520
2521 Expr::Function(func) => eval_function(func, scope),
2532
2533 other => Err(SQLRiteError::NotImplemented(format!(
2534 "unsupported expression in WHERE/projection: {other:?}"
2535 ))),
2536 }
2537}
2538
2539fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2544 let name = match func.name.0.as_slice() {
2547 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2548 _ => {
2549 return Err(SQLRiteError::NotImplemented(format!(
2550 "qualified function names not supported: {:?}",
2551 func.name
2552 )));
2553 }
2554 };
2555
2556 match name.as_str() {
2557 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2558 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2559 let dist = match name.as_str() {
2560 "vec_distance_l2" => vec_distance_l2(&a, &b),
2561 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2562 "vec_distance_dot" => vec_distance_dot(&a, &b),
2563 _ => unreachable!(),
2564 };
2565 Ok(Value::Real(dist as f64))
2571 }
2572 "json_extract" => json_fn_extract(&name, &func.args, scope),
2577 "json_type" => json_fn_type(&name, &func.args, scope),
2578 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2579 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2580 "fts_match" | "bm25_score" => {
2591 let Some((table, rowid)) = scope.single_table_view() else {
2592 return Err(SQLRiteError::NotImplemented(format!(
2593 "{name}() is not yet supported inside a JOIN query — \
2594 use it on a single-table SELECT or move the FTS lookup into a subquery"
2595 )));
2596 };
2597 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2598 Ok(match name.as_str() {
2599 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2600 "bm25_score" => {
2601 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2602 }
2603 _ => unreachable!(),
2604 })
2605 }
2606 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2610 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2611 use it as a top-level projection item (HAVING is not yet supported)"
2612 ))),
2613 other => Err(SQLRiteError::NotImplemented(format!(
2614 "unknown function: {other}(...)"
2615 ))),
2616 }
2617}
2618
2619fn resolve_fts_args<'t>(
2624 fn_name: &str,
2625 args: &FunctionArguments,
2626 table: &'t Table,
2627 scope: &dyn RowScope,
2628) -> Result<(&'t FtsIndexEntry, String)> {
2629 let arg_list = match args {
2630 FunctionArguments::List(l) => &l.args,
2631 _ => {
2632 return Err(SQLRiteError::General(format!(
2633 "{fn_name}() expects exactly two arguments: (column, query_text)"
2634 )));
2635 }
2636 };
2637 if arg_list.len() != 2 {
2638 return Err(SQLRiteError::General(format!(
2639 "{fn_name}() expects exactly 2 arguments, got {}",
2640 arg_list.len()
2641 )));
2642 }
2643
2644 let col_expr = match &arg_list[0] {
2648 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2649 other => {
2650 return Err(SQLRiteError::NotImplemented(format!(
2651 "{fn_name}() argument 0 must be a column name, got {other:?}"
2652 )));
2653 }
2654 };
2655 let col_name = match col_expr {
2656 Expr::Identifier(ident) => ident.value.clone(),
2657 Expr::CompoundIdentifier(parts) => parts
2658 .last()
2659 .map(|p| p.value.clone())
2660 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
2661 other => {
2662 return Err(SQLRiteError::General(format!(
2663 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2664 )));
2665 }
2666 };
2667
2668 let q_expr = match &arg_list[1] {
2672 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2673 other => {
2674 return Err(SQLRiteError::NotImplemented(format!(
2675 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2676 )));
2677 }
2678 };
2679 let query = match eval_expr_scope(q_expr, scope)? {
2680 Value::Text(s) => s,
2681 other => {
2682 return Err(SQLRiteError::General(format!(
2683 "{fn_name}() argument 1 must be TEXT, got {}",
2684 other.to_display_string()
2685 )));
2686 }
2687 };
2688
2689 let entry = table
2690 .fts_indexes
2691 .iter()
2692 .find(|e| e.column_name == col_name)
2693 .ok_or_else(|| {
2694 SQLRiteError::General(format!(
2695 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2696 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2697 ))
2698 })?;
2699 Ok((entry, query))
2700}
2701
2702fn extract_json_and_path(
2716 fn_name: &str,
2717 args: &FunctionArguments,
2718 scope: &dyn RowScope,
2719) -> Result<(String, String)> {
2720 let arg_list = match args {
2721 FunctionArguments::List(l) => &l.args,
2722 _ => {
2723 return Err(SQLRiteError::General(format!(
2724 "{fn_name}() expects 1 or 2 arguments"
2725 )));
2726 }
2727 };
2728 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2729 return Err(SQLRiteError::General(format!(
2730 "{fn_name}() expects 1 or 2 arguments, got {}",
2731 arg_list.len()
2732 )));
2733 }
2734 let first_expr = match &arg_list[0] {
2736 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2737 other => {
2738 return Err(SQLRiteError::NotImplemented(format!(
2739 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2740 )));
2741 }
2742 };
2743 let json_text = match eval_expr_scope(first_expr, scope)? {
2744 Value::Text(s) => s,
2745 Value::Null => {
2746 return Err(SQLRiteError::General(format!(
2747 "{fn_name}() called on NULL — JSON column has no value for this row"
2748 )));
2749 }
2750 other => {
2751 return Err(SQLRiteError::General(format!(
2752 "{fn_name}() argument 0 is not JSON-typed: got {}",
2753 other.to_display_string()
2754 )));
2755 }
2756 };
2757
2758 let path = if arg_list.len() == 2 {
2760 let path_expr = match &arg_list[1] {
2761 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2762 other => {
2763 return Err(SQLRiteError::NotImplemented(format!(
2764 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2765 )));
2766 }
2767 };
2768 match eval_expr_scope(path_expr, scope)? {
2769 Value::Text(s) => s,
2770 other => {
2771 return Err(SQLRiteError::General(format!(
2772 "{fn_name}() path argument must be a string literal, got {}",
2773 other.to_display_string()
2774 )));
2775 }
2776 }
2777 } else {
2778 "$".to_string()
2779 };
2780
2781 Ok((json_text, path))
2782}
2783
2784fn walk_json_path<'a>(
2794 value: &'a serde_json::Value,
2795 path: &str,
2796) -> Result<Option<&'a serde_json::Value>> {
2797 let mut chars = path.chars().peekable();
2798 if chars.next() != Some('$') {
2799 return Err(SQLRiteError::General(format!(
2800 "JSON path must start with '$', got `{path}`"
2801 )));
2802 }
2803 let mut current = value;
2804 while let Some(&c) = chars.peek() {
2805 match c {
2806 '.' => {
2807 chars.next();
2808 let mut key = String::new();
2809 while let Some(&c) = chars.peek() {
2810 if c == '.' || c == '[' {
2811 break;
2812 }
2813 key.push(c);
2814 chars.next();
2815 }
2816 if key.is_empty() {
2817 return Err(SQLRiteError::General(format!(
2818 "JSON path has empty key after '.' in `{path}`"
2819 )));
2820 }
2821 match current.get(&key) {
2822 Some(v) => current = v,
2823 None => return Ok(None),
2824 }
2825 }
2826 '[' => {
2827 chars.next();
2828 let mut idx_str = String::new();
2829 while let Some(&c) = chars.peek() {
2830 if c == ']' {
2831 break;
2832 }
2833 idx_str.push(c);
2834 chars.next();
2835 }
2836 if chars.next() != Some(']') {
2837 return Err(SQLRiteError::General(format!(
2838 "JSON path has unclosed `[` in `{path}`"
2839 )));
2840 }
2841 let idx: usize = idx_str.trim().parse().map_err(|_| {
2842 SQLRiteError::General(format!(
2843 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2844 ))
2845 })?;
2846 match current.get(idx) {
2847 Some(v) => current = v,
2848 None => return Ok(None),
2849 }
2850 }
2851 other => {
2852 return Err(SQLRiteError::General(format!(
2853 "JSON path has unexpected character `{other}` in `{path}` \
2854 (expected `.`, `[`, or end-of-path)"
2855 )));
2856 }
2857 }
2858 }
2859 Ok(Some(current))
2860}
2861
2862fn json_value_to_sql(v: &serde_json::Value) -> Value {
2866 match v {
2867 serde_json::Value::Null => Value::Null,
2868 serde_json::Value::Bool(b) => Value::Bool(*b),
2869 serde_json::Value::Number(n) => {
2870 if let Some(i) = n.as_i64() {
2872 Value::Integer(i)
2873 } else if let Some(f) = n.as_f64() {
2874 Value::Real(f)
2875 } else {
2876 Value::Null
2877 }
2878 }
2879 serde_json::Value::String(s) => Value::Text(s.clone()),
2880 composite => Value::Text(composite.to_string()),
2884 }
2885}
2886
2887fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2888 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2889 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2890 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2891 })?;
2892 match walk_json_path(&parsed, &path)? {
2893 Some(v) => Ok(json_value_to_sql(v)),
2894 None => Ok(Value::Null),
2895 }
2896}
2897
2898fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2899 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2900 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2901 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2902 })?;
2903 let resolved = match walk_json_path(&parsed, &path)? {
2904 Some(v) => v,
2905 None => return Ok(Value::Null),
2906 };
2907 let ty = match resolved {
2908 serde_json::Value::Null => "null",
2909 serde_json::Value::Bool(true) => "true",
2910 serde_json::Value::Bool(false) => "false",
2911 serde_json::Value::Number(n) => {
2912 if n.is_i64() || n.is_u64() {
2913 "integer"
2914 } else {
2915 "real"
2916 }
2917 }
2918 serde_json::Value::String(_) => "text",
2919 serde_json::Value::Array(_) => "array",
2920 serde_json::Value::Object(_) => "object",
2921 };
2922 Ok(Value::Text(ty.to_string()))
2923}
2924
2925fn json_fn_array_length(
2926 name: &str,
2927 args: &FunctionArguments,
2928 scope: &dyn RowScope,
2929) -> Result<Value> {
2930 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2931 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2932 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2933 })?;
2934 let resolved = match walk_json_path(&parsed, &path)? {
2935 Some(v) => v,
2936 None => return Ok(Value::Null),
2937 };
2938 match resolved.as_array() {
2939 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2940 None => Err(SQLRiteError::General(format!(
2941 "{name}() resolved to a non-array value at path `{path}`"
2942 ))),
2943 }
2944}
2945
2946fn json_fn_object_keys(
2947 name: &str,
2948 args: &FunctionArguments,
2949 scope: &dyn RowScope,
2950) -> Result<Value> {
2951 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2952 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2953 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2954 })?;
2955 let resolved = match walk_json_path(&parsed, &path)? {
2956 Some(v) => v,
2957 None => return Ok(Value::Null),
2958 };
2959 let obj = resolved.as_object().ok_or_else(|| {
2960 SQLRiteError::General(format!(
2961 "{name}() resolved to a non-object value at path `{path}`"
2962 ))
2963 })?;
2964 let keys: Vec<serde_json::Value> = obj
2971 .keys()
2972 .map(|k| serde_json::Value::String(k.clone()))
2973 .collect();
2974 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2975}
2976
2977fn extract_two_vector_args(
2981 fn_name: &str,
2982 args: &FunctionArguments,
2983 scope: &dyn RowScope,
2984) -> Result<(Vec<f32>, Vec<f32>)> {
2985 let arg_list = match args {
2986 FunctionArguments::List(l) => &l.args,
2987 _ => {
2988 return Err(SQLRiteError::General(format!(
2989 "{fn_name}() expects exactly two vector arguments"
2990 )));
2991 }
2992 };
2993 if arg_list.len() != 2 {
2994 return Err(SQLRiteError::General(format!(
2995 "{fn_name}() expects exactly 2 arguments, got {}",
2996 arg_list.len()
2997 )));
2998 }
2999 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
3000 for (i, arg) in arg_list.iter().enumerate() {
3001 let expr = match arg {
3002 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
3003 other => {
3004 return Err(SQLRiteError::NotImplemented(format!(
3005 "{fn_name}() argument {i} has unsupported shape: {other:?}"
3006 )));
3007 }
3008 };
3009 let val = eval_expr_scope(expr, scope)?;
3010 match val {
3011 Value::Vector(v) => out.push(v),
3012 other => {
3013 return Err(SQLRiteError::General(format!(
3014 "{fn_name}() argument {i} is not a vector: got {}",
3015 other.to_display_string()
3016 )));
3017 }
3018 }
3019 }
3020 let b = out.pop().unwrap();
3021 let a = out.pop().unwrap();
3022 if a.len() != b.len() {
3023 return Err(SQLRiteError::General(format!(
3024 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
3025 a.len(),
3026 b.len()
3027 )));
3028 }
3029 Ok((a, b))
3030}
3031
3032pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
3035 debug_assert_eq!(a.len(), b.len());
3036 let mut sum = 0.0f32;
3037 for i in 0..a.len() {
3038 let d = a[i] - b[i];
3039 sum += d * d;
3040 }
3041 sum.sqrt()
3042}
3043
3044pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
3054 debug_assert_eq!(a.len(), b.len());
3055 let mut dot = 0.0f32;
3056 let mut norm_a_sq = 0.0f32;
3057 let mut norm_b_sq = 0.0f32;
3058 for i in 0..a.len() {
3059 dot += a[i] * b[i];
3060 norm_a_sq += a[i] * a[i];
3061 norm_b_sq += b[i] * b[i];
3062 }
3063 let denom = (norm_a_sq * norm_b_sq).sqrt();
3064 if denom == 0.0 {
3065 return Err(SQLRiteError::General(
3066 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
3067 ));
3068 }
3069 Ok(1.0 - dot / denom)
3070}
3071
3072pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
3076 debug_assert_eq!(a.len(), b.len());
3077 let mut dot = 0.0f32;
3078 for i in 0..a.len() {
3079 dot += a[i] * b[i];
3080 }
3081 -dot
3082}
3083
3084fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
3087 if matches!(l, Value::Null) || matches!(r, Value::Null) {
3088 return Ok(Value::Null);
3089 }
3090 match (l, r) {
3091 (Value::Integer(a), Value::Integer(b)) => match op {
3092 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
3093 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
3094 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
3095 BinaryOperator::Divide => {
3096 if *b == 0 {
3097 Err(SQLRiteError::General("division by zero".to_string()))
3098 } else {
3099 Ok(Value::Integer(a / b))
3100 }
3101 }
3102 BinaryOperator::Modulo => {
3103 if *b == 0 {
3104 Err(SQLRiteError::General("modulo by zero".to_string()))
3105 } else {
3106 Ok(Value::Integer(a % b))
3107 }
3108 }
3109 _ => unreachable!(),
3110 },
3111 (a, b) => {
3113 let af = as_number(a)?;
3114 let bf = as_number(b)?;
3115 match op {
3116 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
3117 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
3118 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
3119 BinaryOperator::Divide => {
3120 if bf == 0.0 {
3121 Err(SQLRiteError::General("division by zero".to_string()))
3122 } else {
3123 Ok(Value::Real(af / bf))
3124 }
3125 }
3126 BinaryOperator::Modulo => {
3127 if bf == 0.0 {
3128 Err(SQLRiteError::General("modulo by zero".to_string()))
3129 } else {
3130 Ok(Value::Real(af % bf))
3131 }
3132 }
3133 _ => unreachable!(),
3134 }
3135 }
3136 }
3137}
3138
3139fn as_number(v: &Value) -> Result<f64> {
3140 match v {
3141 Value::Integer(i) => Ok(*i as f64),
3142 Value::Real(f) => Ok(*f),
3143 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
3144 other => Err(SQLRiteError::General(format!(
3145 "arithmetic on non-numeric value '{}'",
3146 other.to_display_string()
3147 ))),
3148 }
3149}
3150
3151fn as_bool(v: &Value) -> Result<bool> {
3152 match v {
3153 Value::Bool(b) => Ok(*b),
3154 Value::Null => Ok(false),
3155 Value::Integer(i) => Ok(*i != 0),
3156 other => Err(SQLRiteError::Internal(format!(
3157 "expected boolean, got {}",
3158 other.to_display_string()
3159 ))),
3160 }
3161}
3162
3163#[allow(clippy::too_many_arguments)]
3168fn eval_like(
3169 scope: &dyn RowScope,
3170 negated: bool,
3171 any: bool,
3172 lhs: &Expr,
3173 pattern: &Expr,
3174 escape_char: Option<&AstValue>,
3175 case_insensitive: bool,
3176) -> Result<Value> {
3177 if any {
3178 return Err(SQLRiteError::NotImplemented(
3179 "LIKE ANY (...) is not supported".to_string(),
3180 ));
3181 }
3182 if escape_char.is_some() {
3183 return Err(SQLRiteError::NotImplemented(
3184 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
3185 ));
3186 }
3187
3188 let l = eval_expr_scope(lhs, scope)?;
3189 let p = eval_expr_scope(pattern, scope)?;
3190 if matches!(l, Value::Null) || matches!(p, Value::Null) {
3191 return Ok(Value::Null);
3192 }
3193 let text = match l {
3194 Value::Text(s) => s,
3195 other => other.to_display_string(),
3196 };
3197 let pat = match p {
3198 Value::Text(s) => s,
3199 other => other.to_display_string(),
3200 };
3201 let m = like_match(&text, &pat, case_insensitive);
3202 Ok(Value::Bool(if negated { !m } else { m }))
3203}
3204
3205fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
3206 let l = eval_expr_scope(lhs, scope)?;
3207 if matches!(l, Value::Null) {
3208 return Ok(Value::Null);
3209 }
3210 let mut saw_null = false;
3211 for item in list {
3212 let r = eval_expr_scope(item, scope)?;
3213 if matches!(r, Value::Null) {
3214 saw_null = true;
3215 continue;
3216 }
3217 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
3218 return Ok(Value::Bool(!negated));
3219 }
3220 }
3221 if saw_null {
3222 Ok(Value::Null)
3225 } else {
3226 Ok(Value::Bool(negated))
3227 }
3228}
3229
3230fn aggregate_rows(
3241 table: &Table,
3242 matching: &[i64],
3243 group_by: &[String],
3244 proj_items: &[ProjectionItem],
3245) -> Result<Vec<Vec<Value>>> {
3246 let template: Vec<Option<AggState>> = proj_items
3250 .iter()
3251 .map(|i| match &i.kind {
3252 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
3253 ProjectionKind::Column { .. } => None,
3254 })
3255 .collect();
3256
3257 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
3263 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
3264 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
3265
3266 for &rowid in matching {
3267 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
3268 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
3269 for col in group_by {
3270 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
3271 key.push(DistinctKey::from_value(&v));
3272 key_values.push(v);
3273 }
3274 let idx = match keys.iter().position(|k| k == &key) {
3275 Some(i) => i,
3276 None => {
3277 keys.push(key);
3278 group_states.push(template.clone());
3279 group_key_values.push(key_values);
3280 keys.len() - 1
3281 }
3282 };
3283
3284 for (slot, item) in proj_items.iter().enumerate() {
3285 if let ProjectionKind::Aggregate(call) = &item.kind {
3286 let v = match &call.arg {
3287 AggregateArg::Star => Value::Null,
3288 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
3289 };
3290 if let Some(state) = group_states[idx][slot].as_mut() {
3291 state.update(&v)?;
3292 }
3293 }
3294 }
3295 }
3296
3297 if keys.is_empty() && group_by.is_empty() {
3303 keys.push(Vec::new());
3306 group_states.push(template.clone());
3307 group_key_values.push(Vec::new());
3308 }
3309
3310 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3312 for (group_idx, _) in keys.iter().enumerate() {
3313 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3314 for (slot, item) in proj_items.iter().enumerate() {
3315 match &item.kind {
3316 ProjectionKind::Column { name: c, .. } => {
3317 let pos = group_by
3320 .iter()
3321 .position(|g| g == c)
3322 .expect("validated to be in GROUP BY");
3323 row.push(group_key_values[group_idx][pos].clone());
3324 }
3325 ProjectionKind::Aggregate(_) => {
3326 let state = group_states[group_idx][slot]
3327 .as_ref()
3328 .expect("aggregate slot has state");
3329 row.push(state.finalize());
3330 }
3331 }
3332 }
3333 rows.push(row);
3334 }
3335 Ok(rows)
3336}
3337
3338fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3342 use std::collections::HashSet;
3343 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3344 let mut out = Vec::with_capacity(rows.len());
3345 for row in rows {
3346 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3347 if seen.insert(key) {
3348 out.push(row);
3349 }
3350 }
3351 out
3352}
3353
3354fn sort_output_rows(
3358 rows: &mut [Vec<Value>],
3359 columns: &[String],
3360 proj_items: &[ProjectionItem],
3361 order: &OrderByClause,
3362) -> Result<()> {
3363 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3364 rows.sort_by(|a, b| {
3365 let va = &a[target_idx];
3366 let vb = &b[target_idx];
3367 let ord = compare_values(Some(va), Some(vb));
3368 if order.ascending { ord } else { ord.reverse() }
3369 });
3370 Ok(())
3371}
3372
3373fn resolve_order_by_index(
3376 expr: &Expr,
3377 columns: &[String],
3378 proj_items: &[ProjectionItem],
3379) -> Result<usize> {
3380 let target_name: Option<String> = match expr {
3382 Expr::Identifier(ident) => Some(ident.value.clone()),
3383 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3384 Expr::Function(_) => None,
3385 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3386 other => {
3387 return Err(SQLRiteError::NotImplemented(format!(
3388 "ORDER BY expression not supported on aggregating queries: {other:?}"
3389 )));
3390 }
3391 };
3392 if let Some(name) = target_name {
3393 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3394 return Ok(i);
3395 }
3396 return Err(SQLRiteError::Internal(format!(
3397 "ORDER BY references unknown column '{name}' in the SELECT output"
3398 )));
3399 }
3400 if let Expr::Function(func) = expr {
3404 let user_disp = format_function_display(func);
3405 for (i, item) in proj_items.iter().enumerate() {
3406 if let ProjectionKind::Aggregate(call) = &item.kind
3407 && call.display_name().eq_ignore_ascii_case(&user_disp)
3408 {
3409 return Ok(i);
3410 }
3411 }
3412 return Err(SQLRiteError::Internal(format!(
3413 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3414 )));
3415 }
3416 Err(SQLRiteError::Internal(
3417 "ORDER BY expression could not be resolved against the output columns".to_string(),
3418 ))
3419}
3420
3421fn format_function_display(func: &sqlparser::ast::Function) -> String {
3425 let name = match func.name.0.as_slice() {
3426 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3427 _ => format!("{:?}", func.name).to_uppercase(),
3428 };
3429 let inner = match &func.args {
3430 FunctionArguments::List(l) => {
3431 let distinct = matches!(
3432 l.duplicate_treatment,
3433 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3434 );
3435 let arg = l.args.first().map(|a| match a {
3436 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3437 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3438 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3439 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3440 }
3441 _ => String::new(),
3442 });
3443 match (distinct, arg) {
3444 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3445 (_, Some(a)) => a,
3446 _ => String::new(),
3447 }
3448 }
3449 _ => String::new(),
3450 };
3451 format!("{name}({inner})")
3452}
3453
3454fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3455 use sqlparser::ast::Value as AstValue;
3456 match v {
3457 AstValue::Number(n, _) => {
3458 if let Ok(i) = n.parse::<i64>() {
3459 Ok(Value::Integer(i))
3460 } else if let Ok(f) = n.parse::<f64>() {
3461 Ok(Value::Real(f))
3462 } else {
3463 Err(SQLRiteError::Internal(format!(
3464 "could not parse numeric literal '{n}'"
3465 )))
3466 }
3467 }
3468 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3469 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3470 AstValue::Null => Ok(Value::Null),
3471 other => Err(SQLRiteError::NotImplemented(format!(
3472 "unsupported literal value: {other:?}"
3473 ))),
3474 }
3475}
3476
3477#[cfg(test)]
3478mod tests {
3479 use super::*;
3480
3481 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3488 (a - b).abs() < eps
3489 }
3490
3491 #[test]
3492 fn vec_distance_l2_identical_is_zero() {
3493 let v = vec![0.1, 0.2, 0.3];
3494 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3495 }
3496
3497 #[test]
3498 fn vec_distance_l2_unit_basis_is_sqrt2() {
3499 let a = vec![1.0, 0.0];
3501 let b = vec![0.0, 1.0];
3502 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
3503 }
3504
3505 #[test]
3506 fn vec_distance_l2_known_value() {
3507 let a = vec![0.0, 0.0, 0.0];
3509 let b = vec![3.0, 4.0, 0.0];
3510 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
3511 }
3512
3513 #[test]
3514 fn vec_distance_cosine_identical_is_zero() {
3515 let v = vec![0.1, 0.2, 0.3];
3516 let d = vec_distance_cosine(&v, &v).unwrap();
3517 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
3518 }
3519
3520 #[test]
3521 fn vec_distance_cosine_orthogonal_is_one() {
3522 let a = vec![1.0, 0.0];
3525 let b = vec![0.0, 1.0];
3526 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
3527 }
3528
3529 #[test]
3530 fn vec_distance_cosine_opposite_is_two() {
3531 let a = vec![1.0, 0.0, 0.0];
3533 let b = vec![-1.0, 0.0, 0.0];
3534 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
3535 }
3536
3537 #[test]
3538 fn vec_distance_cosine_zero_magnitude_errors() {
3539 let a = vec![0.0, 0.0];
3541 let b = vec![1.0, 0.0];
3542 let err = vec_distance_cosine(&a, &b).unwrap_err();
3543 assert!(format!("{err}").contains("zero-magnitude"));
3544 }
3545
3546 #[test]
3547 fn vec_distance_dot_negates() {
3548 let a = vec![1.0, 2.0, 3.0];
3550 let b = vec![4.0, 5.0, 6.0];
3551 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
3552 }
3553
3554 #[test]
3555 fn vec_distance_dot_orthogonal_is_zero() {
3556 let a = vec![1.0, 0.0];
3558 let b = vec![0.0, 1.0];
3559 assert_eq!(vec_distance_dot(&a, &b), 0.0);
3560 }
3561
3562 #[test]
3563 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3564 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
3570 let cos = vec_distance_cosine(&a, &b).unwrap();
3571 assert!(approx_eq(dot, cos - 1.0, 1e-5));
3572 }
3573
3574 use crate::sql::db::database::Database;
3579 use crate::sql::dialect::SqlriteDialect;
3580 use crate::sql::parser::select::SelectQuery;
3581 use sqlparser::parser::Parser;
3582
3583 fn seed_score_table(n: usize) -> Database {
3596 let mut db = Database::new("tempdb".to_string());
3597 crate::sql::process_command(
3598 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
3599 &mut db,
3600 )
3601 .expect("create");
3602 for i in 0..n {
3603 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
3607 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
3608 crate::sql::process_command(&sql, &mut db).expect("insert");
3609 }
3610 db
3611 }
3612
3613 fn parse_select(sql: &str) -> SelectQuery {
3617 let dialect = SqlriteDialect::new();
3618 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
3619 let stmt = ast.pop().expect("one statement");
3620 SelectQuery::new(&stmt).expect("select-query")
3621 }
3622
3623 #[test]
3624 fn topk_matches_full_sort_asc() {
3625 let db = seed_score_table(200);
3628 let table = db.get_table("docs".to_string()).unwrap();
3629 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3630 let order = q.order_by.as_ref().unwrap();
3631 let all_rowids = table.rowids();
3632
3633 let mut full = all_rowids.clone();
3635 sort_rowids(&mut full, table, order).unwrap();
3636 full.truncate(10);
3637
3638 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3640
3641 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
3642 }
3643
3644 #[test]
3645 fn topk_matches_full_sort_desc() {
3646 let db = seed_score_table(200);
3648 let table = db.get_table("docs".to_string()).unwrap();
3649 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
3650 let order = q.order_by.as_ref().unwrap();
3651 let all_rowids = table.rowids();
3652
3653 let mut full = all_rowids.clone();
3654 sort_rowids(&mut full, table, order).unwrap();
3655 full.truncate(10);
3656
3657 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3658
3659 assert_eq!(
3660 topk, full,
3661 "top-k DESC via heap should match full-sort+truncate"
3662 );
3663 }
3664
3665 #[test]
3666 fn topk_k_larger_than_n_returns_everything_sorted() {
3667 let db = seed_score_table(50);
3672 let table = db.get_table("docs".to_string()).unwrap();
3673 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
3674 let order = q.order_by.as_ref().unwrap();
3675 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
3676 assert_eq!(topk.len(), 50);
3677 let scores: Vec<f64> = topk
3679 .iter()
3680 .filter_map(|r| match table.get_value("score", *r) {
3681 Some(Value::Real(f)) => Some(f),
3682 _ => None,
3683 })
3684 .collect();
3685 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
3686 }
3687
3688 #[test]
3689 fn topk_k_zero_returns_empty() {
3690 let db = seed_score_table(10);
3691 let table = db.get_table("docs".to_string()).unwrap();
3692 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
3693 let order = q.order_by.as_ref().unwrap();
3694 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
3695 assert!(topk.is_empty());
3696 }
3697
3698 #[test]
3699 fn topk_empty_input_returns_empty() {
3700 let db = seed_score_table(0);
3701 let table = db.get_table("docs".to_string()).unwrap();
3702 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
3703 let order = q.order_by.as_ref().unwrap();
3704 let topk = select_topk(&[], table, order, 5).unwrap();
3705 assert!(topk.is_empty());
3706 }
3707
3708 #[test]
3709 fn topk_works_through_select_executor_with_distance_function() {
3710 let mut db = Database::new("tempdb".to_string());
3714 crate::sql::process_command(
3715 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
3716 &mut db,
3717 )
3718 .unwrap();
3719 for v in &[
3726 "[1.0, 0.0]",
3727 "[2.0, 0.0]",
3728 "[0.0, 3.0]",
3729 "[1.0, 4.0]",
3730 "[10.0, 10.0]",
3731 ] {
3732 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3733 .unwrap();
3734 }
3735 let resp = crate::sql::process_command(
3736 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
3737 &mut db,
3738 )
3739 .unwrap();
3740 assert!(resp.contains("3 rows returned"), "got: {resp}");
3743 }
3744
3745 #[test]
3768 #[ignore]
3769 fn topk_benchmark() {
3770 use std::time::Instant;
3771 const N: usize = 10_000;
3772 const K: usize = 10;
3773
3774 let db = seed_score_table(N);
3775 let table = db.get_table("docs".to_string()).unwrap();
3776 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3777 let order = q.order_by.as_ref().unwrap();
3778 let all_rowids = table.rowids();
3779
3780 let t0 = Instant::now();
3782 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3783 let heap_dur = t0.elapsed();
3784
3785 let t1 = Instant::now();
3787 let mut full = all_rowids.clone();
3788 sort_rowids(&mut full, table, order).unwrap();
3789 full.truncate(K);
3790 let sort_dur = t1.elapsed();
3791
3792 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3793 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3794 println!(" bounded heap: {heap_dur:?}");
3795 println!(" full sort+trunc: {sort_dur:?}");
3796 println!(" speedup ratio: {ratio:.2}×");
3797
3798 assert!(
3805 ratio > 1.4,
3806 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
3807 );
3808 }
3809
3810 fn run_select(db: &mut Database, sql: &str) -> String {
3818 crate::sql::process_command(sql, db).expect("select")
3819 }
3820
3821 #[test]
3822 fn where_is_null_returns_null_rows() {
3823 let mut db = Database::new("t".to_string());
3824 crate::sql::process_command(
3825 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3826 &mut db,
3827 )
3828 .unwrap();
3829 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3830 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3831 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3832 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
3833
3834 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
3835 assert!(
3836 response.contains("2 rows returned"),
3837 "IS NULL should return 2 rows, got: {response}"
3838 );
3839 }
3840
3841 #[test]
3842 fn where_is_not_null_returns_non_null_rows() {
3843 let mut db = Database::new("t".to_string());
3844 crate::sql::process_command(
3845 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3846 &mut db,
3847 )
3848 .unwrap();
3849 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3850 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3851 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3852
3853 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
3854 assert!(
3855 response.contains("2 rows returned"),
3856 "IS NOT NULL should return 2 rows, got: {response}"
3857 );
3858 }
3859
3860 #[test]
3861 fn where_is_null_on_indexed_column() {
3862 let mut db = Database::new("t".to_string());
3867 crate::sql::process_command(
3868 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
3869 &mut db,
3870 )
3871 .unwrap();
3872 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
3873 .unwrap();
3874 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
3875 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
3876 .unwrap();
3877
3878 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
3879 assert!(
3880 null_rows.contains("1 row returned"),
3881 "indexed IS NULL should return 1 row, got: {null_rows}"
3882 );
3883 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
3884 assert!(
3885 not_null_rows.contains("2 rows returned"),
3886 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
3887 );
3888 }
3889
3890 #[test]
3891 fn where_is_null_works_on_omitted_column() {
3892 let mut db = Database::new("t".to_string());
3896 crate::sql::process_command(
3897 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
3898 &mut db,
3899 )
3900 .unwrap();
3901 crate::sql::process_command(
3902 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
3903 &mut db,
3904 )
3905 .unwrap();
3906 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
3908
3909 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
3910 assert!(
3911 response.contains("1 row returned"),
3912 "IS NULL should match the omitted-column row, got: {response}"
3913 );
3914 }
3915
3916 #[test]
3917 fn where_is_null_combines_with_and_or() {
3918 let mut db = Database::new("t".to_string());
3922 crate::sql::process_command(
3923 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3924 &mut db,
3925 )
3926 .unwrap();
3927 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
3928 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3929 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3930
3931 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
3932 assert!(
3933 response.contains("1 row returned"),
3934 "IS NULL combined with AND should match exactly row 2, got: {response}"
3935 );
3936 }
3937
3938 fn seed_employees() -> Database {
3944 let mut db = Database::new("t".to_string());
3945 crate::sql::process_command(
3946 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
3947 &mut db,
3948 )
3949 .unwrap();
3950 let rows = [
3951 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
3952 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
3953 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
3954 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
3955 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
3956 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
3957 ];
3958 for sql in rows {
3959 crate::sql::process_command(sql, &mut db).unwrap();
3960 }
3961 db
3962 }
3963
3964 fn run_rows(db: &Database, sql: &str) -> SelectResult {
3966 let q = parse_select(sql);
3967 execute_select_rows(q, db).expect("select")
3968 }
3969
3970 #[test]
3973 fn like_percent_prefix_case_insensitive() {
3974 let db = seed_employees();
3975 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
3976 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3978 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
3979 assert!(names.contains(&"Alice".to_string()));
3980 assert!(names.contains(&"alex".to_string()));
3981 }
3982
3983 #[test]
3984 fn like_underscore_singlechar() {
3985 let db = seed_employees();
3986 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
3987 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3989 assert_eq!(names, vec!["Eve".to_string()]);
3990 }
3991
3992 #[test]
3993 fn not_like_excludes_match() {
3994 let db = seed_employees();
3995 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
3996 assert_eq!(r.rows.len(), 4);
3998 }
3999
4000 #[test]
4001 fn like_with_null_excludes_row() {
4002 let db = seed_employees();
4003 let r = run_rows(
4005 &db,
4006 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
4007 );
4008 assert_eq!(r.rows.len(), 1);
4009 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
4010 }
4011
4012 #[test]
4015 fn in_list_positive() {
4016 let db = seed_employees();
4017 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
4018 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4019 assert_eq!(names.len(), 3);
4020 assert!(names.contains(&"Alice".to_string()));
4021 assert!(names.contains(&"Bob".to_string()));
4022 assert!(names.contains(&"Dave".to_string()));
4023 }
4024
4025 #[test]
4026 fn not_in_excludes_listed() {
4027 let db = seed_employees();
4028 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
4029 assert_eq!(r.rows.len(), 4);
4031 }
4032
4033 #[test]
4034 fn in_list_with_null_three_valued() {
4035 let db = seed_employees();
4036 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
4039 assert_eq!(r.rows.len(), 1);
4040 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
4041 }
4042
4043 #[test]
4046 fn distinct_single_column() {
4047 let db = seed_employees();
4048 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
4049 assert_eq!(r.rows.len(), 3);
4051 }
4052
4053 #[test]
4054 fn distinct_multi_column_with_null() {
4055 let db = seed_employees();
4056 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
4058 assert_eq!(r.rows.len(), 5);
4060 }
4061
4062 #[test]
4065 fn count_star_no_groupby() {
4066 let db = seed_employees();
4067 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
4068 assert_eq!(r.rows.len(), 1);
4069 assert_eq!(r.rows[0][0], Value::Integer(6));
4070 }
4071
4072 #[test]
4073 fn count_col_skips_nulls() {
4074 let db = seed_employees();
4075 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
4076 assert_eq!(r.rows[0][0], Value::Integer(5));
4078 }
4079
4080 #[test]
4081 fn count_distinct_dedupes_and_skips_nulls() {
4082 let db = seed_employees();
4083 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
4084 assert_eq!(r.rows[0][0], Value::Integer(4));
4086 }
4087
4088 #[test]
4089 fn sum_int_stays_integer() {
4090 let db = seed_employees();
4091 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
4092 assert_eq!(r.rows[0][0], Value::Integer(490));
4094 }
4095
4096 #[test]
4097 fn avg_returns_real() {
4098 let db = seed_employees();
4099 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
4100 match &r.rows[0][0] {
4102 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
4103 other => panic!("expected Real, got {other:?}"),
4104 }
4105 }
4106
4107 #[test]
4108 fn min_max_skip_nulls() {
4109 let db = seed_employees();
4110 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
4111 assert_eq!(r.rows[0][0], Value::Integer(80));
4112 assert_eq!(r.rows[0][1], Value::Integer(120));
4113 }
4114
4115 #[test]
4116 fn aggregates_on_empty_table_emit_one_row() {
4117 let mut db = Database::new("t".to_string());
4118 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4119 let r = run_rows(
4120 &db,
4121 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
4122 );
4123 assert_eq!(r.rows.len(), 1);
4124 assert_eq!(r.rows[0][0], Value::Integer(0));
4125 assert_eq!(r.rows[0][1], Value::Null);
4126 assert_eq!(r.rows[0][2], Value::Null);
4127 assert_eq!(r.rows[0][3], Value::Null);
4128 assert_eq!(r.rows[0][4], Value::Null);
4129 }
4130
4131 #[test]
4134 fn group_by_single_col_with_count() {
4135 let db = seed_employees();
4136 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
4137 assert_eq!(r.rows.len(), 3);
4138 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
4140 for row in &r.rows {
4141 let d = row[0].to_display_string();
4142 let c = match &row[1] {
4143 Value::Integer(i) => *i,
4144 v => panic!("expected Integer count, got {v:?}"),
4145 };
4146 by_dept.insert(d, c);
4147 }
4148 assert_eq!(by_dept["eng"], 3);
4149 assert_eq!(by_dept["sales"], 2);
4150 assert_eq!(by_dept["ops"], 1);
4151 }
4152
4153 #[test]
4154 fn group_by_with_where_filter() {
4155 let db = seed_employees();
4156 let r = run_rows(
4157 &db,
4158 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
4159 );
4160 let by: std::collections::HashMap<String, i64> = r
4163 .rows
4164 .iter()
4165 .map(|row| {
4166 (
4167 row[0].to_display_string(),
4168 match &row[1] {
4169 Value::Integer(i) => *i,
4170 v => panic!("expected Integer sum, got {v:?}"),
4171 },
4172 )
4173 })
4174 .collect();
4175 assert_eq!(by.len(), 2);
4176 assert_eq!(by["eng"], 320);
4177 assert_eq!(by["sales"], 90);
4178 }
4179
4180 #[test]
4181 fn group_by_without_aggregates_is_distinct() {
4182 let db = seed_employees();
4183 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
4184 assert_eq!(r.rows.len(), 3);
4185 }
4186
4187 #[test]
4188 fn order_by_count_desc() {
4189 let db = seed_employees();
4190 let r = run_rows(
4191 &db,
4192 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
4193 );
4194 assert_eq!(r.rows.len(), 2);
4195 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4197 assert_eq!(r.rows[0][1], Value::Integer(3));
4198 }
4199
4200 #[test]
4201 fn order_by_aggregate_call_form() {
4202 let db = seed_employees();
4203 let r = run_rows(
4205 &db,
4206 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
4207 );
4208 assert_eq!(r.rows.len(), 3);
4209 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4210 }
4211
4212 #[test]
4213 fn group_by_invalid_bare_column_errors() {
4214 let mut db = Database::new("t".to_string());
4216 crate::sql::process_command(
4217 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
4218 &mut db,
4219 )
4220 .unwrap();
4221 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
4222 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
4223 }
4224
4225 #[test]
4226 fn aggregate_in_where_errors_friendly() {
4227 let mut db = Database::new("t".to_string());
4228 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4229 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
4230 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
4231 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
4232 }
4233
4234 fn seed_join_fixture() -> Database {
4245 let mut db = Database::new("t".to_string());
4246 for sql in [
4247 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
4248 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
4249 "INSERT INTO customers (name) VALUES ('Alice');",
4250 "INSERT INTO customers (name) VALUES ('Bob');",
4251 "INSERT INTO customers (name) VALUES ('Carol');",
4252 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
4253 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
4254 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
4255 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
4256 ] {
4257 crate::sql::process_command(sql, &mut db).unwrap();
4258 }
4259 db
4260 }
4261
4262 #[test]
4263 fn inner_join_returns_only_matched_rows() {
4264 let db = seed_join_fixture();
4265 let r = run_rows(
4266 &db,
4267 "SELECT customers.name, orders.amount FROM customers \
4268 INNER JOIN orders ON customers.id = orders.customer_id;",
4269 );
4270 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4271 let pairs: Vec<(String, i64)> = r
4274 .rows
4275 .iter()
4276 .map(|row| {
4277 (
4278 row[0].to_display_string(),
4279 match row[1] {
4280 Value::Integer(i) => i,
4281 ref v => panic!("expected integer amount, got {v:?}"),
4282 },
4283 )
4284 })
4285 .collect();
4286 assert_eq!(pairs.len(), 3);
4287 assert!(pairs.contains(&("Alice".to_string(), 100)));
4288 assert!(pairs.contains(&("Alice".to_string(), 200)));
4289 assert!(pairs.contains(&("Bob".to_string(), 50)));
4290 }
4291
4292 #[test]
4293 fn bare_join_defaults_to_inner() {
4294 let db = seed_join_fixture();
4295 let r = run_rows(
4296 &db,
4297 "SELECT customers.name FROM customers \
4298 JOIN orders ON customers.id = orders.customer_id;",
4299 );
4300 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
4301 }
4302
4303 #[test]
4304 fn left_outer_join_preserves_unmatched_left() {
4305 let db = seed_join_fixture();
4306 let r = run_rows(
4307 &db,
4308 "SELECT customers.name, orders.amount FROM customers \
4309 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
4310 );
4311 assert_eq!(r.rows.len(), 4);
4314 let carol = r
4315 .rows
4316 .iter()
4317 .find(|row| row[0].to_display_string() == "Carol")
4318 .expect("Carol should appear with a NULL-padded right side");
4319 assert_eq!(carol[1], Value::Null);
4320 }
4321
4322 #[test]
4323 fn right_outer_join_preserves_unmatched_right() {
4324 let db = seed_join_fixture();
4325 let r = run_rows(
4326 &db,
4327 "SELECT customers.name, orders.amount FROM customers \
4328 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
4329 );
4330 assert_eq!(r.rows.len(), 4);
4334 let dangling = r
4335 .rows
4336 .iter()
4337 .find(|row| matches!(row[1], Value::Integer(999)))
4338 .expect("dangling order 999 should appear with a NULL-padded customer name");
4339 assert_eq!(dangling[0], Value::Null);
4340 }
4341
4342 #[test]
4343 fn full_outer_join_preserves_both_sides() {
4344 let db = seed_join_fixture();
4345 let r = run_rows(
4346 &db,
4347 "SELECT customers.name, orders.amount FROM customers \
4348 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
4349 );
4350 assert_eq!(r.rows.len(), 5);
4353 assert!(
4355 r.rows
4356 .iter()
4357 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
4358 );
4359 assert!(
4361 r.rows
4362 .iter()
4363 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
4364 );
4365 }
4366
4367 #[test]
4368 fn join_with_table_aliases_resolves_qualifiers() {
4369 let db = seed_join_fixture();
4370 let r = run_rows(
4371 &db,
4372 "SELECT c.name, o.amount FROM customers AS c \
4373 INNER JOIN orders AS o ON c.id = o.customer_id;",
4374 );
4375 assert_eq!(r.rows.len(), 3);
4376 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4377 }
4378
4379 #[test]
4380 fn join_with_where_filter_applies_after_join() {
4381 let db = seed_join_fixture();
4382 let r = run_rows(
4385 &db,
4386 "SELECT customers.name, orders.amount FROM customers \
4387 INNER JOIN orders ON customers.id = orders.customer_id \
4388 WHERE orders.amount >= 100;",
4389 );
4390 assert_eq!(r.rows.len(), 2);
4391 assert!(
4392 r.rows
4393 .iter()
4394 .all(|row| row[0].to_display_string() == "Alice")
4395 );
4396 }
4397
4398 #[test]
4399 fn left_join_with_where_on_right_side_is_not_inner() {
4400 let db = seed_join_fixture();
4404 let r = run_rows(
4405 &db,
4406 "SELECT customers.name, orders.amount FROM customers \
4407 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
4408 WHERE orders.amount IS NULL;",
4409 );
4410 assert_eq!(r.rows.len(), 1);
4412 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4413 assert_eq!(r.rows[0][1], Value::Null);
4414 }
4415
4416 #[test]
4417 fn select_star_over_join_emits_all_columns_from_both_tables() {
4418 let db = seed_join_fixture();
4419 let r = run_rows(
4420 &db,
4421 "SELECT * FROM customers \
4422 INNER JOIN orders ON customers.id = orders.customer_id;",
4423 );
4424 assert_eq!(
4428 r.columns,
4429 vec![
4430 "id".to_string(),
4431 "name".to_string(),
4432 "id".to_string(),
4433 "customer_id".to_string(),
4434 "amount".to_string(),
4435 ]
4436 );
4437 assert_eq!(r.rows.len(), 3);
4438 }
4439
4440 #[test]
4441 fn join_order_by_sorts_full_joined_rows() {
4442 let db = seed_join_fixture();
4443 let r = run_rows(
4444 &db,
4445 "SELECT c.name, o.amount FROM customers AS c \
4446 INNER JOIN orders AS o ON c.id = o.customer_id \
4447 ORDER BY o.amount;",
4448 );
4449 let amounts: Vec<i64> = r
4450 .rows
4451 .iter()
4452 .map(|row| match row[1] {
4453 Value::Integer(i) => i,
4454 ref v => panic!("expected integer, got {v:?}"),
4455 })
4456 .collect();
4457 assert_eq!(amounts, vec![50, 100, 200]);
4458 }
4459
4460 #[test]
4461 fn join_limit_truncates_after_join_and_sort() {
4462 let db = seed_join_fixture();
4463 let r = run_rows(
4464 &db,
4465 "SELECT c.name, o.amount FROM customers AS c \
4466 INNER JOIN orders AS o ON c.id = o.customer_id \
4467 ORDER BY o.amount DESC LIMIT 2;",
4468 );
4469 assert_eq!(r.rows.len(), 2);
4470 let amounts: Vec<i64> = r
4472 .rows
4473 .iter()
4474 .map(|row| match row[1] {
4475 Value::Integer(i) => i,
4476 ref v => panic!("expected integer, got {v:?}"),
4477 })
4478 .collect();
4479 assert_eq!(amounts, vec![200, 100]);
4480 }
4481
4482 #[test]
4483 fn three_table_join_chains_correctly() {
4484 let mut db = Database::new("t".to_string());
4485 for sql in [
4486 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4487 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4488 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4489 "INSERT INTO a (label) VALUES ('a-one');",
4490 "INSERT INTO a (label) VALUES ('a-two');",
4491 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4492 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
4493 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
4494 ] {
4495 crate::sql::process_command(sql, &mut db).unwrap();
4496 }
4497 let r = run_rows(
4498 &db,
4499 "SELECT a.label, b.tag, c.note FROM a \
4500 INNER JOIN b ON a.id = b.a_id \
4501 INNER JOIN c ON b.id = c.b_id;",
4502 );
4503 assert_eq!(r.rows.len(), 1);
4505 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
4506 assert_eq!(r.rows[0][1].to_display_string(), "b1");
4507 assert_eq!(r.rows[0][2].to_display_string(), "c1");
4508 }
4509
4510 #[test]
4511 fn ambiguous_unqualified_column_in_join_errors() {
4512 let db = seed_join_fixture();
4516 let q = parse_select(
4517 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
4518 );
4519 let res = execute_select_rows(q, &db);
4520 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
4521 }
4522
4523 #[test]
4524 fn join_self_without_alias_is_rejected() {
4525 let mut db = Database::new("t".to_string());
4526 crate::sql::process_command(
4527 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
4528 &mut db,
4529 )
4530 .unwrap();
4531 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
4532 let res = execute_select_rows(q, &db);
4533 assert!(
4534 res.is_err(),
4535 "self-join without an alias should error on duplicate qualifier"
4536 );
4537 }
4538
4539 #[test]
4545 fn join_using_matches_same_rows_as_on() {
4546 let db = seed_join_fixture();
4547 let using = run_rows(
4548 &db,
4549 "SELECT customers.name, orders.amount FROM customers \
4550 INNER JOIN orders USING (id) ORDER BY orders.amount;",
4551 );
4552 let on = run_rows(
4553 &db,
4554 "SELECT customers.name, orders.amount FROM customers \
4555 INNER JOIN orders ON customers.id = orders.id ORDER BY orders.amount;",
4556 );
4557 let pairs: Vec<(String, Value)> = using
4559 .rows
4560 .iter()
4561 .map(|r| (r[0].to_display_string(), r[1].clone()))
4562 .collect();
4563 assert_eq!(pairs.len(), 3);
4564 assert_eq!(
4565 using.rows, on.rows,
4566 "USING must mirror the explicit ON rows"
4567 );
4568 }
4569
4570 #[test]
4573 fn select_star_using_dedups_joined_column() {
4574 let db = seed_join_fixture();
4575 let r = run_rows(&db, "SELECT * FROM customers INNER JOIN orders USING (id);");
4576 assert_eq!(
4580 r.columns,
4581 vec![
4582 "id".to_string(),
4583 "name".to_string(),
4584 "customer_id".to_string(),
4585 "amount".to_string(),
4586 ]
4587 );
4588 assert_eq!(r.rows.len(), 3);
4589 for row in &r.rows {
4592 assert!(matches!(row[0], Value::Integer(_)));
4593 }
4594 }
4595
4596 fn seed_natural_fixture() -> Database {
4597 let mut db = Database::new("t".to_string());
4598 for sql in [
4599 "CREATE TABLE l (lid INTEGER PRIMARY KEY, k1 INTEGER, k2 INTEGER, v1 TEXT);",
4602 "CREATE TABLE r (rid INTEGER PRIMARY KEY, k1 INTEGER, k2 INTEGER, v2 TEXT);",
4603 "INSERT INTO l (k1, k2, v1) VALUES (1, 1, 'l-a');",
4604 "INSERT INTO l (k1, k2, v1) VALUES (1, 2, 'l-b');",
4605 "INSERT INTO l (k1, k2, v1) VALUES (2, 1, 'l-c');",
4606 "INSERT INTO r (k1, k2, v2) VALUES (1, 1, 'r-a');",
4607 "INSERT INTO r (k1, k2, v2) VALUES (1, 2, 'r-b');",
4608 "INSERT INTO r (k1, k2, v2) VALUES (9, 9, 'r-z');",
4609 ] {
4610 crate::sql::process_command(sql, &mut db).unwrap();
4611 }
4612 db
4613 }
4614
4615 #[test]
4618 fn natural_join_matches_on_all_shared_columns() {
4619 let db = seed_natural_fixture();
4620 let natural = run_rows(&db, "SELECT v1, v2 FROM l NATURAL JOIN r ORDER BY v1;");
4621 let pairs: Vec<(String, String)> = natural
4623 .rows
4624 .iter()
4625 .map(|r| (r[0].to_display_string(), r[1].to_display_string()))
4626 .collect();
4627 assert_eq!(
4628 pairs,
4629 vec![
4630 ("l-a".to_string(), "r-a".to_string()),
4631 ("l-b".to_string(), "r-b".to_string()),
4632 ]
4633 );
4634 let explicit = run_rows(
4636 &db,
4637 "SELECT v1, v2 FROM l INNER JOIN r ON l.k1 = r.k1 AND l.k2 = r.k2 ORDER BY v1;",
4638 );
4639 assert_eq!(natural.rows, explicit.rows);
4640 }
4641
4642 #[test]
4644 fn select_star_natural_dedups_shared_columns() {
4645 let db = seed_natural_fixture();
4646 let r = run_rows(&db, "SELECT * FROM l NATURAL JOIN r;");
4647 assert_eq!(
4650 r.columns,
4651 vec![
4652 "lid".to_string(),
4653 "k1".to_string(),
4654 "k2".to_string(),
4655 "v1".to_string(),
4656 "rid".to_string(),
4657 "v2".to_string(),
4658 ]
4659 );
4660 assert_eq!(r.rows.len(), 2);
4661 }
4662
4663 #[test]
4666 fn natural_join_without_common_columns_is_cross_product() {
4667 let mut db = Database::new("t".to_string());
4668 for sql in [
4669 "CREATE TABLE p (pid INTEGER PRIMARY KEY, pa TEXT);",
4670 "CREATE TABLE q (qid INTEGER PRIMARY KEY, qb TEXT);",
4671 "INSERT INTO p (pa) VALUES ('p1');",
4672 "INSERT INTO p (pa) VALUES ('p2');",
4673 "INSERT INTO q (qb) VALUES ('q1');",
4674 "INSERT INTO q (qb) VALUES ('q2');",
4675 "INSERT INTO q (qb) VALUES ('q3');",
4676 ] {
4677 crate::sql::process_command(sql, &mut db).unwrap();
4678 }
4679 let r = run_rows(&db, "SELECT p.pa, q.qb FROM p NATURAL JOIN q;");
4680 assert_eq!(r.rows.len(), 2 * 3, "no shared columns ⇒ cross product");
4681 }
4682
4683 #[test]
4686 fn cross_join_produces_cartesian_product() {
4687 let db = seed_join_fixture();
4688 let cross = run_rows(
4689 &db,
4690 "SELECT customers.name, orders.amount FROM customers CROSS JOIN orders;",
4691 );
4692 assert_eq!(cross.rows.len(), 12);
4694 let on_true = run_rows(
4695 &db,
4696 "SELECT customers.name, orders.amount FROM customers INNER JOIN orders ON 1;",
4697 );
4698 assert_eq!(cross.rows.len(), on_true.rows.len());
4699 let star = run_rows(&db, "SELECT * FROM customers CROSS JOIN orders;");
4701 assert_eq!(star.columns.len(), 5);
4702 assert_eq!(star.rows.len(), 12);
4703 }
4704
4705 #[test]
4709 fn left_outer_join_using_preserves_unmatched_left() {
4710 let db = seed_join_fixture();
4711 let r = run_rows(
4712 &db,
4713 "SELECT * FROM customers LEFT OUTER JOIN orders USING (id);",
4714 );
4715 assert_eq!(r.columns.len(), 4, "id is shown once");
4719 assert_eq!(r.rows.len(), 3);
4720 }
4721
4722 #[test]
4725 fn using_unknown_column_errors() {
4726 let db = seed_join_fixture();
4727 let q = parse_select("SELECT * FROM customers INNER JOIN orders USING (nope);");
4728 let res = execute_select_rows(q, &db);
4729 assert!(res.is_err(), "USING (nope) must error — column absent");
4730 }
4731
4732 #[test]
4733 fn aggregates_over_join_are_rejected() {
4734 let db = seed_join_fixture();
4735 let err = crate::sql::process_command(
4736 "SELECT COUNT(*) FROM customers \
4737 INNER JOIN orders ON customers.id = orders.customer_id;",
4738 &mut seed_join_fixture(),
4739 );
4740 assert!(err.is_err(), "aggregates over JOIN are not yet supported");
4741 let _ = db; }
4743
4744 #[test]
4745 fn left_join_with_no_matches_pads_every_row() {
4746 let mut db = Database::new("t".to_string());
4747 for sql in [
4748 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4749 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4750 "INSERT INTO a (x) VALUES (1);",
4751 "INSERT INTO a (x) VALUES (2);",
4752 "INSERT INTO b (y) VALUES (10);",
4753 ] {
4754 crate::sql::process_command(sql, &mut db).unwrap();
4755 }
4756 let r = run_rows(
4758 &db,
4759 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
4760 );
4761 assert_eq!(r.rows.len(), 2);
4762 for row in &r.rows {
4763 assert_eq!(row[1], Value::Null);
4764 }
4765 }
4766
4767 #[test]
4768 fn left_outer_join_order_by_places_nulls_first() {
4769 let db = seed_join_fixture();
4774 let r = run_rows(
4775 &db,
4776 "SELECT c.name, o.amount FROM customers AS c \
4777 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
4778 ORDER BY o.amount ASC;",
4779 );
4780 assert_eq!(r.rows.len(), 4);
4781 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4783 assert_eq!(r.rows[0][1], Value::Null);
4784 }
4785
4786 #[test]
4787 fn chained_left_outer_join_preserves_left_through_two_levels() {
4788 let mut db = Database::new("t".to_string());
4791 for sql in [
4792 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4793 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4794 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4795 "INSERT INTO a (label) VALUES ('a-one');",
4796 "INSERT INTO a (label) VALUES ('a-two');",
4797 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4799 ] {
4801 crate::sql::process_command(sql, &mut db).unwrap();
4802 }
4803 let r = run_rows(
4804 &db,
4805 "SELECT a.label, b.tag, c.note FROM a \
4806 LEFT OUTER JOIN b ON a.id = b.a_id \
4807 LEFT OUTER JOIN c ON b.id = c.b_id;",
4808 );
4809 assert_eq!(r.rows.len(), 2);
4811 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
4812 .rows
4813 .iter()
4814 .map(|row| (row[0].to_display_string(), row))
4815 .collect();
4816 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
4817 assert_eq!(by_label["a-one"][2], Value::Null);
4818 assert_eq!(by_label["a-two"][1], Value::Null);
4819 assert_eq!(by_label["a-two"][2], Value::Null);
4820 }
4821
4822 #[test]
4823 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
4824 let mut db = Database::new("t".to_string());
4828 for sql in [
4829 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4830 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
4831 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
4832 "INSERT INTO a (x) VALUES (1);",
4833 "INSERT INTO b (x) VALUES (1);",
4834 "INSERT INTO c (x) VALUES (1);",
4835 ] {
4836 crate::sql::process_command(sql, &mut db).unwrap();
4837 }
4838 let q =
4839 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
4840 let res = execute_select_rows(q, &db);
4841 assert!(
4842 res.is_err(),
4843 "ON referencing not-yet-joined table 'c' should error"
4844 );
4845 }
4846
4847 #[test]
4848 fn join_on_truthy_integer_is_accepted() {
4849 let mut db = Database::new("t".to_string());
4853 for sql in [
4854 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4855 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4856 "INSERT INTO a (x) VALUES (1);",
4857 "INSERT INTO a (x) VALUES (2);",
4858 "INSERT INTO b (y) VALUES (10);",
4859 "INSERT INTO b (y) VALUES (20);",
4860 ] {
4861 crate::sql::process_command(sql, &mut db).unwrap();
4862 }
4863 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
4864 assert_eq!(r.rows.len(), 4);
4866 }
4867
4868 #[test]
4869 fn full_join_on_empty_tables_returns_empty() {
4870 let mut db = Database::new("t".to_string());
4871 for sql in [
4872 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4873 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4874 ] {
4875 crate::sql::process_command(sql, &mut db).unwrap();
4876 }
4877 let r = run_rows(
4878 &db,
4879 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
4880 );
4881 assert!(r.rows.is_empty());
4882 }
4883}