1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, JoinType, OrderByClause, Projection, ProjectionItem, ProjectionKind, SelectQuery,
25};
26
27pub(crate) trait RowScope {
56 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
57
58 fn single_table_view(&self) -> Option<(&Table, i64)>;
64}
65
66pub(crate) struct SingleTableScope<'a> {
68 table: &'a Table,
69 rowid: i64,
70}
71
72impl<'a> SingleTableScope<'a> {
73 pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
74 Self { table, rowid }
75 }
76}
77
78impl RowScope for SingleTableScope<'_> {
79 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
80 let _ = qualifier;
85 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
86 }
87
88 fn single_table_view(&self) -> Option<(&Table, i64)> {
89 Some((self.table, self.rowid))
90 }
91}
92
93pub(crate) struct JoinedTableRef<'a> {
97 pub table: &'a Table,
98 pub scope_name: String,
99}
100
101pub(crate) struct JoinedScope<'a> {
105 pub tables: &'a [JoinedTableRef<'a>],
106 pub rowids: &'a [Option<i64>],
107}
108
109impl RowScope for JoinedScope<'_> {
110 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
111 if let Some(q) = qualifier {
112 let pos = self
115 .tables
116 .iter()
117 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
118 .ok_or_else(|| {
119 SQLRiteError::Internal(format!(
120 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
121 ))
122 })?;
123 if !self.tables[pos].table.contains_column(col.to_string()) {
124 return Err(SQLRiteError::Internal(format!(
125 "column '{col}' does not exist on '{}'",
126 self.tables[pos].scope_name
127 )));
128 }
129 return Ok(match self.rowids[pos] {
130 None => Value::Null,
131 Some(r) => self.tables[pos]
132 .table
133 .get_value(col, r)
134 .unwrap_or(Value::Null),
135 });
136 }
137 let mut hit: Option<usize> = None;
141 for (i, t) in self.tables.iter().enumerate() {
142 if t.table.contains_column(col.to_string()) {
143 if hit.is_some() {
144 return Err(SQLRiteError::Internal(format!(
145 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
146 )));
147 }
148 hit = Some(i);
149 }
150 }
151 let i = hit.ok_or_else(|| {
152 SQLRiteError::Internal(format!(
153 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
154 ))
155 })?;
156 Ok(match self.rowids[i] {
157 None => Value::Null,
158 Some(r) => self.tables[i]
159 .table
160 .get_value(col, r)
161 .unwrap_or(Value::Null),
162 })
163 }
164
165 fn single_table_view(&self) -> Option<(&Table, i64)> {
166 None
167 }
168}
169
170pub struct SelectResult {
179 pub columns: Vec<String>,
180 pub rows: Vec<Vec<Value>>,
181}
182
183pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
187 if !query.joins.is_empty() {
192 return execute_select_rows_joined(query, db);
193 }
194
195 let table = db
196 .get_table(query.table_name.clone())
197 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
198
199 let proj_items: Vec<ProjectionItem> = match &query.projection {
204 Projection::All => table
205 .column_names()
206 .into_iter()
207 .map(|c| ProjectionItem {
208 kind: ProjectionKind::Column {
209 qualifier: None,
210 name: c,
211 },
212 alias: None,
213 })
214 .collect(),
215 Projection::Items(items) => items.clone(),
216 };
217 let has_aggregates = proj_items
218 .iter()
219 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
220 for item in &proj_items {
222 if let ProjectionKind::Column { name: c, .. } = &item.kind
223 && !table.contains_column(c.clone())
224 {
225 return Err(SQLRiteError::Internal(format!(
226 "Column '{c}' does not exist on table '{}'",
227 query.table_name
228 )));
229 }
230 }
231 for c in &query.group_by {
232 if !table.contains_column(c.clone()) {
233 return Err(SQLRiteError::Internal(format!(
234 "GROUP BY references unknown column '{c}' on table '{}'",
235 query.table_name
236 )));
237 }
238 }
239 let matching = match select_rowids(table, query.selection.as_ref())? {
243 RowidSource::IndexProbe(rowids) => rowids,
244 RowidSource::FullScan => {
245 let mut out = Vec::new();
246 for rowid in table.rowids() {
247 if let Some(expr) = &query.selection
248 && !eval_predicate(expr, table, rowid)?
249 {
250 continue;
251 }
252 out.push(rowid);
253 }
254 out
255 }
256 };
257 let mut matching = matching;
258
259 let aggregating = has_aggregates || !query.group_by.is_empty();
260
261 if aggregating {
267 for item in &proj_items {
269 if let ProjectionKind::Aggregate(call) = &item.kind
270 && let AggregateArg::Column(c) = &call.arg
271 && !table.contains_column(c.clone())
272 {
273 return Err(SQLRiteError::Internal(format!(
274 "{}({}) references unknown column '{c}' on table '{}'",
275 call.func.as_str(),
276 c,
277 query.table_name
278 )));
279 }
280 }
281
282 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
283 let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
284
285 if query.distinct {
286 rows = dedupe_rows(rows);
287 }
288
289 if let Some(order) = &query.order_by {
290 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
291 }
292 if let Some(k) = query.limit {
293 rows.truncate(k);
294 }
295
296 return Ok(SelectResult { columns, rows });
297 }
298
299 let defer_limit_for_distinct = query.distinct;
337 match (&query.order_by, query.limit) {
338 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
339 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
340 }
341 (Some(order), Some(k))
342 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
343 {
344 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
345 }
346 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
347 matching = select_topk(&matching, table, order, k)?;
348 }
349 (Some(order), _) => {
350 sort_rowids(&mut matching, table, order)?;
351 if let Some(k) = query.limit
352 && !defer_limit_for_distinct
353 {
354 matching.truncate(k);
355 }
356 }
357 (None, Some(k)) if !defer_limit_for_distinct => {
358 matching.truncate(k);
359 }
360 _ => {}
361 }
362
363 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
364 let projected_cols: Vec<String> = proj_items
365 .iter()
366 .map(|i| match &i.kind {
367 ProjectionKind::Column { name, .. } => name.clone(),
368 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
369 })
370 .collect();
371
372 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
376 for rowid in &matching {
377 let row: Vec<Value> = projected_cols
378 .iter()
379 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
380 .collect();
381 rows.push(row);
382 }
383
384 if query.distinct {
385 rows = dedupe_rows(rows);
386 if let Some(k) = query.limit {
387 rows.truncate(k);
388 }
389 }
390
391 Ok(SelectResult { columns, rows })
392}
393
394fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
421 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
428
429 let primary = db
430 .get_table(query.table_name.clone())
431 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
432 joined_tables.push(JoinedTableRef {
433 table: primary,
434 scope_name: query
435 .table_alias
436 .clone()
437 .unwrap_or_else(|| query.table_name.clone()),
438 });
439 for j in &query.joins {
440 let t = db
441 .get_table(j.right_table.clone())
442 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
443 joined_tables.push(JoinedTableRef {
444 table: t,
445 scope_name: j
446 .right_alias
447 .clone()
448 .unwrap_or_else(|| j.right_table.clone()),
449 });
450 }
451
452 {
457 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
458 for t in &joined_tables {
459 let key = t.scope_name.to_ascii_lowercase();
460 if !seen.insert(key) {
461 return Err(SQLRiteError::Internal(format!(
462 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
463 t.scope_name
464 )));
465 }
466 }
467 }
468
469 let proj_items: Vec<ProjectionItem> = match &query.projection {
475 Projection::All => {
476 let mut all = Vec::new();
485 for t in &joined_tables {
486 for col in t.table.column_names() {
487 all.push(ProjectionItem {
488 kind: ProjectionKind::Column {
489 qualifier: Some(t.scope_name.clone()),
494 name: col,
495 },
496 alias: None,
497 });
498 }
499 }
500 all
501 }
502 Projection::Items(items) => items.clone(),
503 };
504
505 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
506
507 let mut acc: Vec<Vec<Option<i64>>> = primary
512 .rowids()
513 .into_iter()
514 .map(|r| {
515 let mut row = Vec::with_capacity(joined_tables.len());
516 row.push(Some(r));
517 row
518 })
519 .collect();
520
521 for (j_idx, join) in query.joins.iter().enumerate() {
526 let right_pos = j_idx + 1;
527 let right_table = joined_tables[right_pos].table;
528 let right_rowids: Vec<i64> = right_table.rowids();
529
530 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
534
535 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
536
537 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
545
546 for left_row in acc.into_iter() {
547 let mut left_match_count = 0usize;
551 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
552 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
553 on_rowids.push(Some(rrid));
554 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
555 let scope = JoinedScope {
556 tables: on_scope_tables,
557 rowids: &on_rowids,
558 };
559 if eval_predicate_scope(&join.on, &scope)? {
564 left_match_count += 1;
565 right_matched[r_idx] = true;
566 next_acc.push(on_rowids);
571 }
572 }
573
574 if left_match_count == 0
575 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
576 {
577 let mut padded = left_row;
580 padded.push(None);
581 next_acc.push(padded);
582 }
583 }
584
585 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
589 for (r_idx, matched) in right_matched.iter().enumerate() {
590 if *matched {
591 continue;
592 }
593 let mut row: Vec<Option<i64>> = vec![None; right_pos];
594 row.push(Some(right_rowids[r_idx]));
595 next_acc.push(row);
596 }
597 }
598
599 acc = next_acc;
600 }
601
602 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
607 let mut out = Vec::with_capacity(acc.len());
608 for row in acc {
609 let scope = JoinedScope {
610 tables: &joined_tables,
611 rowids: &row,
612 };
613 if eval_predicate_scope(where_expr, &scope)? {
614 out.push(row);
615 }
616 }
617 out
618 } else {
619 acc
620 };
621
622 if let Some(order) = &query.order_by {
626 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
629 for (i, row) in filtered.iter().enumerate() {
630 let scope = JoinedScope {
631 tables: &joined_tables,
632 rowids: row,
633 };
634 let v = eval_expr_scope(&order.expr, &scope)?;
635 keys.push((i, v));
636 }
637 keys.sort_by(|(_, a), (_, b)| {
638 let ord = compare_values(Some(a), Some(b));
639 if order.ascending { ord } else { ord.reverse() }
640 });
641 let mut sorted = Vec::with_capacity(filtered.len());
642 for (i, _) in keys {
643 sorted.push(filtered[i].clone());
644 }
645 filtered = sorted;
646 }
647
648 if let Some(k) = query.limit {
650 filtered.truncate(k);
651 }
652
653 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
656 for row in &filtered {
657 let scope = JoinedScope {
658 tables: &joined_tables,
659 rowids: row,
660 };
661 let mut out_row = Vec::with_capacity(proj_items.len());
662 for item in &proj_items {
663 let v = match &item.kind {
664 ProjectionKind::Column { qualifier, name } => {
665 scope.lookup(qualifier.as_deref(), name)?
666 }
667 ProjectionKind::Aggregate(_) => {
668 return Err(SQLRiteError::Internal(
671 "aggregate functions over JOIN are not supported".to_string(),
672 ));
673 }
674 };
675 out_row.push(v);
676 }
677 rows.push(out_row);
678 }
679
680 Ok(SelectResult { columns, rows })
681}
682
683pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
688 let result = execute_select_rows(query, db)?;
689 let row_count = result.rows.len();
690
691 let mut print_table = PrintTable::new();
692 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
693 print_table.add_row(PrintRow::new(header_cells));
694
695 for row in &result.rows {
696 let cells: Vec<PrintCell> = row
697 .iter()
698 .map(|v| PrintCell::new(&v.to_display_string()))
699 .collect();
700 print_table.add_row(PrintRow::new(cells));
701 }
702
703 Ok((print_table.to_string(), row_count))
704}
705
706pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
708 let Statement::Delete(Delete {
709 from, selection, ..
710 }) = stmt
711 else {
712 return Err(SQLRiteError::Internal(
713 "execute_delete called on a non-DELETE statement".to_string(),
714 ));
715 };
716
717 let tables = match from {
718 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
719 };
720 let table_name = extract_single_table_name(tables)?;
721
722 let matching: Vec<i64> = {
724 let table = db
725 .get_table(table_name.clone())
726 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
727 match select_rowids(table, selection.as_ref())? {
728 RowidSource::IndexProbe(rowids) => rowids,
729 RowidSource::FullScan => {
730 let mut out = Vec::new();
731 for rowid in table.rowids() {
732 if let Some(expr) = selection {
733 if !eval_predicate(expr, table, rowid)? {
734 continue;
735 }
736 }
737 out.push(rowid);
738 }
739 out
740 }
741 }
742 };
743
744 let table = db.get_table_mut(table_name)?;
745 for rowid in &matching {
746 table.delete_row(*rowid);
747 }
748 if !matching.is_empty() {
757 for entry in &mut table.hnsw_indexes {
758 entry.needs_rebuild = true;
759 }
760 for entry in &mut table.fts_indexes {
761 entry.needs_rebuild = true;
762 }
763 }
764 Ok(matching.len())
765}
766
767pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
769 let Statement::Update(Update {
770 table,
771 assignments,
772 from,
773 selection,
774 ..
775 }) = stmt
776 else {
777 return Err(SQLRiteError::Internal(
778 "execute_update called on a non-UPDATE statement".to_string(),
779 ));
780 };
781
782 if from.is_some() {
783 return Err(SQLRiteError::NotImplemented(
784 "UPDATE ... FROM is not supported yet".to_string(),
785 ));
786 }
787
788 let table_name = extract_table_name(table)?;
789
790 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
792 {
793 let tbl = db
794 .get_table(table_name.clone())
795 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
796 for a in assignments {
797 let col = match &a.target {
798 AssignmentTarget::ColumnName(name) => name
799 .0
800 .last()
801 .map(|p| p.to_string())
802 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
803 AssignmentTarget::Tuple(_) => {
804 return Err(SQLRiteError::NotImplemented(
805 "tuple assignment targets are not supported".to_string(),
806 ));
807 }
808 };
809 if !tbl.contains_column(col.clone()) {
810 return Err(SQLRiteError::Internal(format!(
811 "UPDATE references unknown column '{col}'"
812 )));
813 }
814 parsed_assignments.push((col, a.value.clone()));
815 }
816 }
817
818 let work: Vec<(i64, Vec<(String, Value)>)> = {
822 let tbl = db.get_table(table_name.clone())?;
823 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
824 RowidSource::IndexProbe(rowids) => rowids,
825 RowidSource::FullScan => {
826 let mut out = Vec::new();
827 for rowid in tbl.rowids() {
828 if let Some(expr) = selection {
829 if !eval_predicate(expr, tbl, rowid)? {
830 continue;
831 }
832 }
833 out.push(rowid);
834 }
835 out
836 }
837 };
838 let mut rows_to_update = Vec::new();
839 for rowid in matched_rowids {
840 let mut values = Vec::with_capacity(parsed_assignments.len());
841 for (col, expr) in &parsed_assignments {
842 let v = eval_expr(expr, tbl, rowid)?;
845 values.push((col.clone(), v));
846 }
847 rows_to_update.push((rowid, values));
848 }
849 rows_to_update
850 };
851
852 let tbl = db.get_table_mut(table_name)?;
853 for (rowid, values) in &work {
854 for (col, v) in values {
855 tbl.set_value(col, *rowid, v.clone())?;
856 }
857 }
858
859 if !work.is_empty() {
868 let updated_columns: std::collections::HashSet<&str> = work
869 .iter()
870 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
871 .collect();
872 for entry in &mut tbl.hnsw_indexes {
873 if updated_columns.contains(entry.column_name.as_str()) {
874 entry.needs_rebuild = true;
875 }
876 }
877 for entry in &mut tbl.fts_indexes {
878 if updated_columns.contains(entry.column_name.as_str()) {
879 entry.needs_rebuild = true;
880 }
881 }
882 }
883 Ok(work.len())
884}
885
886pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
898 let Statement::CreateIndex(CreateIndex {
899 name,
900 table_name,
901 columns,
902 using,
903 unique,
904 if_not_exists,
905 predicate,
906 with,
907 ..
908 }) = stmt
909 else {
910 return Err(SQLRiteError::Internal(
911 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
912 ));
913 };
914
915 if predicate.is_some() {
916 return Err(SQLRiteError::NotImplemented(
917 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
918 ));
919 }
920
921 if columns.len() != 1 {
922 return Err(SQLRiteError::NotImplemented(format!(
923 "multi-column indexes are not supported yet ({} columns given)",
924 columns.len()
925 )));
926 }
927
928 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
929 SQLRiteError::NotImplemented(
930 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
931 )
932 })?;
933
934 let method = match using {
940 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
941 IndexMethod::Hnsw
942 }
943 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
944 IndexMethod::Fts
945 }
946 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
947 IndexMethod::Btree
948 }
949 Some(other) => {
950 return Err(SQLRiteError::NotImplemented(format!(
951 "CREATE INDEX … USING {other:?} is not supported \
952 (try `hnsw`, `fts`, or no USING clause)"
953 )));
954 }
955 None => IndexMethod::Btree,
956 };
957
958 let hnsw_metric = parse_hnsw_with_options(with, &index_name, method)?;
964
965 let table_name_str = table_name.to_string();
966 let column_name = match &columns[0].column.expr {
967 Expr::Identifier(ident) => ident.value.clone(),
968 Expr::CompoundIdentifier(parts) => parts
969 .last()
970 .map(|p| p.value.clone())
971 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
972 other => {
973 return Err(SQLRiteError::NotImplemented(format!(
974 "CREATE INDEX only supports simple column references, got {other:?}"
975 )));
976 }
977 };
978
979 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
984 let table = db.get_table(table_name_str.clone()).map_err(|_| {
985 SQLRiteError::General(format!(
986 "CREATE INDEX references unknown table '{table_name_str}'"
987 ))
988 })?;
989 if !table.contains_column(column_name.clone()) {
990 return Err(SQLRiteError::General(format!(
991 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
992 )));
993 }
994 let col = table
995 .columns
996 .iter()
997 .find(|c| c.column_name == column_name)
998 .expect("we just verified the column exists");
999
1000 if table.index_by_name(&index_name).is_some()
1003 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
1004 || table.fts_indexes.iter().any(|i| i.name == index_name)
1005 {
1006 if *if_not_exists {
1007 return Ok(index_name);
1008 }
1009 return Err(SQLRiteError::General(format!(
1010 "index '{index_name}' already exists"
1011 )));
1012 }
1013 let datatype = clone_datatype(&col.datatype);
1014
1015 let mut pairs = Vec::new();
1016 for rowid in table.rowids() {
1017 if let Some(v) = table.get_value(&column_name, rowid) {
1018 pairs.push((rowid, v));
1019 }
1020 }
1021 (datatype, pairs)
1022 };
1023
1024 match method {
1025 IndexMethod::Btree => create_btree_index(
1026 db,
1027 &table_name_str,
1028 &index_name,
1029 &column_name,
1030 &datatype,
1031 *unique,
1032 &existing_rowids_and_values,
1033 ),
1034 IndexMethod::Hnsw => create_hnsw_index(
1035 db,
1036 &table_name_str,
1037 &index_name,
1038 &column_name,
1039 &datatype,
1040 *unique,
1041 hnsw_metric.unwrap_or(DistanceMetric::L2),
1042 &existing_rowids_and_values,
1043 ),
1044 IndexMethod::Fts => create_fts_index(
1045 db,
1046 &table_name_str,
1047 &index_name,
1048 &column_name,
1049 &datatype,
1050 *unique,
1051 &existing_rowids_and_values,
1052 ),
1053 }
1054}
1055
1056pub fn execute_drop_table(
1067 names: &[ObjectName],
1068 if_exists: bool,
1069 db: &mut Database,
1070) -> Result<usize> {
1071 if names.len() != 1 {
1072 return Err(SQLRiteError::NotImplemented(
1073 "DROP TABLE supports a single table per statement".to_string(),
1074 ));
1075 }
1076 let name = names[0].to_string();
1077
1078 if name == crate::sql::pager::MASTER_TABLE_NAME {
1079 return Err(SQLRiteError::General(format!(
1080 "'{}' is a reserved name used by the internal schema catalog",
1081 crate::sql::pager::MASTER_TABLE_NAME
1082 )));
1083 }
1084
1085 if !db.contains_table(name.clone()) {
1086 return if if_exists {
1087 Ok(0)
1088 } else {
1089 Err(SQLRiteError::General(format!(
1090 "Table '{name}' does not exist"
1091 )))
1092 };
1093 }
1094
1095 db.tables.remove(&name);
1096 Ok(1)
1097}
1098
1099pub fn execute_drop_index(
1108 names: &[ObjectName],
1109 if_exists: bool,
1110 db: &mut Database,
1111) -> Result<usize> {
1112 if names.len() != 1 {
1113 return Err(SQLRiteError::NotImplemented(
1114 "DROP INDEX supports a single index per statement".to_string(),
1115 ));
1116 }
1117 let name = names[0].to_string();
1118
1119 for table in db.tables.values_mut() {
1120 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1121 if secondary.origin == IndexOrigin::Auto {
1122 return Err(SQLRiteError::General(format!(
1123 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1124 )));
1125 }
1126 table.secondary_indexes.retain(|i| i.name != name);
1127 return Ok(1);
1128 }
1129 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1130 table.hnsw_indexes.retain(|i| i.name != name);
1131 return Ok(1);
1132 }
1133 if table.fts_indexes.iter().any(|i| i.name == name) {
1134 table.fts_indexes.retain(|i| i.name != name);
1135 return Ok(1);
1136 }
1137 }
1138
1139 if if_exists {
1140 Ok(0)
1141 } else {
1142 Err(SQLRiteError::General(format!(
1143 "Index '{name}' does not exist"
1144 )))
1145 }
1146}
1147
1148pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1160 let table_name = alter.name.to_string();
1161
1162 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1163 return Err(SQLRiteError::General(format!(
1164 "'{}' is a reserved name used by the internal schema catalog",
1165 crate::sql::pager::MASTER_TABLE_NAME
1166 )));
1167 }
1168
1169 if !db.contains_table(table_name.clone()) {
1170 return if alter.if_exists {
1171 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1172 } else {
1173 Err(SQLRiteError::General(format!(
1174 "Table '{table_name}' does not exist"
1175 )))
1176 };
1177 }
1178
1179 if alter.operations.len() != 1 {
1180 return Err(SQLRiteError::NotImplemented(
1181 "ALTER TABLE supports one operation per statement".to_string(),
1182 ));
1183 }
1184
1185 match &alter.operations[0] {
1186 AlterTableOperation::RenameTable { table_name: kind } => {
1187 let new_name = match kind {
1188 RenameTableNameKind::To(name) => name.to_string(),
1189 RenameTableNameKind::As(_) => {
1190 return Err(SQLRiteError::NotImplemented(
1191 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1192 .to_string(),
1193 ));
1194 }
1195 };
1196 alter_rename_table(db, &table_name, &new_name)?;
1197 Ok(format!(
1198 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1199 ))
1200 }
1201 AlterTableOperation::RenameColumn {
1202 old_column_name,
1203 new_column_name,
1204 } => {
1205 let old = old_column_name.value.clone();
1206 let new = new_column_name.value.clone();
1207 db.get_table_mut(table_name.clone())?
1208 .rename_column(&old, &new)?;
1209 Ok(format!(
1210 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1211 ))
1212 }
1213 AlterTableOperation::AddColumn {
1214 column_def,
1215 if_not_exists,
1216 ..
1217 } => {
1218 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1219 let table = db.get_table_mut(table_name.clone())?;
1220 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1221 return Ok(format!(
1222 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1223 parsed.name
1224 ));
1225 }
1226 let col_name = parsed.name.clone();
1227 table.add_column(parsed)?;
1228 Ok(format!(
1229 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1230 ))
1231 }
1232 AlterTableOperation::DropColumn {
1233 column_names,
1234 if_exists,
1235 ..
1236 } => {
1237 if column_names.len() != 1 {
1238 return Err(SQLRiteError::NotImplemented(
1239 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1240 ));
1241 }
1242 let col_name = column_names[0].value.clone();
1243 let table = db.get_table_mut(table_name.clone())?;
1244 if *if_exists && !table.contains_column(col_name.clone()) {
1245 return Ok(format!(
1246 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1247 ));
1248 }
1249 table.drop_column(&col_name)?;
1250 Ok(format!(
1251 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1252 ))
1253 }
1254 other => Err(SQLRiteError::NotImplemented(format!(
1255 "ALTER TABLE operation {other:?} is not supported"
1256 ))),
1257 }
1258}
1259
1260pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1270 if db.in_transaction() {
1271 return Err(SQLRiteError::General(
1272 "VACUUM cannot run inside a transaction".to_string(),
1273 ));
1274 }
1275 let path = match db.source_path.clone() {
1276 Some(p) => p,
1277 None => {
1278 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1279 }
1280 };
1281 if let Some(pager) = db.pager.as_mut() {
1287 let _ = pager.checkpoint();
1288 }
1289 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1290 let pages_before = db
1291 .pager
1292 .as_ref()
1293 .map(|p| p.header().page_count)
1294 .unwrap_or(0);
1295 crate::sql::pager::vacuum_database(db, &path)?;
1296 if let Some(pager) = db.pager.as_mut() {
1299 let _ = pager.checkpoint();
1300 }
1301 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1302 let pages_after = db
1303 .pager
1304 .as_ref()
1305 .map(|p| p.header().page_count)
1306 .unwrap_or(0);
1307 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1308 let bytes_reclaimed = size_before.saturating_sub(size_after);
1309 Ok(format!(
1310 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1311 ))
1312}
1313
1314fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1320 if new == crate::sql::pager::MASTER_TABLE_NAME {
1321 return Err(SQLRiteError::General(format!(
1322 "'{}' is a reserved name used by the internal schema catalog",
1323 crate::sql::pager::MASTER_TABLE_NAME
1324 )));
1325 }
1326 if old == new {
1327 return Ok(());
1328 }
1329 if db.contains_table(new.to_string()) {
1330 return Err(SQLRiteError::General(format!(
1331 "target table '{new}' already exists"
1332 )));
1333 }
1334
1335 let mut table = db
1336 .tables
1337 .remove(old)
1338 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1339 table.tb_name = new.to_string();
1340 for idx in table.secondary_indexes.iter_mut() {
1341 idx.table_name = new.to_string();
1342 if idx.origin == IndexOrigin::Auto
1343 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1344 {
1345 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1346 }
1347 }
1348 db.tables.insert(new.to_string(), table);
1349 Ok(())
1350}
1351
1352#[derive(Debug, Clone, Copy)]
1356enum IndexMethod {
1357 Btree,
1358 Hnsw,
1359 Fts,
1361}
1362
1363fn create_btree_index(
1365 db: &mut Database,
1366 table_name: &str,
1367 index_name: &str,
1368 column_name: &str,
1369 datatype: &DataType,
1370 unique: bool,
1371 existing: &[(i64, Value)],
1372) -> Result<String> {
1373 let mut idx = SecondaryIndex::new(
1374 index_name.to_string(),
1375 table_name.to_string(),
1376 column_name.to_string(),
1377 datatype,
1378 unique,
1379 IndexOrigin::Explicit,
1380 )?;
1381
1382 for (rowid, v) in existing {
1386 if unique && idx.would_violate_unique(v) {
1387 return Err(SQLRiteError::General(format!(
1388 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1389 already contains the duplicate value {}",
1390 v.to_display_string()
1391 )));
1392 }
1393 idx.insert(v, *rowid)?;
1394 }
1395
1396 let table_mut = db.get_table_mut(table_name.to_string())?;
1397 table_mut.secondary_indexes.push(idx);
1398 Ok(index_name.to_string())
1399}
1400
1401fn create_hnsw_index(
1403 db: &mut Database,
1404 table_name: &str,
1405 index_name: &str,
1406 column_name: &str,
1407 datatype: &DataType,
1408 unique: bool,
1409 metric: DistanceMetric,
1410 existing: &[(i64, Value)],
1411) -> Result<String> {
1412 let dim = match datatype {
1415 DataType::Vector(d) => *d,
1416 other => {
1417 return Err(SQLRiteError::General(format!(
1418 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1419 )));
1420 }
1421 };
1422
1423 if unique {
1424 return Err(SQLRiteError::General(
1425 "UNIQUE has no meaning for HNSW indexes".to_string(),
1426 ));
1427 }
1428
1429 let seed = hash_str_to_seed(index_name);
1440 let mut idx = HnswIndex::new(metric, seed);
1441
1442 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1446 std::collections::HashMap::with_capacity(existing.len());
1447 for (rowid, v) in existing {
1448 match v {
1449 Value::Vector(vec) => {
1450 if vec.len() != dim {
1451 return Err(SQLRiteError::Internal(format!(
1452 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1453 declared as VECTOR({dim}) — schema invariant violated",
1454 vec.len()
1455 )));
1456 }
1457 vec_map.insert(*rowid, vec.clone());
1458 }
1459 _ => continue,
1463 }
1464 }
1465
1466 for (rowid, _) in existing {
1467 if let Some(v) = vec_map.get(rowid) {
1468 let v_clone = v.clone();
1469 idx.insert(*rowid, &v_clone, |id| {
1470 vec_map.get(&id).cloned().unwrap_or_default()
1471 })?;
1472 }
1473 }
1474
1475 let table_mut = db.get_table_mut(table_name.to_string())?;
1476 table_mut.hnsw_indexes.push(HnswIndexEntry {
1477 name: index_name.to_string(),
1478 column_name: column_name.to_string(),
1479 metric,
1480 index: idx,
1481 needs_rebuild: false,
1483 });
1484 Ok(index_name.to_string())
1485}
1486
1487fn parse_hnsw_with_options(
1498 with: &[Expr],
1499 index_name: &str,
1500 method: IndexMethod,
1501) -> Result<Option<DistanceMetric>> {
1502 if with.is_empty() {
1503 return Ok(None);
1504 }
1505 if !matches!(method, IndexMethod::Hnsw) {
1506 return Err(SQLRiteError::General(format!(
1507 "CREATE INDEX '{index_name}' has a WITH (...) clause but its index method \
1508 doesn't support any options — only `USING hnsw` recognises `WITH (metric = ...)`"
1509 )));
1510 }
1511
1512 let mut metric: Option<DistanceMetric> = None;
1513 for opt in with {
1514 let Expr::BinaryOp { left, op, right } = opt else {
1515 return Err(SQLRiteError::General(format!(
1516 "CREATE INDEX '{index_name}': unsupported WITH option {opt:?} \
1517 (expected `key = 'value'`)"
1518 )));
1519 };
1520 if !matches!(op, BinaryOperator::Eq) {
1521 return Err(SQLRiteError::General(format!(
1522 "CREATE INDEX '{index_name}': WITH options must use `=` (got {op:?})"
1523 )));
1524 }
1525 let key = match left.as_ref() {
1526 Expr::Identifier(ident) => ident.value.clone(),
1527 other => {
1528 return Err(SQLRiteError::General(format!(
1529 "CREATE INDEX '{index_name}': WITH option key must be a bare identifier, \
1530 got {other:?}"
1531 )));
1532 }
1533 };
1534 let value = match right.as_ref() {
1535 Expr::Value(v) => match &v.value {
1536 AstValue::SingleQuotedString(s) => s.clone(),
1537 AstValue::DoubleQuotedString(s) => s.clone(),
1538 other => {
1539 return Err(SQLRiteError::General(format!(
1540 "CREATE INDEX '{index_name}': WITH option '{key}' value must be \
1541 a quoted string, got {other:?}"
1542 )));
1543 }
1544 },
1545 Expr::Identifier(ident) => ident.value.clone(),
1546 other => {
1547 return Err(SQLRiteError::General(format!(
1548 "CREATE INDEX '{index_name}': WITH option '{key}' value must be a \
1549 quoted string, got {other:?}"
1550 )));
1551 }
1552 };
1553
1554 if key.eq_ignore_ascii_case("metric") {
1555 let parsed = DistanceMetric::from_sql_name(&value).ok_or_else(|| {
1556 SQLRiteError::General(format!(
1557 "CREATE INDEX '{index_name}': unknown HNSW metric '{value}' \
1558 (try 'l2', 'cosine', or 'dot')"
1559 ))
1560 })?;
1561 if metric.is_some() {
1562 return Err(SQLRiteError::General(format!(
1563 "CREATE INDEX '{index_name}': metric specified more than once in WITH (...)"
1564 )));
1565 }
1566 metric = Some(parsed);
1567 } else {
1568 return Err(SQLRiteError::General(format!(
1569 "CREATE INDEX '{index_name}': unknown WITH option '{key}' \
1570 (only 'metric' is recognised on HNSW indexes)"
1571 )));
1572 }
1573 }
1574
1575 Ok(metric)
1576}
1577
1578fn create_fts_index(
1583 db: &mut Database,
1584 table_name: &str,
1585 index_name: &str,
1586 column_name: &str,
1587 datatype: &DataType,
1588 unique: bool,
1589 existing: &[(i64, Value)],
1590) -> Result<String> {
1591 match datatype {
1596 DataType::Text => {}
1597 other => {
1598 return Err(SQLRiteError::General(format!(
1599 "USING fts requires a TEXT column; '{column_name}' is {other}"
1600 )));
1601 }
1602 }
1603
1604 if unique {
1605 return Err(SQLRiteError::General(
1606 "UNIQUE has no meaning for FTS indexes".to_string(),
1607 ));
1608 }
1609
1610 let mut idx = PostingList::new();
1611 for (rowid, v) in existing {
1612 if let Value::Text(text) = v {
1613 idx.insert(*rowid, text);
1614 }
1615 }
1618
1619 let table_mut = db.get_table_mut(table_name.to_string())?;
1620 table_mut.fts_indexes.push(FtsIndexEntry {
1621 name: index_name.to_string(),
1622 column_name: column_name.to_string(),
1623 index: idx,
1624 needs_rebuild: false,
1625 });
1626 Ok(index_name.to_string())
1627}
1628
1629fn hash_str_to_seed(s: &str) -> u64 {
1633 let mut h: u64 = 0xCBF29CE484222325;
1634 for b in s.as_bytes() {
1635 h ^= *b as u64;
1636 h = h.wrapping_mul(0x100000001B3);
1637 }
1638 h
1639}
1640
1641fn clone_datatype(dt: &DataType) -> DataType {
1644 match dt {
1645 DataType::Integer => DataType::Integer,
1646 DataType::Text => DataType::Text,
1647 DataType::Real => DataType::Real,
1648 DataType::Bool => DataType::Bool,
1649 DataType::Vector(dim) => DataType::Vector(*dim),
1650 DataType::Json => DataType::Json,
1651 DataType::None => DataType::None,
1652 DataType::Invalid => DataType::Invalid,
1653 }
1654}
1655
1656fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1657 if tables.len() != 1 {
1658 return Err(SQLRiteError::NotImplemented(
1659 "multi-table DELETE is not supported yet".to_string(),
1660 ));
1661 }
1662 extract_table_name(&tables[0])
1663}
1664
1665fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1666 if !twj.joins.is_empty() {
1667 return Err(SQLRiteError::NotImplemented(
1668 "JOIN is not supported yet".to_string(),
1669 ));
1670 }
1671 match &twj.relation {
1672 TableFactor::Table { name, .. } => Ok(name.to_string()),
1673 _ => Err(SQLRiteError::NotImplemented(
1674 "only plain table references are supported".to_string(),
1675 )),
1676 }
1677}
1678
1679enum RowidSource {
1681 IndexProbe(Vec<i64>),
1685 FullScan,
1688}
1689
1690fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1695 let Some(expr) = selection else {
1696 return Ok(RowidSource::FullScan);
1697 };
1698 let Some((col, literal)) = try_extract_equality(expr) else {
1699 return Ok(RowidSource::FullScan);
1700 };
1701 let Some(idx) = table.index_for_column(&col) else {
1702 return Ok(RowidSource::FullScan);
1703 };
1704
1705 let literal_value = match convert_literal(&literal) {
1709 Ok(v) => v,
1710 Err(_) => return Ok(RowidSource::FullScan),
1711 };
1712
1713 let mut rowids = idx.lookup(&literal_value);
1717 rowids.sort_unstable();
1718 Ok(RowidSource::IndexProbe(rowids))
1719}
1720
1721fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1725 let peeled = match expr {
1727 Expr::Nested(inner) => inner.as_ref(),
1728 other => other,
1729 };
1730 let Expr::BinaryOp { left, op, right } = peeled else {
1731 return None;
1732 };
1733 if !matches!(op, BinaryOperator::Eq) {
1734 return None;
1735 }
1736 let col_from = |e: &Expr| -> Option<String> {
1737 match e {
1738 Expr::Identifier(ident) => Some(ident.value.clone()),
1739 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1740 _ => None,
1741 }
1742 };
1743 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1744 if let Expr::Value(v) = e {
1745 Some(v.value.clone())
1746 } else {
1747 None
1748 }
1749 };
1750 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1751 return Some((c, l));
1752 }
1753 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1754 return Some((c, l));
1755 }
1756 None
1757}
1758
1759fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1784 if k == 0 {
1785 return None;
1786 }
1787
1788 let func = match order_expr {
1791 Expr::Function(f) => f,
1792 _ => return None,
1793 };
1794 let fname = match func.name.0.as_slice() {
1795 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1796 _ => return None,
1797 };
1798 let query_metric = match fname.as_str() {
1799 "vec_distance_l2" => DistanceMetric::L2,
1800 "vec_distance_cosine" => DistanceMetric::Cosine,
1801 "vec_distance_dot" => DistanceMetric::Dot,
1802 _ => return None,
1803 };
1804
1805 let arg_list = match &func.args {
1807 FunctionArguments::List(l) => &l.args,
1808 _ => return None,
1809 };
1810 if arg_list.len() != 2 {
1811 return None;
1812 }
1813 let exprs: Vec<&Expr> = arg_list
1814 .iter()
1815 .filter_map(|a| match a {
1816 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1817 _ => None,
1818 })
1819 .collect();
1820 if exprs.len() != 2 {
1821 return None;
1822 }
1823
1824 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1829 Some(v) => v,
1830 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1831 Some(v) => v,
1832 None => return None,
1833 },
1834 };
1835
1836 let entry = table
1841 .hnsw_indexes
1842 .iter()
1843 .find(|e| e.column_name == col_name && e.metric == query_metric)?;
1844
1845 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
1851 Some(c) => match &c.datatype {
1852 DataType::Vector(d) => *d,
1853 _ => return None,
1854 },
1855 None => return None,
1856 };
1857 if query_vec.len() != declared_dim {
1858 return None;
1859 }
1860
1861 let column_for_closure = col_name.clone();
1865 let table_ref = table;
1866 let result = entry
1867 .index
1868 .search(&query_vec, k, |id| {
1869 match table_ref.get_value(&column_for_closure, id) {
1870 Some(Value::Vector(v)) => v,
1871 _ => Vec::new(),
1872 }
1873 })
1874 .ok()?;
1875 Some(result)
1876}
1877
1878fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
1894 if k == 0 || ascending {
1895 return None;
1899 }
1900
1901 let func = match order_expr {
1902 Expr::Function(f) => f,
1903 _ => return None,
1904 };
1905 let fname = match func.name.0.as_slice() {
1906 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1907 _ => return None,
1908 };
1909 if fname != "bm25_score" {
1910 return None;
1911 }
1912
1913 let arg_list = match &func.args {
1914 FunctionArguments::List(l) => &l.args,
1915 _ => return None,
1916 };
1917 if arg_list.len() != 2 {
1918 return None;
1919 }
1920 let exprs: Vec<&Expr> = arg_list
1921 .iter()
1922 .filter_map(|a| match a {
1923 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1924 _ => None,
1925 })
1926 .collect();
1927 if exprs.len() != 2 {
1928 return None;
1929 }
1930
1931 let col_name = match exprs[0] {
1933 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1934 _ => return None,
1935 };
1936
1937 let query = match exprs[1] {
1941 Expr::Value(v) => match &v.value {
1942 AstValue::SingleQuotedString(s) => s.clone(),
1943 _ => return None,
1944 },
1945 _ => return None,
1946 };
1947
1948 let entry = table
1949 .fts_indexes
1950 .iter()
1951 .find(|e| e.column_name == col_name)?;
1952
1953 let scored = entry.index.query(&query, &Bm25Params::default());
1954 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1955 if out.len() > k {
1956 out.truncate(k);
1957 }
1958 Some(out)
1959}
1960
1961fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1966 let col_name = match a {
1967 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1968 _ => return None,
1969 };
1970 let lit_str = match b {
1971 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1972 format!("[{}]", ident.value)
1973 }
1974 _ => return None,
1975 };
1976 let v = parse_vector_literal(&lit_str).ok()?;
1977 Some((col_name, v))
1978}
1979
1980struct HeapEntry {
1993 key: Value,
1994 rowid: i64,
1995 asc: bool,
1996}
1997
1998impl PartialEq for HeapEntry {
1999 fn eq(&self, other: &Self) -> bool {
2000 self.cmp(other) == Ordering::Equal
2001 }
2002}
2003
2004impl Eq for HeapEntry {}
2005
2006impl PartialOrd for HeapEntry {
2007 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2008 Some(self.cmp(other))
2009 }
2010}
2011
2012impl Ord for HeapEntry {
2013 fn cmp(&self, other: &Self) -> Ordering {
2014 let raw = compare_values(Some(&self.key), Some(&other.key));
2015 if self.asc { raw } else { raw.reverse() }
2016 }
2017}
2018
2019fn select_topk(
2028 matching: &[i64],
2029 table: &Table,
2030 order: &OrderByClause,
2031 k: usize,
2032) -> Result<Vec<i64>> {
2033 use std::collections::BinaryHeap;
2034
2035 if k == 0 || matching.is_empty() {
2036 return Ok(Vec::new());
2037 }
2038
2039 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
2040
2041 for &rowid in matching {
2042 let key = eval_expr(&order.expr, table, rowid)?;
2043 let entry = HeapEntry {
2044 key,
2045 rowid,
2046 asc: order.ascending,
2047 };
2048
2049 if heap.len() < k {
2050 heap.push(entry);
2051 } else {
2052 if entry < *heap.peek().unwrap() {
2056 heap.pop();
2057 heap.push(entry);
2058 }
2059 }
2060 }
2061
2062 Ok(heap
2067 .into_sorted_vec()
2068 .into_iter()
2069 .map(|e| e.rowid)
2070 .collect())
2071}
2072
2073fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
2074 let mut keys: Vec<(i64, Result<Value>)> = rowids
2082 .iter()
2083 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
2084 .collect();
2085
2086 for (_, k) in &keys {
2090 if let Err(e) = k {
2091 return Err(SQLRiteError::General(format!(
2092 "ORDER BY expression failed: {e}"
2093 )));
2094 }
2095 }
2096
2097 keys.sort_by(|(_, ka), (_, kb)| {
2098 let va = ka.as_ref().unwrap();
2101 let vb = kb.as_ref().unwrap();
2102 let ord = compare_values(Some(va), Some(vb));
2103 if order.ascending { ord } else { ord.reverse() }
2104 });
2105
2106 for (i, (rowid, _)) in keys.into_iter().enumerate() {
2108 rowids[i] = rowid;
2109 }
2110 Ok(())
2111}
2112
2113fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
2114 match (a, b) {
2115 (None, None) => Ordering::Equal,
2116 (None, _) => Ordering::Less,
2117 (_, None) => Ordering::Greater,
2118 (Some(a), Some(b)) => match (a, b) {
2119 (Value::Null, Value::Null) => Ordering::Equal,
2120 (Value::Null, _) => Ordering::Less,
2121 (_, Value::Null) => Ordering::Greater,
2122 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2123 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2124 (Value::Integer(x), Value::Real(y)) => {
2125 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2126 }
2127 (Value::Real(x), Value::Integer(y)) => {
2128 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2129 }
2130 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2131 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2132 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2134 },
2135 }
2136}
2137
2138pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
2140 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
2141}
2142
2143pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2147 let v = eval_expr_scope(expr, scope)?;
2148 match v {
2149 Value::Bool(b) => Ok(b),
2150 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2152 other => Err(SQLRiteError::Internal(format!(
2153 "WHERE clause must evaluate to boolean, got {}",
2154 other.to_display_string()
2155 ))),
2156 }
2157}
2158
2159fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
2161 eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
2162}
2163
2164fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2165 match expr {
2166 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2167
2168 Expr::Identifier(ident) => {
2169 if ident.quote_style == Some('[') {
2179 let raw = format!("[{}]", ident.value);
2180 let v = parse_vector_literal(&raw)?;
2181 return Ok(Value::Vector(v));
2182 }
2183 scope.lookup(None, &ident.value)
2184 }
2185
2186 Expr::CompoundIdentifier(parts) => {
2187 match parts.as_slice() {
2193 [only] => scope.lookup(None, &only.value),
2194 [q, c] => scope.lookup(Some(&q.value), &c.value),
2195 _ => Err(SQLRiteError::NotImplemented(format!(
2196 "compound identifier with {} parts is not supported",
2197 parts.len()
2198 ))),
2199 }
2200 }
2201
2202 Expr::Value(v) => convert_literal(&v.value),
2203
2204 Expr::UnaryOp { op, expr } => {
2205 let inner = eval_expr_scope(expr, scope)?;
2206 match op {
2207 UnaryOperator::Not => match inner {
2208 Value::Bool(b) => Ok(Value::Bool(!b)),
2209 Value::Null => Ok(Value::Null),
2210 other => Err(SQLRiteError::Internal(format!(
2211 "NOT applied to non-boolean value: {}",
2212 other.to_display_string()
2213 ))),
2214 },
2215 UnaryOperator::Minus => match inner {
2216 Value::Integer(i) => Ok(Value::Integer(-i)),
2217 Value::Real(f) => Ok(Value::Real(-f)),
2218 Value::Null => Ok(Value::Null),
2219 other => Err(SQLRiteError::Internal(format!(
2220 "unary minus on non-numeric value: {}",
2221 other.to_display_string()
2222 ))),
2223 },
2224 UnaryOperator::Plus => Ok(inner),
2225 other => Err(SQLRiteError::NotImplemented(format!(
2226 "unary operator {other:?} is not supported"
2227 ))),
2228 }
2229 }
2230
2231 Expr::BinaryOp { left, op, right } => match op {
2232 BinaryOperator::And => {
2233 let l = eval_expr_scope(left, scope)?;
2234 let r = eval_expr_scope(right, scope)?;
2235 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2236 }
2237 BinaryOperator::Or => {
2238 let l = eval_expr_scope(left, scope)?;
2239 let r = eval_expr_scope(right, scope)?;
2240 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2241 }
2242 cmp @ (BinaryOperator::Eq
2243 | BinaryOperator::NotEq
2244 | BinaryOperator::Lt
2245 | BinaryOperator::LtEq
2246 | BinaryOperator::Gt
2247 | BinaryOperator::GtEq) => {
2248 let l = eval_expr_scope(left, scope)?;
2249 let r = eval_expr_scope(right, scope)?;
2250 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2252 return Ok(Value::Bool(false));
2253 }
2254 let ord = compare_values(Some(&l), Some(&r));
2255 let result = match cmp {
2256 BinaryOperator::Eq => ord == Ordering::Equal,
2257 BinaryOperator::NotEq => ord != Ordering::Equal,
2258 BinaryOperator::Lt => ord == Ordering::Less,
2259 BinaryOperator::LtEq => ord != Ordering::Greater,
2260 BinaryOperator::Gt => ord == Ordering::Greater,
2261 BinaryOperator::GtEq => ord != Ordering::Less,
2262 _ => unreachable!(),
2263 };
2264 Ok(Value::Bool(result))
2265 }
2266 arith @ (BinaryOperator::Plus
2267 | BinaryOperator::Minus
2268 | BinaryOperator::Multiply
2269 | BinaryOperator::Divide
2270 | BinaryOperator::Modulo) => {
2271 let l = eval_expr_scope(left, scope)?;
2272 let r = eval_expr_scope(right, scope)?;
2273 eval_arith(arith, &l, &r)
2274 }
2275 BinaryOperator::StringConcat => {
2276 let l = eval_expr_scope(left, scope)?;
2277 let r = eval_expr_scope(right, scope)?;
2278 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2279 return Ok(Value::Null);
2280 }
2281 Ok(Value::Text(format!(
2282 "{}{}",
2283 l.to_display_string(),
2284 r.to_display_string()
2285 )))
2286 }
2287 other => Err(SQLRiteError::NotImplemented(format!(
2288 "binary operator {other:?} is not supported yet"
2289 ))),
2290 },
2291
2292 Expr::IsNull(inner) => {
2300 let v = eval_expr_scope(inner, scope)?;
2301 Ok(Value::Bool(matches!(v, Value::Null)))
2302 }
2303 Expr::IsNotNull(inner) => {
2304 let v = eval_expr_scope(inner, scope)?;
2305 Ok(Value::Bool(!matches!(v, Value::Null)))
2306 }
2307
2308 Expr::Like {
2315 negated,
2316 any,
2317 expr: lhs,
2318 pattern,
2319 escape_char,
2320 } => eval_like(
2321 scope,
2322 *negated,
2323 *any,
2324 lhs,
2325 pattern,
2326 escape_char.as_ref(),
2327 true,
2328 ),
2329 Expr::ILike {
2330 negated,
2331 any,
2332 expr: lhs,
2333 pattern,
2334 escape_char,
2335 } => eval_like(
2336 scope,
2337 *negated,
2338 *any,
2339 lhs,
2340 pattern,
2341 escape_char.as_ref(),
2342 true,
2343 ),
2344
2345 Expr::InList {
2351 expr: lhs,
2352 list,
2353 negated,
2354 } => eval_in_list(scope, lhs, list, *negated),
2355 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2356 "IN (subquery) is not supported (only literal lists are)".to_string(),
2357 )),
2358
2359 Expr::Function(func) => eval_function(func, scope),
2370
2371 other => Err(SQLRiteError::NotImplemented(format!(
2372 "unsupported expression in WHERE/projection: {other:?}"
2373 ))),
2374 }
2375}
2376
2377fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2382 let name = match func.name.0.as_slice() {
2385 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2386 _ => {
2387 return Err(SQLRiteError::NotImplemented(format!(
2388 "qualified function names not supported: {:?}",
2389 func.name
2390 )));
2391 }
2392 };
2393
2394 match name.as_str() {
2395 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2396 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2397 let dist = match name.as_str() {
2398 "vec_distance_l2" => vec_distance_l2(&a, &b),
2399 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2400 "vec_distance_dot" => vec_distance_dot(&a, &b),
2401 _ => unreachable!(),
2402 };
2403 Ok(Value::Real(dist as f64))
2409 }
2410 "json_extract" => json_fn_extract(&name, &func.args, scope),
2415 "json_type" => json_fn_type(&name, &func.args, scope),
2416 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2417 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2418 "fts_match" | "bm25_score" => {
2429 let Some((table, rowid)) = scope.single_table_view() else {
2430 return Err(SQLRiteError::NotImplemented(format!(
2431 "{name}() is not yet supported inside a JOIN query — \
2432 use it on a single-table SELECT or move the FTS lookup into a subquery"
2433 )));
2434 };
2435 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2436 Ok(match name.as_str() {
2437 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2438 "bm25_score" => {
2439 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2440 }
2441 _ => unreachable!(),
2442 })
2443 }
2444 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2448 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2449 use it as a top-level projection item (HAVING is not yet supported)"
2450 ))),
2451 other => Err(SQLRiteError::NotImplemented(format!(
2452 "unknown function: {other}(...)"
2453 ))),
2454 }
2455}
2456
2457fn resolve_fts_args<'t>(
2462 fn_name: &str,
2463 args: &FunctionArguments,
2464 table: &'t Table,
2465 scope: &dyn RowScope,
2466) -> Result<(&'t FtsIndexEntry, String)> {
2467 let arg_list = match args {
2468 FunctionArguments::List(l) => &l.args,
2469 _ => {
2470 return Err(SQLRiteError::General(format!(
2471 "{fn_name}() expects exactly two arguments: (column, query_text)"
2472 )));
2473 }
2474 };
2475 if arg_list.len() != 2 {
2476 return Err(SQLRiteError::General(format!(
2477 "{fn_name}() expects exactly 2 arguments, got {}",
2478 arg_list.len()
2479 )));
2480 }
2481
2482 let col_expr = match &arg_list[0] {
2486 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2487 other => {
2488 return Err(SQLRiteError::NotImplemented(format!(
2489 "{fn_name}() argument 0 must be a column name, got {other:?}"
2490 )));
2491 }
2492 };
2493 let col_name = match col_expr {
2494 Expr::Identifier(ident) => ident.value.clone(),
2495 Expr::CompoundIdentifier(parts) => parts
2496 .last()
2497 .map(|p| p.value.clone())
2498 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
2499 other => {
2500 return Err(SQLRiteError::General(format!(
2501 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2502 )));
2503 }
2504 };
2505
2506 let q_expr = match &arg_list[1] {
2510 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2511 other => {
2512 return Err(SQLRiteError::NotImplemented(format!(
2513 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2514 )));
2515 }
2516 };
2517 let query = match eval_expr_scope(q_expr, scope)? {
2518 Value::Text(s) => s,
2519 other => {
2520 return Err(SQLRiteError::General(format!(
2521 "{fn_name}() argument 1 must be TEXT, got {}",
2522 other.to_display_string()
2523 )));
2524 }
2525 };
2526
2527 let entry = table
2528 .fts_indexes
2529 .iter()
2530 .find(|e| e.column_name == col_name)
2531 .ok_or_else(|| {
2532 SQLRiteError::General(format!(
2533 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2534 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2535 ))
2536 })?;
2537 Ok((entry, query))
2538}
2539
2540fn extract_json_and_path(
2554 fn_name: &str,
2555 args: &FunctionArguments,
2556 scope: &dyn RowScope,
2557) -> Result<(String, String)> {
2558 let arg_list = match args {
2559 FunctionArguments::List(l) => &l.args,
2560 _ => {
2561 return Err(SQLRiteError::General(format!(
2562 "{fn_name}() expects 1 or 2 arguments"
2563 )));
2564 }
2565 };
2566 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2567 return Err(SQLRiteError::General(format!(
2568 "{fn_name}() expects 1 or 2 arguments, got {}",
2569 arg_list.len()
2570 )));
2571 }
2572 let first_expr = match &arg_list[0] {
2574 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2575 other => {
2576 return Err(SQLRiteError::NotImplemented(format!(
2577 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2578 )));
2579 }
2580 };
2581 let json_text = match eval_expr_scope(first_expr, scope)? {
2582 Value::Text(s) => s,
2583 Value::Null => {
2584 return Err(SQLRiteError::General(format!(
2585 "{fn_name}() called on NULL — JSON column has no value for this row"
2586 )));
2587 }
2588 other => {
2589 return Err(SQLRiteError::General(format!(
2590 "{fn_name}() argument 0 is not JSON-typed: got {}",
2591 other.to_display_string()
2592 )));
2593 }
2594 };
2595
2596 let path = if arg_list.len() == 2 {
2598 let path_expr = match &arg_list[1] {
2599 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2600 other => {
2601 return Err(SQLRiteError::NotImplemented(format!(
2602 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2603 )));
2604 }
2605 };
2606 match eval_expr_scope(path_expr, scope)? {
2607 Value::Text(s) => s,
2608 other => {
2609 return Err(SQLRiteError::General(format!(
2610 "{fn_name}() path argument must be a string literal, got {}",
2611 other.to_display_string()
2612 )));
2613 }
2614 }
2615 } else {
2616 "$".to_string()
2617 };
2618
2619 Ok((json_text, path))
2620}
2621
2622fn walk_json_path<'a>(
2632 value: &'a serde_json::Value,
2633 path: &str,
2634) -> Result<Option<&'a serde_json::Value>> {
2635 let mut chars = path.chars().peekable();
2636 if chars.next() != Some('$') {
2637 return Err(SQLRiteError::General(format!(
2638 "JSON path must start with '$', got `{path}`"
2639 )));
2640 }
2641 let mut current = value;
2642 while let Some(&c) = chars.peek() {
2643 match c {
2644 '.' => {
2645 chars.next();
2646 let mut key = String::new();
2647 while let Some(&c) = chars.peek() {
2648 if c == '.' || c == '[' {
2649 break;
2650 }
2651 key.push(c);
2652 chars.next();
2653 }
2654 if key.is_empty() {
2655 return Err(SQLRiteError::General(format!(
2656 "JSON path has empty key after '.' in `{path}`"
2657 )));
2658 }
2659 match current.get(&key) {
2660 Some(v) => current = v,
2661 None => return Ok(None),
2662 }
2663 }
2664 '[' => {
2665 chars.next();
2666 let mut idx_str = String::new();
2667 while let Some(&c) = chars.peek() {
2668 if c == ']' {
2669 break;
2670 }
2671 idx_str.push(c);
2672 chars.next();
2673 }
2674 if chars.next() != Some(']') {
2675 return Err(SQLRiteError::General(format!(
2676 "JSON path has unclosed `[` in `{path}`"
2677 )));
2678 }
2679 let idx: usize = idx_str.trim().parse().map_err(|_| {
2680 SQLRiteError::General(format!(
2681 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2682 ))
2683 })?;
2684 match current.get(idx) {
2685 Some(v) => current = v,
2686 None => return Ok(None),
2687 }
2688 }
2689 other => {
2690 return Err(SQLRiteError::General(format!(
2691 "JSON path has unexpected character `{other}` in `{path}` \
2692 (expected `.`, `[`, or end-of-path)"
2693 )));
2694 }
2695 }
2696 }
2697 Ok(Some(current))
2698}
2699
2700fn json_value_to_sql(v: &serde_json::Value) -> Value {
2704 match v {
2705 serde_json::Value::Null => Value::Null,
2706 serde_json::Value::Bool(b) => Value::Bool(*b),
2707 serde_json::Value::Number(n) => {
2708 if let Some(i) = n.as_i64() {
2710 Value::Integer(i)
2711 } else if let Some(f) = n.as_f64() {
2712 Value::Real(f)
2713 } else {
2714 Value::Null
2715 }
2716 }
2717 serde_json::Value::String(s) => Value::Text(s.clone()),
2718 composite => Value::Text(composite.to_string()),
2722 }
2723}
2724
2725fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2726 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2727 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2728 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2729 })?;
2730 match walk_json_path(&parsed, &path)? {
2731 Some(v) => Ok(json_value_to_sql(v)),
2732 None => Ok(Value::Null),
2733 }
2734}
2735
2736fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2737 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2738 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2739 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2740 })?;
2741 let resolved = match walk_json_path(&parsed, &path)? {
2742 Some(v) => v,
2743 None => return Ok(Value::Null),
2744 };
2745 let ty = match resolved {
2746 serde_json::Value::Null => "null",
2747 serde_json::Value::Bool(true) => "true",
2748 serde_json::Value::Bool(false) => "false",
2749 serde_json::Value::Number(n) => {
2750 if n.is_i64() || n.is_u64() {
2751 "integer"
2752 } else {
2753 "real"
2754 }
2755 }
2756 serde_json::Value::String(_) => "text",
2757 serde_json::Value::Array(_) => "array",
2758 serde_json::Value::Object(_) => "object",
2759 };
2760 Ok(Value::Text(ty.to_string()))
2761}
2762
2763fn json_fn_array_length(
2764 name: &str,
2765 args: &FunctionArguments,
2766 scope: &dyn RowScope,
2767) -> Result<Value> {
2768 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2769 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2770 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2771 })?;
2772 let resolved = match walk_json_path(&parsed, &path)? {
2773 Some(v) => v,
2774 None => return Ok(Value::Null),
2775 };
2776 match resolved.as_array() {
2777 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2778 None => Err(SQLRiteError::General(format!(
2779 "{name}() resolved to a non-array value at path `{path}`"
2780 ))),
2781 }
2782}
2783
2784fn json_fn_object_keys(
2785 name: &str,
2786 args: &FunctionArguments,
2787 scope: &dyn RowScope,
2788) -> Result<Value> {
2789 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2790 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2791 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2792 })?;
2793 let resolved = match walk_json_path(&parsed, &path)? {
2794 Some(v) => v,
2795 None => return Ok(Value::Null),
2796 };
2797 let obj = resolved.as_object().ok_or_else(|| {
2798 SQLRiteError::General(format!(
2799 "{name}() resolved to a non-object value at path `{path}`"
2800 ))
2801 })?;
2802 let keys: Vec<serde_json::Value> = obj
2809 .keys()
2810 .map(|k| serde_json::Value::String(k.clone()))
2811 .collect();
2812 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2813}
2814
2815fn extract_two_vector_args(
2819 fn_name: &str,
2820 args: &FunctionArguments,
2821 scope: &dyn RowScope,
2822) -> Result<(Vec<f32>, Vec<f32>)> {
2823 let arg_list = match args {
2824 FunctionArguments::List(l) => &l.args,
2825 _ => {
2826 return Err(SQLRiteError::General(format!(
2827 "{fn_name}() expects exactly two vector arguments"
2828 )));
2829 }
2830 };
2831 if arg_list.len() != 2 {
2832 return Err(SQLRiteError::General(format!(
2833 "{fn_name}() expects exactly 2 arguments, got {}",
2834 arg_list.len()
2835 )));
2836 }
2837 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
2838 for (i, arg) in arg_list.iter().enumerate() {
2839 let expr = match arg {
2840 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2841 other => {
2842 return Err(SQLRiteError::NotImplemented(format!(
2843 "{fn_name}() argument {i} has unsupported shape: {other:?}"
2844 )));
2845 }
2846 };
2847 let val = eval_expr_scope(expr, scope)?;
2848 match val {
2849 Value::Vector(v) => out.push(v),
2850 other => {
2851 return Err(SQLRiteError::General(format!(
2852 "{fn_name}() argument {i} is not a vector: got {}",
2853 other.to_display_string()
2854 )));
2855 }
2856 }
2857 }
2858 let b = out.pop().unwrap();
2859 let a = out.pop().unwrap();
2860 if a.len() != b.len() {
2861 return Err(SQLRiteError::General(format!(
2862 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
2863 a.len(),
2864 b.len()
2865 )));
2866 }
2867 Ok((a, b))
2868}
2869
2870pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
2873 debug_assert_eq!(a.len(), b.len());
2874 let mut sum = 0.0f32;
2875 for i in 0..a.len() {
2876 let d = a[i] - b[i];
2877 sum += d * d;
2878 }
2879 sum.sqrt()
2880}
2881
2882pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
2892 debug_assert_eq!(a.len(), b.len());
2893 let mut dot = 0.0f32;
2894 let mut norm_a_sq = 0.0f32;
2895 let mut norm_b_sq = 0.0f32;
2896 for i in 0..a.len() {
2897 dot += a[i] * b[i];
2898 norm_a_sq += a[i] * a[i];
2899 norm_b_sq += b[i] * b[i];
2900 }
2901 let denom = (norm_a_sq * norm_b_sq).sqrt();
2902 if denom == 0.0 {
2903 return Err(SQLRiteError::General(
2904 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
2905 ));
2906 }
2907 Ok(1.0 - dot / denom)
2908}
2909
2910pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
2914 debug_assert_eq!(a.len(), b.len());
2915 let mut dot = 0.0f32;
2916 for i in 0..a.len() {
2917 dot += a[i] * b[i];
2918 }
2919 -dot
2920}
2921
2922fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
2925 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2926 return Ok(Value::Null);
2927 }
2928 match (l, r) {
2929 (Value::Integer(a), Value::Integer(b)) => match op {
2930 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
2931 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
2932 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
2933 BinaryOperator::Divide => {
2934 if *b == 0 {
2935 Err(SQLRiteError::General("division by zero".to_string()))
2936 } else {
2937 Ok(Value::Integer(a / b))
2938 }
2939 }
2940 BinaryOperator::Modulo => {
2941 if *b == 0 {
2942 Err(SQLRiteError::General("modulo by zero".to_string()))
2943 } else {
2944 Ok(Value::Integer(a % b))
2945 }
2946 }
2947 _ => unreachable!(),
2948 },
2949 (a, b) => {
2951 let af = as_number(a)?;
2952 let bf = as_number(b)?;
2953 match op {
2954 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
2955 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
2956 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
2957 BinaryOperator::Divide => {
2958 if bf == 0.0 {
2959 Err(SQLRiteError::General("division by zero".to_string()))
2960 } else {
2961 Ok(Value::Real(af / bf))
2962 }
2963 }
2964 BinaryOperator::Modulo => {
2965 if bf == 0.0 {
2966 Err(SQLRiteError::General("modulo by zero".to_string()))
2967 } else {
2968 Ok(Value::Real(af % bf))
2969 }
2970 }
2971 _ => unreachable!(),
2972 }
2973 }
2974 }
2975}
2976
2977fn as_number(v: &Value) -> Result<f64> {
2978 match v {
2979 Value::Integer(i) => Ok(*i as f64),
2980 Value::Real(f) => Ok(*f),
2981 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
2982 other => Err(SQLRiteError::General(format!(
2983 "arithmetic on non-numeric value '{}'",
2984 other.to_display_string()
2985 ))),
2986 }
2987}
2988
2989fn as_bool(v: &Value) -> Result<bool> {
2990 match v {
2991 Value::Bool(b) => Ok(*b),
2992 Value::Null => Ok(false),
2993 Value::Integer(i) => Ok(*i != 0),
2994 other => Err(SQLRiteError::Internal(format!(
2995 "expected boolean, got {}",
2996 other.to_display_string()
2997 ))),
2998 }
2999}
3000
3001#[allow(clippy::too_many_arguments)]
3006fn eval_like(
3007 scope: &dyn RowScope,
3008 negated: bool,
3009 any: bool,
3010 lhs: &Expr,
3011 pattern: &Expr,
3012 escape_char: Option<&AstValue>,
3013 case_insensitive: bool,
3014) -> Result<Value> {
3015 if any {
3016 return Err(SQLRiteError::NotImplemented(
3017 "LIKE ANY (...) is not supported".to_string(),
3018 ));
3019 }
3020 if escape_char.is_some() {
3021 return Err(SQLRiteError::NotImplemented(
3022 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
3023 ));
3024 }
3025
3026 let l = eval_expr_scope(lhs, scope)?;
3027 let p = eval_expr_scope(pattern, scope)?;
3028 if matches!(l, Value::Null) || matches!(p, Value::Null) {
3029 return Ok(Value::Null);
3030 }
3031 let text = match l {
3032 Value::Text(s) => s,
3033 other => other.to_display_string(),
3034 };
3035 let pat = match p {
3036 Value::Text(s) => s,
3037 other => other.to_display_string(),
3038 };
3039 let m = like_match(&text, &pat, case_insensitive);
3040 Ok(Value::Bool(if negated { !m } else { m }))
3041}
3042
3043fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
3044 let l = eval_expr_scope(lhs, scope)?;
3045 if matches!(l, Value::Null) {
3046 return Ok(Value::Null);
3047 }
3048 let mut saw_null = false;
3049 for item in list {
3050 let r = eval_expr_scope(item, scope)?;
3051 if matches!(r, Value::Null) {
3052 saw_null = true;
3053 continue;
3054 }
3055 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
3056 return Ok(Value::Bool(!negated));
3057 }
3058 }
3059 if saw_null {
3060 Ok(Value::Null)
3063 } else {
3064 Ok(Value::Bool(negated))
3065 }
3066}
3067
3068fn aggregate_rows(
3079 table: &Table,
3080 matching: &[i64],
3081 group_by: &[String],
3082 proj_items: &[ProjectionItem],
3083) -> Result<Vec<Vec<Value>>> {
3084 let template: Vec<Option<AggState>> = proj_items
3088 .iter()
3089 .map(|i| match &i.kind {
3090 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
3091 ProjectionKind::Column { .. } => None,
3092 })
3093 .collect();
3094
3095 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
3101 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
3102 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
3103
3104 for &rowid in matching {
3105 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
3106 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
3107 for col in group_by {
3108 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
3109 key.push(DistinctKey::from_value(&v));
3110 key_values.push(v);
3111 }
3112 let idx = match keys.iter().position(|k| k == &key) {
3113 Some(i) => i,
3114 None => {
3115 keys.push(key);
3116 group_states.push(template.clone());
3117 group_key_values.push(key_values);
3118 keys.len() - 1
3119 }
3120 };
3121
3122 for (slot, item) in proj_items.iter().enumerate() {
3123 if let ProjectionKind::Aggregate(call) = &item.kind {
3124 let v = match &call.arg {
3125 AggregateArg::Star => Value::Null,
3126 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
3127 };
3128 if let Some(state) = group_states[idx][slot].as_mut() {
3129 state.update(&v)?;
3130 }
3131 }
3132 }
3133 }
3134
3135 if keys.is_empty() && group_by.is_empty() {
3141 keys.push(Vec::new());
3144 group_states.push(template.clone());
3145 group_key_values.push(Vec::new());
3146 }
3147
3148 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3150 for (group_idx, _) in keys.iter().enumerate() {
3151 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3152 for (slot, item) in proj_items.iter().enumerate() {
3153 match &item.kind {
3154 ProjectionKind::Column { name: c, .. } => {
3155 let pos = group_by
3158 .iter()
3159 .position(|g| g == c)
3160 .expect("validated to be in GROUP BY");
3161 row.push(group_key_values[group_idx][pos].clone());
3162 }
3163 ProjectionKind::Aggregate(_) => {
3164 let state = group_states[group_idx][slot]
3165 .as_ref()
3166 .expect("aggregate slot has state");
3167 row.push(state.finalize());
3168 }
3169 }
3170 }
3171 rows.push(row);
3172 }
3173 Ok(rows)
3174}
3175
3176fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3180 use std::collections::HashSet;
3181 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3182 let mut out = Vec::with_capacity(rows.len());
3183 for row in rows {
3184 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3185 if seen.insert(key) {
3186 out.push(row);
3187 }
3188 }
3189 out
3190}
3191
3192fn sort_output_rows(
3196 rows: &mut [Vec<Value>],
3197 columns: &[String],
3198 proj_items: &[ProjectionItem],
3199 order: &OrderByClause,
3200) -> Result<()> {
3201 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3202 rows.sort_by(|a, b| {
3203 let va = &a[target_idx];
3204 let vb = &b[target_idx];
3205 let ord = compare_values(Some(va), Some(vb));
3206 if order.ascending { ord } else { ord.reverse() }
3207 });
3208 Ok(())
3209}
3210
3211fn resolve_order_by_index(
3214 expr: &Expr,
3215 columns: &[String],
3216 proj_items: &[ProjectionItem],
3217) -> Result<usize> {
3218 let target_name: Option<String> = match expr {
3220 Expr::Identifier(ident) => Some(ident.value.clone()),
3221 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3222 Expr::Function(_) => None,
3223 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3224 other => {
3225 return Err(SQLRiteError::NotImplemented(format!(
3226 "ORDER BY expression not supported on aggregating queries: {other:?}"
3227 )));
3228 }
3229 };
3230 if let Some(name) = target_name {
3231 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3232 return Ok(i);
3233 }
3234 return Err(SQLRiteError::Internal(format!(
3235 "ORDER BY references unknown column '{name}' in the SELECT output"
3236 )));
3237 }
3238 if let Expr::Function(func) = expr {
3242 let user_disp = format_function_display(func);
3243 for (i, item) in proj_items.iter().enumerate() {
3244 if let ProjectionKind::Aggregate(call) = &item.kind
3245 && call.display_name().eq_ignore_ascii_case(&user_disp)
3246 {
3247 return Ok(i);
3248 }
3249 }
3250 return Err(SQLRiteError::Internal(format!(
3251 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3252 )));
3253 }
3254 Err(SQLRiteError::Internal(
3255 "ORDER BY expression could not be resolved against the output columns".to_string(),
3256 ))
3257}
3258
3259fn format_function_display(func: &sqlparser::ast::Function) -> String {
3263 let name = match func.name.0.as_slice() {
3264 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3265 _ => format!("{:?}", func.name).to_uppercase(),
3266 };
3267 let inner = match &func.args {
3268 FunctionArguments::List(l) => {
3269 let distinct = matches!(
3270 l.duplicate_treatment,
3271 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3272 );
3273 let arg = l.args.first().map(|a| match a {
3274 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3275 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3276 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3277 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3278 }
3279 _ => String::new(),
3280 });
3281 match (distinct, arg) {
3282 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3283 (_, Some(a)) => a,
3284 _ => String::new(),
3285 }
3286 }
3287 _ => String::new(),
3288 };
3289 format!("{name}({inner})")
3290}
3291
3292fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3293 use sqlparser::ast::Value as AstValue;
3294 match v {
3295 AstValue::Number(n, _) => {
3296 if let Ok(i) = n.parse::<i64>() {
3297 Ok(Value::Integer(i))
3298 } else if let Ok(f) = n.parse::<f64>() {
3299 Ok(Value::Real(f))
3300 } else {
3301 Err(SQLRiteError::Internal(format!(
3302 "could not parse numeric literal '{n}'"
3303 )))
3304 }
3305 }
3306 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3307 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3308 AstValue::Null => Ok(Value::Null),
3309 other => Err(SQLRiteError::NotImplemented(format!(
3310 "unsupported literal value: {other:?}"
3311 ))),
3312 }
3313}
3314
3315#[cfg(test)]
3316mod tests {
3317 use super::*;
3318
3319 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3326 (a - b).abs() < eps
3327 }
3328
3329 #[test]
3330 fn vec_distance_l2_identical_is_zero() {
3331 let v = vec![0.1, 0.2, 0.3];
3332 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3333 }
3334
3335 #[test]
3336 fn vec_distance_l2_unit_basis_is_sqrt2() {
3337 let a = vec![1.0, 0.0];
3339 let b = vec![0.0, 1.0];
3340 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
3341 }
3342
3343 #[test]
3344 fn vec_distance_l2_known_value() {
3345 let a = vec![0.0, 0.0, 0.0];
3347 let b = vec![3.0, 4.0, 0.0];
3348 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
3349 }
3350
3351 #[test]
3352 fn vec_distance_cosine_identical_is_zero() {
3353 let v = vec![0.1, 0.2, 0.3];
3354 let d = vec_distance_cosine(&v, &v).unwrap();
3355 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
3356 }
3357
3358 #[test]
3359 fn vec_distance_cosine_orthogonal_is_one() {
3360 let a = vec![1.0, 0.0];
3363 let b = vec![0.0, 1.0];
3364 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
3365 }
3366
3367 #[test]
3368 fn vec_distance_cosine_opposite_is_two() {
3369 let a = vec![1.0, 0.0, 0.0];
3371 let b = vec![-1.0, 0.0, 0.0];
3372 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
3373 }
3374
3375 #[test]
3376 fn vec_distance_cosine_zero_magnitude_errors() {
3377 let a = vec![0.0, 0.0];
3379 let b = vec![1.0, 0.0];
3380 let err = vec_distance_cosine(&a, &b).unwrap_err();
3381 assert!(format!("{err}").contains("zero-magnitude"));
3382 }
3383
3384 #[test]
3385 fn vec_distance_dot_negates() {
3386 let a = vec![1.0, 2.0, 3.0];
3388 let b = vec![4.0, 5.0, 6.0];
3389 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
3390 }
3391
3392 #[test]
3393 fn vec_distance_dot_orthogonal_is_zero() {
3394 let a = vec![1.0, 0.0];
3396 let b = vec![0.0, 1.0];
3397 assert_eq!(vec_distance_dot(&a, &b), 0.0);
3398 }
3399
3400 #[test]
3401 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3402 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
3408 let cos = vec_distance_cosine(&a, &b).unwrap();
3409 assert!(approx_eq(dot, cos - 1.0, 1e-5));
3410 }
3411
3412 use crate::sql::db::database::Database;
3417 use crate::sql::dialect::SqlriteDialect;
3418 use crate::sql::parser::select::SelectQuery;
3419 use sqlparser::parser::Parser;
3420
3421 fn seed_score_table(n: usize) -> Database {
3434 let mut db = Database::new("tempdb".to_string());
3435 crate::sql::process_command(
3436 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
3437 &mut db,
3438 )
3439 .expect("create");
3440 for i in 0..n {
3441 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
3445 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
3446 crate::sql::process_command(&sql, &mut db).expect("insert");
3447 }
3448 db
3449 }
3450
3451 fn parse_select(sql: &str) -> SelectQuery {
3455 let dialect = SqlriteDialect::new();
3456 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
3457 let stmt = ast.pop().expect("one statement");
3458 SelectQuery::new(&stmt).expect("select-query")
3459 }
3460
3461 #[test]
3462 fn topk_matches_full_sort_asc() {
3463 let db = seed_score_table(200);
3466 let table = db.get_table("docs".to_string()).unwrap();
3467 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3468 let order = q.order_by.as_ref().unwrap();
3469 let all_rowids = table.rowids();
3470
3471 let mut full = all_rowids.clone();
3473 sort_rowids(&mut full, table, order).unwrap();
3474 full.truncate(10);
3475
3476 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3478
3479 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
3480 }
3481
3482 #[test]
3483 fn topk_matches_full_sort_desc() {
3484 let db = seed_score_table(200);
3486 let table = db.get_table("docs".to_string()).unwrap();
3487 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
3488 let order = q.order_by.as_ref().unwrap();
3489 let all_rowids = table.rowids();
3490
3491 let mut full = all_rowids.clone();
3492 sort_rowids(&mut full, table, order).unwrap();
3493 full.truncate(10);
3494
3495 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3496
3497 assert_eq!(
3498 topk, full,
3499 "top-k DESC via heap should match full-sort+truncate"
3500 );
3501 }
3502
3503 #[test]
3504 fn topk_k_larger_than_n_returns_everything_sorted() {
3505 let db = seed_score_table(50);
3510 let table = db.get_table("docs".to_string()).unwrap();
3511 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
3512 let order = q.order_by.as_ref().unwrap();
3513 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
3514 assert_eq!(topk.len(), 50);
3515 let scores: Vec<f64> = topk
3517 .iter()
3518 .filter_map(|r| match table.get_value("score", *r) {
3519 Some(Value::Real(f)) => Some(f),
3520 _ => None,
3521 })
3522 .collect();
3523 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
3524 }
3525
3526 #[test]
3527 fn topk_k_zero_returns_empty() {
3528 let db = seed_score_table(10);
3529 let table = db.get_table("docs".to_string()).unwrap();
3530 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
3531 let order = q.order_by.as_ref().unwrap();
3532 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
3533 assert!(topk.is_empty());
3534 }
3535
3536 #[test]
3537 fn topk_empty_input_returns_empty() {
3538 let db = seed_score_table(0);
3539 let table = db.get_table("docs".to_string()).unwrap();
3540 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
3541 let order = q.order_by.as_ref().unwrap();
3542 let topk = select_topk(&[], table, order, 5).unwrap();
3543 assert!(topk.is_empty());
3544 }
3545
3546 #[test]
3547 fn topk_works_through_select_executor_with_distance_function() {
3548 let mut db = Database::new("tempdb".to_string());
3552 crate::sql::process_command(
3553 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
3554 &mut db,
3555 )
3556 .unwrap();
3557 for v in &[
3564 "[1.0, 0.0]",
3565 "[2.0, 0.0]",
3566 "[0.0, 3.0]",
3567 "[1.0, 4.0]",
3568 "[10.0, 10.0]",
3569 ] {
3570 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3571 .unwrap();
3572 }
3573 let resp = crate::sql::process_command(
3574 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
3575 &mut db,
3576 )
3577 .unwrap();
3578 assert!(resp.contains("3 rows returned"), "got: {resp}");
3581 }
3582
3583 #[test]
3606 #[ignore]
3607 fn topk_benchmark() {
3608 use std::time::Instant;
3609 const N: usize = 10_000;
3610 const K: usize = 10;
3611
3612 let db = seed_score_table(N);
3613 let table = db.get_table("docs".to_string()).unwrap();
3614 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3615 let order = q.order_by.as_ref().unwrap();
3616 let all_rowids = table.rowids();
3617
3618 let t0 = Instant::now();
3620 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3621 let heap_dur = t0.elapsed();
3622
3623 let t1 = Instant::now();
3625 let mut full = all_rowids.clone();
3626 sort_rowids(&mut full, table, order).unwrap();
3627 full.truncate(K);
3628 let sort_dur = t1.elapsed();
3629
3630 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3631 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3632 println!(" bounded heap: {heap_dur:?}");
3633 println!(" full sort+trunc: {sort_dur:?}");
3634 println!(" speedup ratio: {ratio:.2}×");
3635
3636 assert!(
3643 ratio > 1.4,
3644 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
3645 );
3646 }
3647
3648 fn run_select(db: &mut Database, sql: &str) -> String {
3656 crate::sql::process_command(sql, db).expect("select")
3657 }
3658
3659 #[test]
3660 fn where_is_null_returns_null_rows() {
3661 let mut db = Database::new("t".to_string());
3662 crate::sql::process_command(
3663 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3664 &mut db,
3665 )
3666 .unwrap();
3667 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3668 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3669 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3670 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
3671
3672 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
3673 assert!(
3674 response.contains("2 rows returned"),
3675 "IS NULL should return 2 rows, got: {response}"
3676 );
3677 }
3678
3679 #[test]
3680 fn where_is_not_null_returns_non_null_rows() {
3681 let mut db = Database::new("t".to_string());
3682 crate::sql::process_command(
3683 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3684 &mut db,
3685 )
3686 .unwrap();
3687 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3688 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3689 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3690
3691 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
3692 assert!(
3693 response.contains("2 rows returned"),
3694 "IS NOT NULL should return 2 rows, got: {response}"
3695 );
3696 }
3697
3698 #[test]
3699 fn where_is_null_on_indexed_column() {
3700 let mut db = Database::new("t".to_string());
3705 crate::sql::process_command(
3706 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
3707 &mut db,
3708 )
3709 .unwrap();
3710 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
3711 .unwrap();
3712 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
3713 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
3714 .unwrap();
3715
3716 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
3717 assert!(
3718 null_rows.contains("1 row returned"),
3719 "indexed IS NULL should return 1 row, got: {null_rows}"
3720 );
3721 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
3722 assert!(
3723 not_null_rows.contains("2 rows returned"),
3724 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
3725 );
3726 }
3727
3728 #[test]
3729 fn where_is_null_works_on_omitted_column() {
3730 let mut db = Database::new("t".to_string());
3734 crate::sql::process_command(
3735 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
3736 &mut db,
3737 )
3738 .unwrap();
3739 crate::sql::process_command(
3740 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
3741 &mut db,
3742 )
3743 .unwrap();
3744 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
3746
3747 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
3748 assert!(
3749 response.contains("1 row returned"),
3750 "IS NULL should match the omitted-column row, got: {response}"
3751 );
3752 }
3753
3754 #[test]
3755 fn where_is_null_combines_with_and_or() {
3756 let mut db = Database::new("t".to_string());
3760 crate::sql::process_command(
3761 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3762 &mut db,
3763 )
3764 .unwrap();
3765 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
3766 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3767 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3768
3769 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
3770 assert!(
3771 response.contains("1 row returned"),
3772 "IS NULL combined with AND should match exactly row 2, got: {response}"
3773 );
3774 }
3775
3776 fn seed_employees() -> Database {
3782 let mut db = Database::new("t".to_string());
3783 crate::sql::process_command(
3784 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
3785 &mut db,
3786 )
3787 .unwrap();
3788 let rows = [
3789 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
3790 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
3791 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
3792 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
3793 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
3794 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
3795 ];
3796 for sql in rows {
3797 crate::sql::process_command(sql, &mut db).unwrap();
3798 }
3799 db
3800 }
3801
3802 fn run_rows(db: &Database, sql: &str) -> SelectResult {
3804 let q = parse_select(sql);
3805 execute_select_rows(q, db).expect("select")
3806 }
3807
3808 #[test]
3811 fn like_percent_prefix_case_insensitive() {
3812 let db = seed_employees();
3813 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
3814 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3816 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
3817 assert!(names.contains(&"Alice".to_string()));
3818 assert!(names.contains(&"alex".to_string()));
3819 }
3820
3821 #[test]
3822 fn like_underscore_singlechar() {
3823 let db = seed_employees();
3824 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
3825 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3827 assert_eq!(names, vec!["Eve".to_string()]);
3828 }
3829
3830 #[test]
3831 fn not_like_excludes_match() {
3832 let db = seed_employees();
3833 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
3834 assert_eq!(r.rows.len(), 4);
3836 }
3837
3838 #[test]
3839 fn like_with_null_excludes_row() {
3840 let db = seed_employees();
3841 let r = run_rows(
3843 &db,
3844 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
3845 );
3846 assert_eq!(r.rows.len(), 1);
3847 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
3848 }
3849
3850 #[test]
3853 fn in_list_positive() {
3854 let db = seed_employees();
3855 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
3856 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3857 assert_eq!(names.len(), 3);
3858 assert!(names.contains(&"Alice".to_string()));
3859 assert!(names.contains(&"Bob".to_string()));
3860 assert!(names.contains(&"Dave".to_string()));
3861 }
3862
3863 #[test]
3864 fn not_in_excludes_listed() {
3865 let db = seed_employees();
3866 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
3867 assert_eq!(r.rows.len(), 4);
3869 }
3870
3871 #[test]
3872 fn in_list_with_null_three_valued() {
3873 let db = seed_employees();
3874 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
3877 assert_eq!(r.rows.len(), 1);
3878 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
3879 }
3880
3881 #[test]
3884 fn distinct_single_column() {
3885 let db = seed_employees();
3886 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
3887 assert_eq!(r.rows.len(), 3);
3889 }
3890
3891 #[test]
3892 fn distinct_multi_column_with_null() {
3893 let db = seed_employees();
3894 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
3896 assert_eq!(r.rows.len(), 5);
3898 }
3899
3900 #[test]
3903 fn count_star_no_groupby() {
3904 let db = seed_employees();
3905 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
3906 assert_eq!(r.rows.len(), 1);
3907 assert_eq!(r.rows[0][0], Value::Integer(6));
3908 }
3909
3910 #[test]
3911 fn count_col_skips_nulls() {
3912 let db = seed_employees();
3913 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
3914 assert_eq!(r.rows[0][0], Value::Integer(5));
3916 }
3917
3918 #[test]
3919 fn count_distinct_dedupes_and_skips_nulls() {
3920 let db = seed_employees();
3921 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
3922 assert_eq!(r.rows[0][0], Value::Integer(4));
3924 }
3925
3926 #[test]
3927 fn sum_int_stays_integer() {
3928 let db = seed_employees();
3929 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
3930 assert_eq!(r.rows[0][0], Value::Integer(490));
3932 }
3933
3934 #[test]
3935 fn avg_returns_real() {
3936 let db = seed_employees();
3937 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
3938 match &r.rows[0][0] {
3940 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
3941 other => panic!("expected Real, got {other:?}"),
3942 }
3943 }
3944
3945 #[test]
3946 fn min_max_skip_nulls() {
3947 let db = seed_employees();
3948 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
3949 assert_eq!(r.rows[0][0], Value::Integer(80));
3950 assert_eq!(r.rows[0][1], Value::Integer(120));
3951 }
3952
3953 #[test]
3954 fn aggregates_on_empty_table_emit_one_row() {
3955 let mut db = Database::new("t".to_string());
3956 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3957 let r = run_rows(
3958 &db,
3959 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
3960 );
3961 assert_eq!(r.rows.len(), 1);
3962 assert_eq!(r.rows[0][0], Value::Integer(0));
3963 assert_eq!(r.rows[0][1], Value::Null);
3964 assert_eq!(r.rows[0][2], Value::Null);
3965 assert_eq!(r.rows[0][3], Value::Null);
3966 assert_eq!(r.rows[0][4], Value::Null);
3967 }
3968
3969 #[test]
3972 fn group_by_single_col_with_count() {
3973 let db = seed_employees();
3974 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
3975 assert_eq!(r.rows.len(), 3);
3976 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
3978 for row in &r.rows {
3979 let d = row[0].to_display_string();
3980 let c = match &row[1] {
3981 Value::Integer(i) => *i,
3982 v => panic!("expected Integer count, got {v:?}"),
3983 };
3984 by_dept.insert(d, c);
3985 }
3986 assert_eq!(by_dept["eng"], 3);
3987 assert_eq!(by_dept["sales"], 2);
3988 assert_eq!(by_dept["ops"], 1);
3989 }
3990
3991 #[test]
3992 fn group_by_with_where_filter() {
3993 let db = seed_employees();
3994 let r = run_rows(
3995 &db,
3996 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
3997 );
3998 let by: std::collections::HashMap<String, i64> = r
4001 .rows
4002 .iter()
4003 .map(|row| {
4004 (
4005 row[0].to_display_string(),
4006 match &row[1] {
4007 Value::Integer(i) => *i,
4008 v => panic!("expected Integer sum, got {v:?}"),
4009 },
4010 )
4011 })
4012 .collect();
4013 assert_eq!(by.len(), 2);
4014 assert_eq!(by["eng"], 320);
4015 assert_eq!(by["sales"], 90);
4016 }
4017
4018 #[test]
4019 fn group_by_without_aggregates_is_distinct() {
4020 let db = seed_employees();
4021 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
4022 assert_eq!(r.rows.len(), 3);
4023 }
4024
4025 #[test]
4026 fn order_by_count_desc() {
4027 let db = seed_employees();
4028 let r = run_rows(
4029 &db,
4030 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
4031 );
4032 assert_eq!(r.rows.len(), 2);
4033 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4035 assert_eq!(r.rows[0][1], Value::Integer(3));
4036 }
4037
4038 #[test]
4039 fn order_by_aggregate_call_form() {
4040 let db = seed_employees();
4041 let r = run_rows(
4043 &db,
4044 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
4045 );
4046 assert_eq!(r.rows.len(), 3);
4047 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4048 }
4049
4050 #[test]
4051 fn group_by_invalid_bare_column_errors() {
4052 let mut db = Database::new("t".to_string());
4054 crate::sql::process_command(
4055 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
4056 &mut db,
4057 )
4058 .unwrap();
4059 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
4060 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
4061 }
4062
4063 #[test]
4064 fn aggregate_in_where_errors_friendly() {
4065 let mut db = Database::new("t".to_string());
4066 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4067 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
4068 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
4069 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
4070 }
4071
4072 fn seed_join_fixture() -> Database {
4083 let mut db = Database::new("t".to_string());
4084 for sql in [
4085 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
4086 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
4087 "INSERT INTO customers (name) VALUES ('Alice');",
4088 "INSERT INTO customers (name) VALUES ('Bob');",
4089 "INSERT INTO customers (name) VALUES ('Carol');",
4090 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
4091 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
4092 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
4093 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
4094 ] {
4095 crate::sql::process_command(sql, &mut db).unwrap();
4096 }
4097 db
4098 }
4099
4100 #[test]
4101 fn inner_join_returns_only_matched_rows() {
4102 let db = seed_join_fixture();
4103 let r = run_rows(
4104 &db,
4105 "SELECT customers.name, orders.amount FROM customers \
4106 INNER JOIN orders ON customers.id = orders.customer_id;",
4107 );
4108 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4109 let pairs: Vec<(String, i64)> = r
4112 .rows
4113 .iter()
4114 .map(|row| {
4115 (
4116 row[0].to_display_string(),
4117 match row[1] {
4118 Value::Integer(i) => i,
4119 ref v => panic!("expected integer amount, got {v:?}"),
4120 },
4121 )
4122 })
4123 .collect();
4124 assert_eq!(pairs.len(), 3);
4125 assert!(pairs.contains(&("Alice".to_string(), 100)));
4126 assert!(pairs.contains(&("Alice".to_string(), 200)));
4127 assert!(pairs.contains(&("Bob".to_string(), 50)));
4128 }
4129
4130 #[test]
4131 fn bare_join_defaults_to_inner() {
4132 let db = seed_join_fixture();
4133 let r = run_rows(
4134 &db,
4135 "SELECT customers.name FROM customers \
4136 JOIN orders ON customers.id = orders.customer_id;",
4137 );
4138 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
4139 }
4140
4141 #[test]
4142 fn left_outer_join_preserves_unmatched_left() {
4143 let db = seed_join_fixture();
4144 let r = run_rows(
4145 &db,
4146 "SELECT customers.name, orders.amount FROM customers \
4147 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
4148 );
4149 assert_eq!(r.rows.len(), 4);
4152 let carol = r
4153 .rows
4154 .iter()
4155 .find(|row| row[0].to_display_string() == "Carol")
4156 .expect("Carol should appear with a NULL-padded right side");
4157 assert_eq!(carol[1], Value::Null);
4158 }
4159
4160 #[test]
4161 fn right_outer_join_preserves_unmatched_right() {
4162 let db = seed_join_fixture();
4163 let r = run_rows(
4164 &db,
4165 "SELECT customers.name, orders.amount FROM customers \
4166 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
4167 );
4168 assert_eq!(r.rows.len(), 4);
4172 let dangling = r
4173 .rows
4174 .iter()
4175 .find(|row| matches!(row[1], Value::Integer(999)))
4176 .expect("dangling order 999 should appear with a NULL-padded customer name");
4177 assert_eq!(dangling[0], Value::Null);
4178 }
4179
4180 #[test]
4181 fn full_outer_join_preserves_both_sides() {
4182 let db = seed_join_fixture();
4183 let r = run_rows(
4184 &db,
4185 "SELECT customers.name, orders.amount FROM customers \
4186 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
4187 );
4188 assert_eq!(r.rows.len(), 5);
4191 assert!(
4193 r.rows
4194 .iter()
4195 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
4196 );
4197 assert!(
4199 r.rows
4200 .iter()
4201 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
4202 );
4203 }
4204
4205 #[test]
4206 fn join_with_table_aliases_resolves_qualifiers() {
4207 let db = seed_join_fixture();
4208 let r = run_rows(
4209 &db,
4210 "SELECT c.name, o.amount FROM customers AS c \
4211 INNER JOIN orders AS o ON c.id = o.customer_id;",
4212 );
4213 assert_eq!(r.rows.len(), 3);
4214 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4215 }
4216
4217 #[test]
4218 fn join_with_where_filter_applies_after_join() {
4219 let db = seed_join_fixture();
4220 let r = run_rows(
4223 &db,
4224 "SELECT customers.name, orders.amount FROM customers \
4225 INNER JOIN orders ON customers.id = orders.customer_id \
4226 WHERE orders.amount >= 100;",
4227 );
4228 assert_eq!(r.rows.len(), 2);
4229 assert!(
4230 r.rows
4231 .iter()
4232 .all(|row| row[0].to_display_string() == "Alice")
4233 );
4234 }
4235
4236 #[test]
4237 fn left_join_with_where_on_right_side_is_not_inner() {
4238 let db = seed_join_fixture();
4242 let r = run_rows(
4243 &db,
4244 "SELECT customers.name, orders.amount FROM customers \
4245 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
4246 WHERE orders.amount IS NULL;",
4247 );
4248 assert_eq!(r.rows.len(), 1);
4250 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4251 assert_eq!(r.rows[0][1], Value::Null);
4252 }
4253
4254 #[test]
4255 fn select_star_over_join_emits_all_columns_from_both_tables() {
4256 let db = seed_join_fixture();
4257 let r = run_rows(
4258 &db,
4259 "SELECT * FROM customers \
4260 INNER JOIN orders ON customers.id = orders.customer_id;",
4261 );
4262 assert_eq!(
4266 r.columns,
4267 vec![
4268 "id".to_string(),
4269 "name".to_string(),
4270 "id".to_string(),
4271 "customer_id".to_string(),
4272 "amount".to_string(),
4273 ]
4274 );
4275 assert_eq!(r.rows.len(), 3);
4276 }
4277
4278 #[test]
4279 fn join_order_by_sorts_full_joined_rows() {
4280 let db = seed_join_fixture();
4281 let r = run_rows(
4282 &db,
4283 "SELECT c.name, o.amount FROM customers AS c \
4284 INNER JOIN orders AS o ON c.id = o.customer_id \
4285 ORDER BY o.amount;",
4286 );
4287 let amounts: Vec<i64> = r
4288 .rows
4289 .iter()
4290 .map(|row| match row[1] {
4291 Value::Integer(i) => i,
4292 ref v => panic!("expected integer, got {v:?}"),
4293 })
4294 .collect();
4295 assert_eq!(amounts, vec![50, 100, 200]);
4296 }
4297
4298 #[test]
4299 fn join_limit_truncates_after_join_and_sort() {
4300 let db = seed_join_fixture();
4301 let r = run_rows(
4302 &db,
4303 "SELECT c.name, o.amount FROM customers AS c \
4304 INNER JOIN orders AS o ON c.id = o.customer_id \
4305 ORDER BY o.amount DESC LIMIT 2;",
4306 );
4307 assert_eq!(r.rows.len(), 2);
4308 let amounts: Vec<i64> = r
4310 .rows
4311 .iter()
4312 .map(|row| match row[1] {
4313 Value::Integer(i) => i,
4314 ref v => panic!("expected integer, got {v:?}"),
4315 })
4316 .collect();
4317 assert_eq!(amounts, vec![200, 100]);
4318 }
4319
4320 #[test]
4321 fn three_table_join_chains_correctly() {
4322 let mut db = Database::new("t".to_string());
4323 for sql in [
4324 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4325 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4326 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4327 "INSERT INTO a (label) VALUES ('a-one');",
4328 "INSERT INTO a (label) VALUES ('a-two');",
4329 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4330 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
4331 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
4332 ] {
4333 crate::sql::process_command(sql, &mut db).unwrap();
4334 }
4335 let r = run_rows(
4336 &db,
4337 "SELECT a.label, b.tag, c.note FROM a \
4338 INNER JOIN b ON a.id = b.a_id \
4339 INNER JOIN c ON b.id = c.b_id;",
4340 );
4341 assert_eq!(r.rows.len(), 1);
4343 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
4344 assert_eq!(r.rows[0][1].to_display_string(), "b1");
4345 assert_eq!(r.rows[0][2].to_display_string(), "c1");
4346 }
4347
4348 #[test]
4349 fn ambiguous_unqualified_column_in_join_errors() {
4350 let db = seed_join_fixture();
4354 let q = parse_select(
4355 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
4356 );
4357 let res = execute_select_rows(q, &db);
4358 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
4359 }
4360
4361 #[test]
4362 fn join_self_without_alias_is_rejected() {
4363 let mut db = Database::new("t".to_string());
4364 crate::sql::process_command(
4365 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
4366 &mut db,
4367 )
4368 .unwrap();
4369 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
4370 let res = execute_select_rows(q, &db);
4371 assert!(
4372 res.is_err(),
4373 "self-join without an alias should error on duplicate qualifier"
4374 );
4375 }
4376
4377 #[test]
4378 fn using_or_natural_join_returns_not_implemented() {
4379 let mut db = Database::new("t".to_string());
4380 crate::sql::process_command("CREATE TABLE a (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4381 crate::sql::process_command("CREATE TABLE b (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4382 let err = crate::sql::process_command("SELECT * FROM a INNER JOIN b USING (id);", &mut db);
4383 assert!(err.is_err(), "USING is not yet supported");
4384
4385 let err = crate::sql::process_command("SELECT * FROM a NATURAL JOIN b;", &mut db);
4386 assert!(err.is_err(), "NATURAL is not supported");
4387 }
4388
4389 #[test]
4390 fn aggregates_over_join_are_rejected() {
4391 let db = seed_join_fixture();
4392 let err = crate::sql::process_command(
4393 "SELECT COUNT(*) FROM customers \
4394 INNER JOIN orders ON customers.id = orders.customer_id;",
4395 &mut seed_join_fixture(),
4396 );
4397 assert!(err.is_err(), "aggregates over JOIN are not yet supported");
4398 let _ = db; }
4400
4401 #[test]
4402 fn left_join_with_no_matches_pads_every_row() {
4403 let mut db = Database::new("t".to_string());
4404 for sql in [
4405 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4406 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4407 "INSERT INTO a (x) VALUES (1);",
4408 "INSERT INTO a (x) VALUES (2);",
4409 "INSERT INTO b (y) VALUES (10);",
4410 ] {
4411 crate::sql::process_command(sql, &mut db).unwrap();
4412 }
4413 let r = run_rows(
4415 &db,
4416 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
4417 );
4418 assert_eq!(r.rows.len(), 2);
4419 for row in &r.rows {
4420 assert_eq!(row[1], Value::Null);
4421 }
4422 }
4423
4424 #[test]
4425 fn left_outer_join_order_by_places_nulls_first() {
4426 let db = seed_join_fixture();
4431 let r = run_rows(
4432 &db,
4433 "SELECT c.name, o.amount FROM customers AS c \
4434 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
4435 ORDER BY o.amount ASC;",
4436 );
4437 assert_eq!(r.rows.len(), 4);
4438 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4440 assert_eq!(r.rows[0][1], Value::Null);
4441 }
4442
4443 #[test]
4444 fn chained_left_outer_join_preserves_left_through_two_levels() {
4445 let mut db = Database::new("t".to_string());
4448 for sql in [
4449 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4450 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4451 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4452 "INSERT INTO a (label) VALUES ('a-one');",
4453 "INSERT INTO a (label) VALUES ('a-two');",
4454 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4456 ] {
4458 crate::sql::process_command(sql, &mut db).unwrap();
4459 }
4460 let r = run_rows(
4461 &db,
4462 "SELECT a.label, b.tag, c.note FROM a \
4463 LEFT OUTER JOIN b ON a.id = b.a_id \
4464 LEFT OUTER JOIN c ON b.id = c.b_id;",
4465 );
4466 assert_eq!(r.rows.len(), 2);
4468 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
4469 .rows
4470 .iter()
4471 .map(|row| (row[0].to_display_string(), row))
4472 .collect();
4473 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
4474 assert_eq!(by_label["a-one"][2], Value::Null);
4475 assert_eq!(by_label["a-two"][1], Value::Null);
4476 assert_eq!(by_label["a-two"][2], Value::Null);
4477 }
4478
4479 #[test]
4480 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
4481 let mut db = Database::new("t".to_string());
4485 for sql in [
4486 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4487 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
4488 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
4489 "INSERT INTO a (x) VALUES (1);",
4490 "INSERT INTO b (x) VALUES (1);",
4491 "INSERT INTO c (x) VALUES (1);",
4492 ] {
4493 crate::sql::process_command(sql, &mut db).unwrap();
4494 }
4495 let q =
4496 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
4497 let res = execute_select_rows(q, &db);
4498 assert!(
4499 res.is_err(),
4500 "ON referencing not-yet-joined table 'c' should error"
4501 );
4502 }
4503
4504 #[test]
4505 fn join_on_truthy_integer_is_accepted() {
4506 let mut db = Database::new("t".to_string());
4510 for sql in [
4511 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4512 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4513 "INSERT INTO a (x) VALUES (1);",
4514 "INSERT INTO a (x) VALUES (2);",
4515 "INSERT INTO b (y) VALUES (10);",
4516 "INSERT INTO b (y) VALUES (20);",
4517 ] {
4518 crate::sql::process_command(sql, &mut db).unwrap();
4519 }
4520 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
4521 assert_eq!(r.rows.len(), 4);
4523 }
4524
4525 #[test]
4526 fn full_join_on_empty_tables_returns_empty() {
4527 let mut db = Database::new("t".to_string());
4528 for sql in [
4529 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4530 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4531 ] {
4532 crate::sql::process_command(sql, &mut db).unwrap();
4533 }
4534 let r = run_rows(
4535 &db,
4536 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
4537 );
4538 assert!(r.rows.is_empty());
4539 }
4540}