1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, JoinType, OrderByClause, Projection, ProjectionItem, ProjectionKind, SelectQuery,
25};
26
27pub(crate) trait RowScope {
56 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
57
58 fn single_table_view(&self) -> Option<(&Table, i64)>;
64}
65
66pub(crate) struct SingleTableScope<'a> {
68 table: &'a Table,
69 rowid: i64,
70}
71
72impl<'a> SingleTableScope<'a> {
73 pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
74 Self { table, rowid }
75 }
76}
77
78impl RowScope for SingleTableScope<'_> {
79 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
80 let _ = qualifier;
85 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
86 }
87
88 fn single_table_view(&self) -> Option<(&Table, i64)> {
89 Some((self.table, self.rowid))
90 }
91}
92
93pub(crate) struct JoinedTableRef<'a> {
97 pub table: &'a Table,
98 pub scope_name: String,
99}
100
101pub(crate) struct JoinedScope<'a> {
105 pub tables: &'a [JoinedTableRef<'a>],
106 pub rowids: &'a [Option<i64>],
107}
108
109impl RowScope for JoinedScope<'_> {
110 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
111 if let Some(q) = qualifier {
112 let pos = self
115 .tables
116 .iter()
117 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
118 .ok_or_else(|| {
119 SQLRiteError::Internal(format!(
120 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
121 ))
122 })?;
123 if !self.tables[pos].table.contains_column(col.to_string()) {
124 return Err(SQLRiteError::Internal(format!(
125 "column '{col}' does not exist on '{}'",
126 self.tables[pos].scope_name
127 )));
128 }
129 return Ok(match self.rowids[pos] {
130 None => Value::Null,
131 Some(r) => self.tables[pos]
132 .table
133 .get_value(col, r)
134 .unwrap_or(Value::Null),
135 });
136 }
137 let mut hit: Option<usize> = None;
141 for (i, t) in self.tables.iter().enumerate() {
142 if t.table.contains_column(col.to_string()) {
143 if hit.is_some() {
144 return Err(SQLRiteError::Internal(format!(
145 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
146 )));
147 }
148 hit = Some(i);
149 }
150 }
151 let i = hit.ok_or_else(|| {
152 SQLRiteError::Internal(format!(
153 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
154 ))
155 })?;
156 Ok(match self.rowids[i] {
157 None => Value::Null,
158 Some(r) => self.tables[i]
159 .table
160 .get_value(col, r)
161 .unwrap_or(Value::Null),
162 })
163 }
164
165 fn single_table_view(&self) -> Option<(&Table, i64)> {
166 None
167 }
168}
169
170pub struct SelectResult {
179 pub columns: Vec<String>,
180 pub rows: Vec<Vec<Value>>,
181}
182
183pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
187 if !query.joins.is_empty() {
192 return execute_select_rows_joined(query, db);
193 }
194
195 let table = db
196 .get_table(query.table_name.clone())
197 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
198
199 let proj_items: Vec<ProjectionItem> = match &query.projection {
204 Projection::All => table
205 .column_names()
206 .into_iter()
207 .map(|c| ProjectionItem {
208 kind: ProjectionKind::Column {
209 qualifier: None,
210 name: c,
211 },
212 alias: None,
213 })
214 .collect(),
215 Projection::Items(items) => items.clone(),
216 };
217 let has_aggregates = proj_items
218 .iter()
219 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
220 for item in &proj_items {
222 if let ProjectionKind::Column { name: c, .. } = &item.kind
223 && !table.contains_column(c.clone())
224 {
225 return Err(SQLRiteError::Internal(format!(
226 "Column '{c}' does not exist on table '{}'",
227 query.table_name
228 )));
229 }
230 }
231 for c in &query.group_by {
232 if !table.contains_column(c.clone()) {
233 return Err(SQLRiteError::Internal(format!(
234 "GROUP BY references unknown column '{c}' on table '{}'",
235 query.table_name
236 )));
237 }
238 }
239 let matching = match select_rowids(table, query.selection.as_ref())? {
243 RowidSource::IndexProbe(rowids) => rowids,
244 RowidSource::FullScan => {
245 let mut out = Vec::new();
246 for rowid in table.rowids() {
247 if let Some(expr) = &query.selection
248 && !eval_predicate(expr, table, rowid)?
249 {
250 continue;
251 }
252 out.push(rowid);
253 }
254 out
255 }
256 };
257 let mut matching = matching;
258
259 let aggregating = has_aggregates || !query.group_by.is_empty();
260
261 if aggregating {
267 for item in &proj_items {
269 if let ProjectionKind::Aggregate(call) = &item.kind
270 && let AggregateArg::Column(c) = &call.arg
271 && !table.contains_column(c.clone())
272 {
273 return Err(SQLRiteError::Internal(format!(
274 "{}({}) references unknown column '{c}' on table '{}'",
275 call.func.as_str(),
276 c,
277 query.table_name
278 )));
279 }
280 }
281
282 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
283 let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
284
285 if query.distinct {
286 rows = dedupe_rows(rows);
287 }
288
289 if let Some(order) = &query.order_by {
290 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
291 }
292 if let Some(k) = query.limit {
293 rows.truncate(k);
294 }
295
296 return Ok(SelectResult { columns, rows });
297 }
298
299 let defer_limit_for_distinct = query.distinct;
337 match (&query.order_by, query.limit) {
338 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
339 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
340 }
341 (Some(order), Some(k))
342 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
343 {
344 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
345 }
346 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
347 matching = select_topk(&matching, table, order, k)?;
348 }
349 (Some(order), _) => {
350 sort_rowids(&mut matching, table, order)?;
351 if let Some(k) = query.limit
352 && !defer_limit_for_distinct
353 {
354 matching.truncate(k);
355 }
356 }
357 (None, Some(k)) if !defer_limit_for_distinct => {
358 matching.truncate(k);
359 }
360 _ => {}
361 }
362
363 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
364 let projected_cols: Vec<String> = proj_items
365 .iter()
366 .map(|i| match &i.kind {
367 ProjectionKind::Column { name, .. } => name.clone(),
368 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
369 })
370 .collect();
371
372 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
376 for rowid in &matching {
377 let row: Vec<Value> = projected_cols
378 .iter()
379 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
380 .collect();
381 rows.push(row);
382 }
383
384 if query.distinct {
385 rows = dedupe_rows(rows);
386 if let Some(k) = query.limit {
387 rows.truncate(k);
388 }
389 }
390
391 Ok(SelectResult { columns, rows })
392}
393
394fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
421 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
428
429 let primary = db
430 .get_table(query.table_name.clone())
431 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
432 joined_tables.push(JoinedTableRef {
433 table: primary,
434 scope_name: query
435 .table_alias
436 .clone()
437 .unwrap_or_else(|| query.table_name.clone()),
438 });
439 for j in &query.joins {
440 let t = db
441 .get_table(j.right_table.clone())
442 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
443 joined_tables.push(JoinedTableRef {
444 table: t,
445 scope_name: j
446 .right_alias
447 .clone()
448 .unwrap_or_else(|| j.right_table.clone()),
449 });
450 }
451
452 {
457 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
458 for t in &joined_tables {
459 let key = t.scope_name.to_ascii_lowercase();
460 if !seen.insert(key) {
461 return Err(SQLRiteError::Internal(format!(
462 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
463 t.scope_name
464 )));
465 }
466 }
467 }
468
469 let proj_items: Vec<ProjectionItem> = match &query.projection {
475 Projection::All => {
476 let mut all = Vec::new();
485 for t in &joined_tables {
486 for col in t.table.column_names() {
487 all.push(ProjectionItem {
488 kind: ProjectionKind::Column {
489 qualifier: Some(t.scope_name.clone()),
494 name: col,
495 },
496 alias: None,
497 });
498 }
499 }
500 all
501 }
502 Projection::Items(items) => items.clone(),
503 };
504
505 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
506
507 let mut acc: Vec<Vec<Option<i64>>> = primary
512 .rowids()
513 .into_iter()
514 .map(|r| {
515 let mut row = Vec::with_capacity(joined_tables.len());
516 row.push(Some(r));
517 row
518 })
519 .collect();
520
521 for (j_idx, join) in query.joins.iter().enumerate() {
526 let right_pos = j_idx + 1;
527 let right_table = joined_tables[right_pos].table;
528 let right_rowids: Vec<i64> = right_table.rowids();
529
530 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
534
535 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
536
537 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
545
546 for left_row in acc.into_iter() {
547 let mut left_match_count = 0usize;
551 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
552 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
553 on_rowids.push(Some(rrid));
554 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
555 let scope = JoinedScope {
556 tables: on_scope_tables,
557 rowids: &on_rowids,
558 };
559 if eval_predicate_scope(&join.on, &scope)? {
564 left_match_count += 1;
565 right_matched[r_idx] = true;
566 next_acc.push(on_rowids);
571 }
572 }
573
574 if left_match_count == 0
575 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
576 {
577 let mut padded = left_row;
580 padded.push(None);
581 next_acc.push(padded);
582 }
583 }
584
585 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
589 for (r_idx, matched) in right_matched.iter().enumerate() {
590 if *matched {
591 continue;
592 }
593 let mut row: Vec<Option<i64>> = vec![None; right_pos];
594 row.push(Some(right_rowids[r_idx]));
595 next_acc.push(row);
596 }
597 }
598
599 acc = next_acc;
600 }
601
602 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
607 let mut out = Vec::with_capacity(acc.len());
608 for row in acc {
609 let scope = JoinedScope {
610 tables: &joined_tables,
611 rowids: &row,
612 };
613 if eval_predicate_scope(where_expr, &scope)? {
614 out.push(row);
615 }
616 }
617 out
618 } else {
619 acc
620 };
621
622 if let Some(order) = &query.order_by {
626 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
629 for (i, row) in filtered.iter().enumerate() {
630 let scope = JoinedScope {
631 tables: &joined_tables,
632 rowids: row,
633 };
634 let v = eval_expr_scope(&order.expr, &scope)?;
635 keys.push((i, v));
636 }
637 keys.sort_by(|(_, a), (_, b)| {
638 let ord = compare_values(Some(a), Some(b));
639 if order.ascending { ord } else { ord.reverse() }
640 });
641 let mut sorted = Vec::with_capacity(filtered.len());
642 for (i, _) in keys {
643 sorted.push(filtered[i].clone());
644 }
645 filtered = sorted;
646 }
647
648 if let Some(k) = query.limit {
650 filtered.truncate(k);
651 }
652
653 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
656 for row in &filtered {
657 let scope = JoinedScope {
658 tables: &joined_tables,
659 rowids: row,
660 };
661 let mut out_row = Vec::with_capacity(proj_items.len());
662 for item in &proj_items {
663 let v = match &item.kind {
664 ProjectionKind::Column { qualifier, name } => {
665 scope.lookup(qualifier.as_deref(), name)?
666 }
667 ProjectionKind::Aggregate(_) => {
668 return Err(SQLRiteError::Internal(
671 "aggregate functions over JOIN are not supported".to_string(),
672 ));
673 }
674 };
675 out_row.push(v);
676 }
677 rows.push(out_row);
678 }
679
680 Ok(SelectResult { columns, rows })
681}
682
683pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
688 let result = execute_select_rows(query, db)?;
689 let row_count = result.rows.len();
690
691 let mut print_table = PrintTable::new();
692 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
693 print_table.add_row(PrintRow::new(header_cells));
694
695 for row in &result.rows {
696 let cells: Vec<PrintCell> = row
697 .iter()
698 .map(|v| PrintCell::new(&v.to_display_string()))
699 .collect();
700 print_table.add_row(PrintRow::new(cells));
701 }
702
703 Ok((print_table.to_string(), row_count))
704}
705
706pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
708 let Statement::Delete(Delete {
709 from, selection, ..
710 }) = stmt
711 else {
712 return Err(SQLRiteError::Internal(
713 "execute_delete called on a non-DELETE statement".to_string(),
714 ));
715 };
716
717 let tables = match from {
718 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
719 };
720 let table_name = extract_single_table_name(tables)?;
721
722 let matching: Vec<i64> = {
724 let table = db
725 .get_table(table_name.clone())
726 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
727 match select_rowids(table, selection.as_ref())? {
728 RowidSource::IndexProbe(rowids) => rowids,
729 RowidSource::FullScan => {
730 let mut out = Vec::new();
731 for rowid in table.rowids() {
732 if let Some(expr) = selection {
733 if !eval_predicate(expr, table, rowid)? {
734 continue;
735 }
736 }
737 out.push(rowid);
738 }
739 out
740 }
741 }
742 };
743
744 let table = db.get_table_mut(table_name)?;
745 for rowid in &matching {
746 table.delete_row(*rowid);
747 }
748 if !matching.is_empty() {
757 for entry in &mut table.hnsw_indexes {
758 entry.needs_rebuild = true;
759 }
760 for entry in &mut table.fts_indexes {
761 entry.needs_rebuild = true;
762 }
763 }
764 Ok(matching.len())
765}
766
767pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
769 let Statement::Update(Update {
770 table,
771 assignments,
772 from,
773 selection,
774 ..
775 }) = stmt
776 else {
777 return Err(SQLRiteError::Internal(
778 "execute_update called on a non-UPDATE statement".to_string(),
779 ));
780 };
781
782 if from.is_some() {
783 return Err(SQLRiteError::NotImplemented(
784 "UPDATE ... FROM is not supported yet".to_string(),
785 ));
786 }
787
788 let table_name = extract_table_name(table)?;
789
790 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
792 {
793 let tbl = db
794 .get_table(table_name.clone())
795 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
796 for a in assignments {
797 let col = match &a.target {
798 AssignmentTarget::ColumnName(name) => name
799 .0
800 .last()
801 .map(|p| p.to_string())
802 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
803 AssignmentTarget::Tuple(_) => {
804 return Err(SQLRiteError::NotImplemented(
805 "tuple assignment targets are not supported".to_string(),
806 ));
807 }
808 };
809 if !tbl.contains_column(col.clone()) {
810 return Err(SQLRiteError::Internal(format!(
811 "UPDATE references unknown column '{col}'"
812 )));
813 }
814 parsed_assignments.push((col, a.value.clone()));
815 }
816 }
817
818 let work: Vec<(i64, Vec<(String, Value)>)> = {
822 let tbl = db.get_table(table_name.clone())?;
823 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
824 RowidSource::IndexProbe(rowids) => rowids,
825 RowidSource::FullScan => {
826 let mut out = Vec::new();
827 for rowid in tbl.rowids() {
828 if let Some(expr) = selection {
829 if !eval_predicate(expr, tbl, rowid)? {
830 continue;
831 }
832 }
833 out.push(rowid);
834 }
835 out
836 }
837 };
838 let mut rows_to_update = Vec::new();
839 for rowid in matched_rowids {
840 let mut values = Vec::with_capacity(parsed_assignments.len());
841 for (col, expr) in &parsed_assignments {
842 let v = eval_expr(expr, tbl, rowid)?;
845 values.push((col.clone(), v));
846 }
847 rows_to_update.push((rowid, values));
848 }
849 rows_to_update
850 };
851
852 let tbl = db.get_table_mut(table_name)?;
853 for (rowid, values) in &work {
854 for (col, v) in values {
855 tbl.set_value(col, *rowid, v.clone())?;
856 }
857 }
858
859 if !work.is_empty() {
868 let updated_columns: std::collections::HashSet<&str> = work
869 .iter()
870 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
871 .collect();
872 for entry in &mut tbl.hnsw_indexes {
873 if updated_columns.contains(entry.column_name.as_str()) {
874 entry.needs_rebuild = true;
875 }
876 }
877 for entry in &mut tbl.fts_indexes {
878 if updated_columns.contains(entry.column_name.as_str()) {
879 entry.needs_rebuild = true;
880 }
881 }
882 }
883 Ok(work.len())
884}
885
886pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
898 let Statement::CreateIndex(CreateIndex {
899 name,
900 table_name,
901 columns,
902 using,
903 unique,
904 if_not_exists,
905 predicate,
906 ..
907 }) = stmt
908 else {
909 return Err(SQLRiteError::Internal(
910 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
911 ));
912 };
913
914 if predicate.is_some() {
915 return Err(SQLRiteError::NotImplemented(
916 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
917 ));
918 }
919
920 if columns.len() != 1 {
921 return Err(SQLRiteError::NotImplemented(format!(
922 "multi-column indexes are not supported yet ({} columns given)",
923 columns.len()
924 )));
925 }
926
927 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
928 SQLRiteError::NotImplemented(
929 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
930 )
931 })?;
932
933 let method = match using {
939 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
940 IndexMethod::Hnsw
941 }
942 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
943 IndexMethod::Fts
944 }
945 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
946 IndexMethod::Btree
947 }
948 Some(other) => {
949 return Err(SQLRiteError::NotImplemented(format!(
950 "CREATE INDEX … USING {other:?} is not supported \
951 (try `hnsw`, `fts`, or no USING clause)"
952 )));
953 }
954 None => IndexMethod::Btree,
955 };
956
957 let table_name_str = table_name.to_string();
958 let column_name = match &columns[0].column.expr {
959 Expr::Identifier(ident) => ident.value.clone(),
960 Expr::CompoundIdentifier(parts) => parts
961 .last()
962 .map(|p| p.value.clone())
963 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
964 other => {
965 return Err(SQLRiteError::NotImplemented(format!(
966 "CREATE INDEX only supports simple column references, got {other:?}"
967 )));
968 }
969 };
970
971 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
976 let table = db.get_table(table_name_str.clone()).map_err(|_| {
977 SQLRiteError::General(format!(
978 "CREATE INDEX references unknown table '{table_name_str}'"
979 ))
980 })?;
981 if !table.contains_column(column_name.clone()) {
982 return Err(SQLRiteError::General(format!(
983 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
984 )));
985 }
986 let col = table
987 .columns
988 .iter()
989 .find(|c| c.column_name == column_name)
990 .expect("we just verified the column exists");
991
992 if table.index_by_name(&index_name).is_some()
995 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
996 || table.fts_indexes.iter().any(|i| i.name == index_name)
997 {
998 if *if_not_exists {
999 return Ok(index_name);
1000 }
1001 return Err(SQLRiteError::General(format!(
1002 "index '{index_name}' already exists"
1003 )));
1004 }
1005 let datatype = clone_datatype(&col.datatype);
1006
1007 let mut pairs = Vec::new();
1008 for rowid in table.rowids() {
1009 if let Some(v) = table.get_value(&column_name, rowid) {
1010 pairs.push((rowid, v));
1011 }
1012 }
1013 (datatype, pairs)
1014 };
1015
1016 match method {
1017 IndexMethod::Btree => create_btree_index(
1018 db,
1019 &table_name_str,
1020 &index_name,
1021 &column_name,
1022 &datatype,
1023 *unique,
1024 &existing_rowids_and_values,
1025 ),
1026 IndexMethod::Hnsw => create_hnsw_index(
1027 db,
1028 &table_name_str,
1029 &index_name,
1030 &column_name,
1031 &datatype,
1032 *unique,
1033 &existing_rowids_and_values,
1034 ),
1035 IndexMethod::Fts => create_fts_index(
1036 db,
1037 &table_name_str,
1038 &index_name,
1039 &column_name,
1040 &datatype,
1041 *unique,
1042 &existing_rowids_and_values,
1043 ),
1044 }
1045}
1046
1047pub fn execute_drop_table(
1058 names: &[ObjectName],
1059 if_exists: bool,
1060 db: &mut Database,
1061) -> Result<usize> {
1062 if names.len() != 1 {
1063 return Err(SQLRiteError::NotImplemented(
1064 "DROP TABLE supports a single table per statement".to_string(),
1065 ));
1066 }
1067 let name = names[0].to_string();
1068
1069 if name == crate::sql::pager::MASTER_TABLE_NAME {
1070 return Err(SQLRiteError::General(format!(
1071 "'{}' is a reserved name used by the internal schema catalog",
1072 crate::sql::pager::MASTER_TABLE_NAME
1073 )));
1074 }
1075
1076 if !db.contains_table(name.clone()) {
1077 return if if_exists {
1078 Ok(0)
1079 } else {
1080 Err(SQLRiteError::General(format!(
1081 "Table '{name}' does not exist"
1082 )))
1083 };
1084 }
1085
1086 db.tables.remove(&name);
1087 Ok(1)
1088}
1089
1090pub fn execute_drop_index(
1099 names: &[ObjectName],
1100 if_exists: bool,
1101 db: &mut Database,
1102) -> Result<usize> {
1103 if names.len() != 1 {
1104 return Err(SQLRiteError::NotImplemented(
1105 "DROP INDEX supports a single index per statement".to_string(),
1106 ));
1107 }
1108 let name = names[0].to_string();
1109
1110 for table in db.tables.values_mut() {
1111 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1112 if secondary.origin == IndexOrigin::Auto {
1113 return Err(SQLRiteError::General(format!(
1114 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1115 )));
1116 }
1117 table.secondary_indexes.retain(|i| i.name != name);
1118 return Ok(1);
1119 }
1120 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1121 table.hnsw_indexes.retain(|i| i.name != name);
1122 return Ok(1);
1123 }
1124 if table.fts_indexes.iter().any(|i| i.name == name) {
1125 table.fts_indexes.retain(|i| i.name != name);
1126 return Ok(1);
1127 }
1128 }
1129
1130 if if_exists {
1131 Ok(0)
1132 } else {
1133 Err(SQLRiteError::General(format!(
1134 "Index '{name}' does not exist"
1135 )))
1136 }
1137}
1138
1139pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1151 let table_name = alter.name.to_string();
1152
1153 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1154 return Err(SQLRiteError::General(format!(
1155 "'{}' is a reserved name used by the internal schema catalog",
1156 crate::sql::pager::MASTER_TABLE_NAME
1157 )));
1158 }
1159
1160 if !db.contains_table(table_name.clone()) {
1161 return if alter.if_exists {
1162 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1163 } else {
1164 Err(SQLRiteError::General(format!(
1165 "Table '{table_name}' does not exist"
1166 )))
1167 };
1168 }
1169
1170 if alter.operations.len() != 1 {
1171 return Err(SQLRiteError::NotImplemented(
1172 "ALTER TABLE supports one operation per statement".to_string(),
1173 ));
1174 }
1175
1176 match &alter.operations[0] {
1177 AlterTableOperation::RenameTable { table_name: kind } => {
1178 let new_name = match kind {
1179 RenameTableNameKind::To(name) => name.to_string(),
1180 RenameTableNameKind::As(_) => {
1181 return Err(SQLRiteError::NotImplemented(
1182 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1183 .to_string(),
1184 ));
1185 }
1186 };
1187 alter_rename_table(db, &table_name, &new_name)?;
1188 Ok(format!(
1189 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1190 ))
1191 }
1192 AlterTableOperation::RenameColumn {
1193 old_column_name,
1194 new_column_name,
1195 } => {
1196 let old = old_column_name.value.clone();
1197 let new = new_column_name.value.clone();
1198 db.get_table_mut(table_name.clone())?
1199 .rename_column(&old, &new)?;
1200 Ok(format!(
1201 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1202 ))
1203 }
1204 AlterTableOperation::AddColumn {
1205 column_def,
1206 if_not_exists,
1207 ..
1208 } => {
1209 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1210 let table = db.get_table_mut(table_name.clone())?;
1211 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1212 return Ok(format!(
1213 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1214 parsed.name
1215 ));
1216 }
1217 let col_name = parsed.name.clone();
1218 table.add_column(parsed)?;
1219 Ok(format!(
1220 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1221 ))
1222 }
1223 AlterTableOperation::DropColumn {
1224 column_names,
1225 if_exists,
1226 ..
1227 } => {
1228 if column_names.len() != 1 {
1229 return Err(SQLRiteError::NotImplemented(
1230 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1231 ));
1232 }
1233 let col_name = column_names[0].value.clone();
1234 let table = db.get_table_mut(table_name.clone())?;
1235 if *if_exists && !table.contains_column(col_name.clone()) {
1236 return Ok(format!(
1237 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1238 ));
1239 }
1240 table.drop_column(&col_name)?;
1241 Ok(format!(
1242 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1243 ))
1244 }
1245 other => Err(SQLRiteError::NotImplemented(format!(
1246 "ALTER TABLE operation {other:?} is not supported"
1247 ))),
1248 }
1249}
1250
1251pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1261 if db.in_transaction() {
1262 return Err(SQLRiteError::General(
1263 "VACUUM cannot run inside a transaction".to_string(),
1264 ));
1265 }
1266 let path = match db.source_path.clone() {
1267 Some(p) => p,
1268 None => {
1269 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1270 }
1271 };
1272 if let Some(pager) = db.pager.as_mut() {
1278 let _ = pager.checkpoint();
1279 }
1280 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1281 let pages_before = db
1282 .pager
1283 .as_ref()
1284 .map(|p| p.header().page_count)
1285 .unwrap_or(0);
1286 crate::sql::pager::vacuum_database(db, &path)?;
1287 if let Some(pager) = db.pager.as_mut() {
1290 let _ = pager.checkpoint();
1291 }
1292 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1293 let pages_after = db
1294 .pager
1295 .as_ref()
1296 .map(|p| p.header().page_count)
1297 .unwrap_or(0);
1298 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1299 let bytes_reclaimed = size_before.saturating_sub(size_after);
1300 Ok(format!(
1301 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1302 ))
1303}
1304
1305fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1311 if new == crate::sql::pager::MASTER_TABLE_NAME {
1312 return Err(SQLRiteError::General(format!(
1313 "'{}' is a reserved name used by the internal schema catalog",
1314 crate::sql::pager::MASTER_TABLE_NAME
1315 )));
1316 }
1317 if old == new {
1318 return Ok(());
1319 }
1320 if db.contains_table(new.to_string()) {
1321 return Err(SQLRiteError::General(format!(
1322 "target table '{new}' already exists"
1323 )));
1324 }
1325
1326 let mut table = db
1327 .tables
1328 .remove(old)
1329 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1330 table.tb_name = new.to_string();
1331 for idx in table.secondary_indexes.iter_mut() {
1332 idx.table_name = new.to_string();
1333 if idx.origin == IndexOrigin::Auto
1334 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1335 {
1336 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1337 }
1338 }
1339 db.tables.insert(new.to_string(), table);
1340 Ok(())
1341}
1342
1343#[derive(Debug, Clone, Copy)]
1347enum IndexMethod {
1348 Btree,
1349 Hnsw,
1350 Fts,
1352}
1353
1354fn create_btree_index(
1356 db: &mut Database,
1357 table_name: &str,
1358 index_name: &str,
1359 column_name: &str,
1360 datatype: &DataType,
1361 unique: bool,
1362 existing: &[(i64, Value)],
1363) -> Result<String> {
1364 let mut idx = SecondaryIndex::new(
1365 index_name.to_string(),
1366 table_name.to_string(),
1367 column_name.to_string(),
1368 datatype,
1369 unique,
1370 IndexOrigin::Explicit,
1371 )?;
1372
1373 for (rowid, v) in existing {
1377 if unique && idx.would_violate_unique(v) {
1378 return Err(SQLRiteError::General(format!(
1379 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1380 already contains the duplicate value {}",
1381 v.to_display_string()
1382 )));
1383 }
1384 idx.insert(v, *rowid)?;
1385 }
1386
1387 let table_mut = db.get_table_mut(table_name.to_string())?;
1388 table_mut.secondary_indexes.push(idx);
1389 Ok(index_name.to_string())
1390}
1391
1392fn create_hnsw_index(
1394 db: &mut Database,
1395 table_name: &str,
1396 index_name: &str,
1397 column_name: &str,
1398 datatype: &DataType,
1399 unique: bool,
1400 existing: &[(i64, Value)],
1401) -> Result<String> {
1402 let dim = match datatype {
1405 DataType::Vector(d) => *d,
1406 other => {
1407 return Err(SQLRiteError::General(format!(
1408 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1409 )));
1410 }
1411 };
1412
1413 if unique {
1414 return Err(SQLRiteError::General(
1415 "UNIQUE has no meaning for HNSW indexes".to_string(),
1416 ));
1417 }
1418
1419 let seed = hash_str_to_seed(index_name);
1427 let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
1428
1429 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1433 std::collections::HashMap::with_capacity(existing.len());
1434 for (rowid, v) in existing {
1435 match v {
1436 Value::Vector(vec) => {
1437 if vec.len() != dim {
1438 return Err(SQLRiteError::Internal(format!(
1439 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1440 declared as VECTOR({dim}) — schema invariant violated",
1441 vec.len()
1442 )));
1443 }
1444 vec_map.insert(*rowid, vec.clone());
1445 }
1446 _ => continue,
1450 }
1451 }
1452
1453 for (rowid, _) in existing {
1454 if let Some(v) = vec_map.get(rowid) {
1455 let v_clone = v.clone();
1456 idx.insert(*rowid, &v_clone, |id| {
1457 vec_map.get(&id).cloned().unwrap_or_default()
1458 });
1459 }
1460 }
1461
1462 let table_mut = db.get_table_mut(table_name.to_string())?;
1463 table_mut.hnsw_indexes.push(HnswIndexEntry {
1464 name: index_name.to_string(),
1465 column_name: column_name.to_string(),
1466 index: idx,
1467 needs_rebuild: false,
1469 });
1470 Ok(index_name.to_string())
1471}
1472
1473fn create_fts_index(
1478 db: &mut Database,
1479 table_name: &str,
1480 index_name: &str,
1481 column_name: &str,
1482 datatype: &DataType,
1483 unique: bool,
1484 existing: &[(i64, Value)],
1485) -> Result<String> {
1486 match datatype {
1491 DataType::Text => {}
1492 other => {
1493 return Err(SQLRiteError::General(format!(
1494 "USING fts requires a TEXT column; '{column_name}' is {other}"
1495 )));
1496 }
1497 }
1498
1499 if unique {
1500 return Err(SQLRiteError::General(
1501 "UNIQUE has no meaning for FTS indexes".to_string(),
1502 ));
1503 }
1504
1505 let mut idx = PostingList::new();
1506 for (rowid, v) in existing {
1507 if let Value::Text(text) = v {
1508 idx.insert(*rowid, text);
1509 }
1510 }
1513
1514 let table_mut = db.get_table_mut(table_name.to_string())?;
1515 table_mut.fts_indexes.push(FtsIndexEntry {
1516 name: index_name.to_string(),
1517 column_name: column_name.to_string(),
1518 index: idx,
1519 needs_rebuild: false,
1520 });
1521 Ok(index_name.to_string())
1522}
1523
1524fn hash_str_to_seed(s: &str) -> u64 {
1528 let mut h: u64 = 0xCBF29CE484222325;
1529 for b in s.as_bytes() {
1530 h ^= *b as u64;
1531 h = h.wrapping_mul(0x100000001B3);
1532 }
1533 h
1534}
1535
1536fn clone_datatype(dt: &DataType) -> DataType {
1539 match dt {
1540 DataType::Integer => DataType::Integer,
1541 DataType::Text => DataType::Text,
1542 DataType::Real => DataType::Real,
1543 DataType::Bool => DataType::Bool,
1544 DataType::Vector(dim) => DataType::Vector(*dim),
1545 DataType::Json => DataType::Json,
1546 DataType::None => DataType::None,
1547 DataType::Invalid => DataType::Invalid,
1548 }
1549}
1550
1551fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1552 if tables.len() != 1 {
1553 return Err(SQLRiteError::NotImplemented(
1554 "multi-table DELETE is not supported yet".to_string(),
1555 ));
1556 }
1557 extract_table_name(&tables[0])
1558}
1559
1560fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1561 if !twj.joins.is_empty() {
1562 return Err(SQLRiteError::NotImplemented(
1563 "JOIN is not supported yet".to_string(),
1564 ));
1565 }
1566 match &twj.relation {
1567 TableFactor::Table { name, .. } => Ok(name.to_string()),
1568 _ => Err(SQLRiteError::NotImplemented(
1569 "only plain table references are supported".to_string(),
1570 )),
1571 }
1572}
1573
1574enum RowidSource {
1576 IndexProbe(Vec<i64>),
1580 FullScan,
1583}
1584
1585fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1590 let Some(expr) = selection else {
1591 return Ok(RowidSource::FullScan);
1592 };
1593 let Some((col, literal)) = try_extract_equality(expr) else {
1594 return Ok(RowidSource::FullScan);
1595 };
1596 let Some(idx) = table.index_for_column(&col) else {
1597 return Ok(RowidSource::FullScan);
1598 };
1599
1600 let literal_value = match convert_literal(&literal) {
1604 Ok(v) => v,
1605 Err(_) => return Ok(RowidSource::FullScan),
1606 };
1607
1608 let mut rowids = idx.lookup(&literal_value);
1612 rowids.sort_unstable();
1613 Ok(RowidSource::IndexProbe(rowids))
1614}
1615
1616fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1620 let peeled = match expr {
1622 Expr::Nested(inner) => inner.as_ref(),
1623 other => other,
1624 };
1625 let Expr::BinaryOp { left, op, right } = peeled else {
1626 return None;
1627 };
1628 if !matches!(op, BinaryOperator::Eq) {
1629 return None;
1630 }
1631 let col_from = |e: &Expr| -> Option<String> {
1632 match e {
1633 Expr::Identifier(ident) => Some(ident.value.clone()),
1634 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1635 _ => None,
1636 }
1637 };
1638 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1639 if let Expr::Value(v) = e {
1640 Some(v.value.clone())
1641 } else {
1642 None
1643 }
1644 };
1645 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1646 return Some((c, l));
1647 }
1648 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1649 return Some((c, l));
1650 }
1651 None
1652}
1653
1654fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1676 if k == 0 {
1677 return None;
1678 }
1679
1680 let func = match order_expr {
1682 Expr::Function(f) => f,
1683 _ => return None,
1684 };
1685 let fname = match func.name.0.as_slice() {
1686 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1687 _ => return None,
1688 };
1689 if fname != "vec_distance_l2" {
1690 return None;
1691 }
1692
1693 let arg_list = match &func.args {
1695 FunctionArguments::List(l) => &l.args,
1696 _ => return None,
1697 };
1698 if arg_list.len() != 2 {
1699 return None;
1700 }
1701 let exprs: Vec<&Expr> = arg_list
1702 .iter()
1703 .filter_map(|a| match a {
1704 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1705 _ => None,
1706 })
1707 .collect();
1708 if exprs.len() != 2 {
1709 return None;
1710 }
1711
1712 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1717 Some(v) => v,
1718 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1719 Some(v) => v,
1720 None => return None,
1721 },
1722 };
1723
1724 let entry = table
1726 .hnsw_indexes
1727 .iter()
1728 .find(|e| e.column_name == col_name)?;
1729
1730 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
1736 Some(c) => match &c.datatype {
1737 DataType::Vector(d) => *d,
1738 _ => return None,
1739 },
1740 None => return None,
1741 };
1742 if query_vec.len() != declared_dim {
1743 return None;
1744 }
1745
1746 let column_for_closure = col_name.clone();
1750 let table_ref = table;
1751 let result = entry.index.search(&query_vec, k, |id| {
1752 match table_ref.get_value(&column_for_closure, id) {
1753 Some(Value::Vector(v)) => v,
1754 _ => Vec::new(),
1755 }
1756 });
1757 Some(result)
1758}
1759
1760fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
1776 if k == 0 || ascending {
1777 return None;
1781 }
1782
1783 let func = match order_expr {
1784 Expr::Function(f) => f,
1785 _ => return None,
1786 };
1787 let fname = match func.name.0.as_slice() {
1788 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1789 _ => return None,
1790 };
1791 if fname != "bm25_score" {
1792 return None;
1793 }
1794
1795 let arg_list = match &func.args {
1796 FunctionArguments::List(l) => &l.args,
1797 _ => return None,
1798 };
1799 if arg_list.len() != 2 {
1800 return None;
1801 }
1802 let exprs: Vec<&Expr> = arg_list
1803 .iter()
1804 .filter_map(|a| match a {
1805 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1806 _ => None,
1807 })
1808 .collect();
1809 if exprs.len() != 2 {
1810 return None;
1811 }
1812
1813 let col_name = match exprs[0] {
1815 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1816 _ => return None,
1817 };
1818
1819 let query = match exprs[1] {
1823 Expr::Value(v) => match &v.value {
1824 AstValue::SingleQuotedString(s) => s.clone(),
1825 _ => return None,
1826 },
1827 _ => return None,
1828 };
1829
1830 let entry = table
1831 .fts_indexes
1832 .iter()
1833 .find(|e| e.column_name == col_name)?;
1834
1835 let scored = entry.index.query(&query, &Bm25Params::default());
1836 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1837 if out.len() > k {
1838 out.truncate(k);
1839 }
1840 Some(out)
1841}
1842
1843fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1848 let col_name = match a {
1849 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1850 _ => return None,
1851 };
1852 let lit_str = match b {
1853 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1854 format!("[{}]", ident.value)
1855 }
1856 _ => return None,
1857 };
1858 let v = parse_vector_literal(&lit_str).ok()?;
1859 Some((col_name, v))
1860}
1861
1862struct HeapEntry {
1875 key: Value,
1876 rowid: i64,
1877 asc: bool,
1878}
1879
1880impl PartialEq for HeapEntry {
1881 fn eq(&self, other: &Self) -> bool {
1882 self.cmp(other) == Ordering::Equal
1883 }
1884}
1885
1886impl Eq for HeapEntry {}
1887
1888impl PartialOrd for HeapEntry {
1889 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1890 Some(self.cmp(other))
1891 }
1892}
1893
1894impl Ord for HeapEntry {
1895 fn cmp(&self, other: &Self) -> Ordering {
1896 let raw = compare_values(Some(&self.key), Some(&other.key));
1897 if self.asc { raw } else { raw.reverse() }
1898 }
1899}
1900
1901fn select_topk(
1910 matching: &[i64],
1911 table: &Table,
1912 order: &OrderByClause,
1913 k: usize,
1914) -> Result<Vec<i64>> {
1915 use std::collections::BinaryHeap;
1916
1917 if k == 0 || matching.is_empty() {
1918 return Ok(Vec::new());
1919 }
1920
1921 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
1922
1923 for &rowid in matching {
1924 let key = eval_expr(&order.expr, table, rowid)?;
1925 let entry = HeapEntry {
1926 key,
1927 rowid,
1928 asc: order.ascending,
1929 };
1930
1931 if heap.len() < k {
1932 heap.push(entry);
1933 } else {
1934 if entry < *heap.peek().unwrap() {
1938 heap.pop();
1939 heap.push(entry);
1940 }
1941 }
1942 }
1943
1944 Ok(heap
1949 .into_sorted_vec()
1950 .into_iter()
1951 .map(|e| e.rowid)
1952 .collect())
1953}
1954
1955fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
1956 let mut keys: Vec<(i64, Result<Value>)> = rowids
1964 .iter()
1965 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
1966 .collect();
1967
1968 for (_, k) in &keys {
1972 if let Err(e) = k {
1973 return Err(SQLRiteError::General(format!(
1974 "ORDER BY expression failed: {e}"
1975 )));
1976 }
1977 }
1978
1979 keys.sort_by(|(_, ka), (_, kb)| {
1980 let va = ka.as_ref().unwrap();
1983 let vb = kb.as_ref().unwrap();
1984 let ord = compare_values(Some(va), Some(vb));
1985 if order.ascending { ord } else { ord.reverse() }
1986 });
1987
1988 for (i, (rowid, _)) in keys.into_iter().enumerate() {
1990 rowids[i] = rowid;
1991 }
1992 Ok(())
1993}
1994
1995fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1996 match (a, b) {
1997 (None, None) => Ordering::Equal,
1998 (None, _) => Ordering::Less,
1999 (_, None) => Ordering::Greater,
2000 (Some(a), Some(b)) => match (a, b) {
2001 (Value::Null, Value::Null) => Ordering::Equal,
2002 (Value::Null, _) => Ordering::Less,
2003 (_, Value::Null) => Ordering::Greater,
2004 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2005 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2006 (Value::Integer(x), Value::Real(y)) => {
2007 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2008 }
2009 (Value::Real(x), Value::Integer(y)) => {
2010 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2011 }
2012 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2013 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2014 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2016 },
2017 }
2018}
2019
2020pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
2022 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
2023}
2024
2025pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2029 let v = eval_expr_scope(expr, scope)?;
2030 match v {
2031 Value::Bool(b) => Ok(b),
2032 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2034 other => Err(SQLRiteError::Internal(format!(
2035 "WHERE clause must evaluate to boolean, got {}",
2036 other.to_display_string()
2037 ))),
2038 }
2039}
2040
2041fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
2043 eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
2044}
2045
2046fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2047 match expr {
2048 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2049
2050 Expr::Identifier(ident) => {
2051 if ident.quote_style == Some('[') {
2061 let raw = format!("[{}]", ident.value);
2062 let v = parse_vector_literal(&raw)?;
2063 return Ok(Value::Vector(v));
2064 }
2065 scope.lookup(None, &ident.value)
2066 }
2067
2068 Expr::CompoundIdentifier(parts) => {
2069 match parts.as_slice() {
2075 [only] => scope.lookup(None, &only.value),
2076 [q, c] => scope.lookup(Some(&q.value), &c.value),
2077 _ => Err(SQLRiteError::NotImplemented(format!(
2078 "compound identifier with {} parts is not supported",
2079 parts.len()
2080 ))),
2081 }
2082 }
2083
2084 Expr::Value(v) => convert_literal(&v.value),
2085
2086 Expr::UnaryOp { op, expr } => {
2087 let inner = eval_expr_scope(expr, scope)?;
2088 match op {
2089 UnaryOperator::Not => match inner {
2090 Value::Bool(b) => Ok(Value::Bool(!b)),
2091 Value::Null => Ok(Value::Null),
2092 other => Err(SQLRiteError::Internal(format!(
2093 "NOT applied to non-boolean value: {}",
2094 other.to_display_string()
2095 ))),
2096 },
2097 UnaryOperator::Minus => match inner {
2098 Value::Integer(i) => Ok(Value::Integer(-i)),
2099 Value::Real(f) => Ok(Value::Real(-f)),
2100 Value::Null => Ok(Value::Null),
2101 other => Err(SQLRiteError::Internal(format!(
2102 "unary minus on non-numeric value: {}",
2103 other.to_display_string()
2104 ))),
2105 },
2106 UnaryOperator::Plus => Ok(inner),
2107 other => Err(SQLRiteError::NotImplemented(format!(
2108 "unary operator {other:?} is not supported"
2109 ))),
2110 }
2111 }
2112
2113 Expr::BinaryOp { left, op, right } => match op {
2114 BinaryOperator::And => {
2115 let l = eval_expr_scope(left, scope)?;
2116 let r = eval_expr_scope(right, scope)?;
2117 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2118 }
2119 BinaryOperator::Or => {
2120 let l = eval_expr_scope(left, scope)?;
2121 let r = eval_expr_scope(right, scope)?;
2122 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2123 }
2124 cmp @ (BinaryOperator::Eq
2125 | BinaryOperator::NotEq
2126 | BinaryOperator::Lt
2127 | BinaryOperator::LtEq
2128 | BinaryOperator::Gt
2129 | BinaryOperator::GtEq) => {
2130 let l = eval_expr_scope(left, scope)?;
2131 let r = eval_expr_scope(right, scope)?;
2132 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2134 return Ok(Value::Bool(false));
2135 }
2136 let ord = compare_values(Some(&l), Some(&r));
2137 let result = match cmp {
2138 BinaryOperator::Eq => ord == Ordering::Equal,
2139 BinaryOperator::NotEq => ord != Ordering::Equal,
2140 BinaryOperator::Lt => ord == Ordering::Less,
2141 BinaryOperator::LtEq => ord != Ordering::Greater,
2142 BinaryOperator::Gt => ord == Ordering::Greater,
2143 BinaryOperator::GtEq => ord != Ordering::Less,
2144 _ => unreachable!(),
2145 };
2146 Ok(Value::Bool(result))
2147 }
2148 arith @ (BinaryOperator::Plus
2149 | BinaryOperator::Minus
2150 | BinaryOperator::Multiply
2151 | BinaryOperator::Divide
2152 | BinaryOperator::Modulo) => {
2153 let l = eval_expr_scope(left, scope)?;
2154 let r = eval_expr_scope(right, scope)?;
2155 eval_arith(arith, &l, &r)
2156 }
2157 BinaryOperator::StringConcat => {
2158 let l = eval_expr_scope(left, scope)?;
2159 let r = eval_expr_scope(right, scope)?;
2160 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2161 return Ok(Value::Null);
2162 }
2163 Ok(Value::Text(format!(
2164 "{}{}",
2165 l.to_display_string(),
2166 r.to_display_string()
2167 )))
2168 }
2169 other => Err(SQLRiteError::NotImplemented(format!(
2170 "binary operator {other:?} is not supported yet"
2171 ))),
2172 },
2173
2174 Expr::IsNull(inner) => {
2182 let v = eval_expr_scope(inner, scope)?;
2183 Ok(Value::Bool(matches!(v, Value::Null)))
2184 }
2185 Expr::IsNotNull(inner) => {
2186 let v = eval_expr_scope(inner, scope)?;
2187 Ok(Value::Bool(!matches!(v, Value::Null)))
2188 }
2189
2190 Expr::Like {
2197 negated,
2198 any,
2199 expr: lhs,
2200 pattern,
2201 escape_char,
2202 } => eval_like(
2203 scope,
2204 *negated,
2205 *any,
2206 lhs,
2207 pattern,
2208 escape_char.as_ref(),
2209 true,
2210 ),
2211 Expr::ILike {
2212 negated,
2213 any,
2214 expr: lhs,
2215 pattern,
2216 escape_char,
2217 } => eval_like(
2218 scope,
2219 *negated,
2220 *any,
2221 lhs,
2222 pattern,
2223 escape_char.as_ref(),
2224 true,
2225 ),
2226
2227 Expr::InList {
2233 expr: lhs,
2234 list,
2235 negated,
2236 } => eval_in_list(scope, lhs, list, *negated),
2237 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2238 "IN (subquery) is not supported (only literal lists are)".to_string(),
2239 )),
2240
2241 Expr::Function(func) => eval_function(func, scope),
2252
2253 other => Err(SQLRiteError::NotImplemented(format!(
2254 "unsupported expression in WHERE/projection: {other:?}"
2255 ))),
2256 }
2257}
2258
2259fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2264 let name = match func.name.0.as_slice() {
2267 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2268 _ => {
2269 return Err(SQLRiteError::NotImplemented(format!(
2270 "qualified function names not supported: {:?}",
2271 func.name
2272 )));
2273 }
2274 };
2275
2276 match name.as_str() {
2277 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2278 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2279 let dist = match name.as_str() {
2280 "vec_distance_l2" => vec_distance_l2(&a, &b),
2281 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2282 "vec_distance_dot" => vec_distance_dot(&a, &b),
2283 _ => unreachable!(),
2284 };
2285 Ok(Value::Real(dist as f64))
2291 }
2292 "json_extract" => json_fn_extract(&name, &func.args, scope),
2297 "json_type" => json_fn_type(&name, &func.args, scope),
2298 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2299 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2300 "fts_match" | "bm25_score" => {
2311 let Some((table, rowid)) = scope.single_table_view() else {
2312 return Err(SQLRiteError::NotImplemented(format!(
2313 "{name}() is not yet supported inside a JOIN query — \
2314 use it on a single-table SELECT or move the FTS lookup into a subquery"
2315 )));
2316 };
2317 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2318 Ok(match name.as_str() {
2319 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2320 "bm25_score" => {
2321 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2322 }
2323 _ => unreachable!(),
2324 })
2325 }
2326 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2330 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2331 use it as a top-level projection item (HAVING is not yet supported)"
2332 ))),
2333 other => Err(SQLRiteError::NotImplemented(format!(
2334 "unknown function: {other}(...)"
2335 ))),
2336 }
2337}
2338
2339fn resolve_fts_args<'t>(
2344 fn_name: &str,
2345 args: &FunctionArguments,
2346 table: &'t Table,
2347 scope: &dyn RowScope,
2348) -> Result<(&'t FtsIndexEntry, String)> {
2349 let arg_list = match args {
2350 FunctionArguments::List(l) => &l.args,
2351 _ => {
2352 return Err(SQLRiteError::General(format!(
2353 "{fn_name}() expects exactly two arguments: (column, query_text)"
2354 )));
2355 }
2356 };
2357 if arg_list.len() != 2 {
2358 return Err(SQLRiteError::General(format!(
2359 "{fn_name}() expects exactly 2 arguments, got {}",
2360 arg_list.len()
2361 )));
2362 }
2363
2364 let col_expr = match &arg_list[0] {
2368 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2369 other => {
2370 return Err(SQLRiteError::NotImplemented(format!(
2371 "{fn_name}() argument 0 must be a column name, got {other:?}"
2372 )));
2373 }
2374 };
2375 let col_name = match col_expr {
2376 Expr::Identifier(ident) => ident.value.clone(),
2377 Expr::CompoundIdentifier(parts) => parts
2378 .last()
2379 .map(|p| p.value.clone())
2380 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
2381 other => {
2382 return Err(SQLRiteError::General(format!(
2383 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2384 )));
2385 }
2386 };
2387
2388 let q_expr = match &arg_list[1] {
2392 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2393 other => {
2394 return Err(SQLRiteError::NotImplemented(format!(
2395 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2396 )));
2397 }
2398 };
2399 let query = match eval_expr_scope(q_expr, scope)? {
2400 Value::Text(s) => s,
2401 other => {
2402 return Err(SQLRiteError::General(format!(
2403 "{fn_name}() argument 1 must be TEXT, got {}",
2404 other.to_display_string()
2405 )));
2406 }
2407 };
2408
2409 let entry = table
2410 .fts_indexes
2411 .iter()
2412 .find(|e| e.column_name == col_name)
2413 .ok_or_else(|| {
2414 SQLRiteError::General(format!(
2415 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2416 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2417 ))
2418 })?;
2419 Ok((entry, query))
2420}
2421
2422fn extract_json_and_path(
2436 fn_name: &str,
2437 args: &FunctionArguments,
2438 scope: &dyn RowScope,
2439) -> Result<(String, String)> {
2440 let arg_list = match args {
2441 FunctionArguments::List(l) => &l.args,
2442 _ => {
2443 return Err(SQLRiteError::General(format!(
2444 "{fn_name}() expects 1 or 2 arguments"
2445 )));
2446 }
2447 };
2448 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2449 return Err(SQLRiteError::General(format!(
2450 "{fn_name}() expects 1 or 2 arguments, got {}",
2451 arg_list.len()
2452 )));
2453 }
2454 let first_expr = match &arg_list[0] {
2456 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2457 other => {
2458 return Err(SQLRiteError::NotImplemented(format!(
2459 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2460 )));
2461 }
2462 };
2463 let json_text = match eval_expr_scope(first_expr, scope)? {
2464 Value::Text(s) => s,
2465 Value::Null => {
2466 return Err(SQLRiteError::General(format!(
2467 "{fn_name}() called on NULL — JSON column has no value for this row"
2468 )));
2469 }
2470 other => {
2471 return Err(SQLRiteError::General(format!(
2472 "{fn_name}() argument 0 is not JSON-typed: got {}",
2473 other.to_display_string()
2474 )));
2475 }
2476 };
2477
2478 let path = if arg_list.len() == 2 {
2480 let path_expr = match &arg_list[1] {
2481 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2482 other => {
2483 return Err(SQLRiteError::NotImplemented(format!(
2484 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2485 )));
2486 }
2487 };
2488 match eval_expr_scope(path_expr, scope)? {
2489 Value::Text(s) => s,
2490 other => {
2491 return Err(SQLRiteError::General(format!(
2492 "{fn_name}() path argument must be a string literal, got {}",
2493 other.to_display_string()
2494 )));
2495 }
2496 }
2497 } else {
2498 "$".to_string()
2499 };
2500
2501 Ok((json_text, path))
2502}
2503
2504fn walk_json_path<'a>(
2514 value: &'a serde_json::Value,
2515 path: &str,
2516) -> Result<Option<&'a serde_json::Value>> {
2517 let mut chars = path.chars().peekable();
2518 if chars.next() != Some('$') {
2519 return Err(SQLRiteError::General(format!(
2520 "JSON path must start with '$', got `{path}`"
2521 )));
2522 }
2523 let mut current = value;
2524 while let Some(&c) = chars.peek() {
2525 match c {
2526 '.' => {
2527 chars.next();
2528 let mut key = String::new();
2529 while let Some(&c) = chars.peek() {
2530 if c == '.' || c == '[' {
2531 break;
2532 }
2533 key.push(c);
2534 chars.next();
2535 }
2536 if key.is_empty() {
2537 return Err(SQLRiteError::General(format!(
2538 "JSON path has empty key after '.' in `{path}`"
2539 )));
2540 }
2541 match current.get(&key) {
2542 Some(v) => current = v,
2543 None => return Ok(None),
2544 }
2545 }
2546 '[' => {
2547 chars.next();
2548 let mut idx_str = String::new();
2549 while let Some(&c) = chars.peek() {
2550 if c == ']' {
2551 break;
2552 }
2553 idx_str.push(c);
2554 chars.next();
2555 }
2556 if chars.next() != Some(']') {
2557 return Err(SQLRiteError::General(format!(
2558 "JSON path has unclosed `[` in `{path}`"
2559 )));
2560 }
2561 let idx: usize = idx_str.trim().parse().map_err(|_| {
2562 SQLRiteError::General(format!(
2563 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2564 ))
2565 })?;
2566 match current.get(idx) {
2567 Some(v) => current = v,
2568 None => return Ok(None),
2569 }
2570 }
2571 other => {
2572 return Err(SQLRiteError::General(format!(
2573 "JSON path has unexpected character `{other}` in `{path}` \
2574 (expected `.`, `[`, or end-of-path)"
2575 )));
2576 }
2577 }
2578 }
2579 Ok(Some(current))
2580}
2581
2582fn json_value_to_sql(v: &serde_json::Value) -> Value {
2586 match v {
2587 serde_json::Value::Null => Value::Null,
2588 serde_json::Value::Bool(b) => Value::Bool(*b),
2589 serde_json::Value::Number(n) => {
2590 if let Some(i) = n.as_i64() {
2592 Value::Integer(i)
2593 } else if let Some(f) = n.as_f64() {
2594 Value::Real(f)
2595 } else {
2596 Value::Null
2597 }
2598 }
2599 serde_json::Value::String(s) => Value::Text(s.clone()),
2600 composite => Value::Text(composite.to_string()),
2604 }
2605}
2606
2607fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2608 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2609 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2610 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2611 })?;
2612 match walk_json_path(&parsed, &path)? {
2613 Some(v) => Ok(json_value_to_sql(v)),
2614 None => Ok(Value::Null),
2615 }
2616}
2617
2618fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2619 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2620 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2621 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2622 })?;
2623 let resolved = match walk_json_path(&parsed, &path)? {
2624 Some(v) => v,
2625 None => return Ok(Value::Null),
2626 };
2627 let ty = match resolved {
2628 serde_json::Value::Null => "null",
2629 serde_json::Value::Bool(true) => "true",
2630 serde_json::Value::Bool(false) => "false",
2631 serde_json::Value::Number(n) => {
2632 if n.is_i64() || n.is_u64() {
2633 "integer"
2634 } else {
2635 "real"
2636 }
2637 }
2638 serde_json::Value::String(_) => "text",
2639 serde_json::Value::Array(_) => "array",
2640 serde_json::Value::Object(_) => "object",
2641 };
2642 Ok(Value::Text(ty.to_string()))
2643}
2644
2645fn json_fn_array_length(
2646 name: &str,
2647 args: &FunctionArguments,
2648 scope: &dyn RowScope,
2649) -> Result<Value> {
2650 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2651 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2652 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2653 })?;
2654 let resolved = match walk_json_path(&parsed, &path)? {
2655 Some(v) => v,
2656 None => return Ok(Value::Null),
2657 };
2658 match resolved.as_array() {
2659 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2660 None => Err(SQLRiteError::General(format!(
2661 "{name}() resolved to a non-array value at path `{path}`"
2662 ))),
2663 }
2664}
2665
2666fn json_fn_object_keys(
2667 name: &str,
2668 args: &FunctionArguments,
2669 scope: &dyn RowScope,
2670) -> Result<Value> {
2671 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2672 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2673 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2674 })?;
2675 let resolved = match walk_json_path(&parsed, &path)? {
2676 Some(v) => v,
2677 None => return Ok(Value::Null),
2678 };
2679 let obj = resolved.as_object().ok_or_else(|| {
2680 SQLRiteError::General(format!(
2681 "{name}() resolved to a non-object value at path `{path}`"
2682 ))
2683 })?;
2684 let keys: Vec<serde_json::Value> = obj
2691 .keys()
2692 .map(|k| serde_json::Value::String(k.clone()))
2693 .collect();
2694 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2695}
2696
2697fn extract_two_vector_args(
2701 fn_name: &str,
2702 args: &FunctionArguments,
2703 scope: &dyn RowScope,
2704) -> Result<(Vec<f32>, Vec<f32>)> {
2705 let arg_list = match args {
2706 FunctionArguments::List(l) => &l.args,
2707 _ => {
2708 return Err(SQLRiteError::General(format!(
2709 "{fn_name}() expects exactly two vector arguments"
2710 )));
2711 }
2712 };
2713 if arg_list.len() != 2 {
2714 return Err(SQLRiteError::General(format!(
2715 "{fn_name}() expects exactly 2 arguments, got {}",
2716 arg_list.len()
2717 )));
2718 }
2719 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
2720 for (i, arg) in arg_list.iter().enumerate() {
2721 let expr = match arg {
2722 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2723 other => {
2724 return Err(SQLRiteError::NotImplemented(format!(
2725 "{fn_name}() argument {i} has unsupported shape: {other:?}"
2726 )));
2727 }
2728 };
2729 let val = eval_expr_scope(expr, scope)?;
2730 match val {
2731 Value::Vector(v) => out.push(v),
2732 other => {
2733 return Err(SQLRiteError::General(format!(
2734 "{fn_name}() argument {i} is not a vector: got {}",
2735 other.to_display_string()
2736 )));
2737 }
2738 }
2739 }
2740 let b = out.pop().unwrap();
2741 let a = out.pop().unwrap();
2742 if a.len() != b.len() {
2743 return Err(SQLRiteError::General(format!(
2744 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
2745 a.len(),
2746 b.len()
2747 )));
2748 }
2749 Ok((a, b))
2750}
2751
2752pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
2755 debug_assert_eq!(a.len(), b.len());
2756 let mut sum = 0.0f32;
2757 for i in 0..a.len() {
2758 let d = a[i] - b[i];
2759 sum += d * d;
2760 }
2761 sum.sqrt()
2762}
2763
2764pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
2774 debug_assert_eq!(a.len(), b.len());
2775 let mut dot = 0.0f32;
2776 let mut norm_a_sq = 0.0f32;
2777 let mut norm_b_sq = 0.0f32;
2778 for i in 0..a.len() {
2779 dot += a[i] * b[i];
2780 norm_a_sq += a[i] * a[i];
2781 norm_b_sq += b[i] * b[i];
2782 }
2783 let denom = (norm_a_sq * norm_b_sq).sqrt();
2784 if denom == 0.0 {
2785 return Err(SQLRiteError::General(
2786 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
2787 ));
2788 }
2789 Ok(1.0 - dot / denom)
2790}
2791
2792pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
2796 debug_assert_eq!(a.len(), b.len());
2797 let mut dot = 0.0f32;
2798 for i in 0..a.len() {
2799 dot += a[i] * b[i];
2800 }
2801 -dot
2802}
2803
2804fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
2807 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2808 return Ok(Value::Null);
2809 }
2810 match (l, r) {
2811 (Value::Integer(a), Value::Integer(b)) => match op {
2812 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
2813 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
2814 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
2815 BinaryOperator::Divide => {
2816 if *b == 0 {
2817 Err(SQLRiteError::General("division by zero".to_string()))
2818 } else {
2819 Ok(Value::Integer(a / b))
2820 }
2821 }
2822 BinaryOperator::Modulo => {
2823 if *b == 0 {
2824 Err(SQLRiteError::General("modulo by zero".to_string()))
2825 } else {
2826 Ok(Value::Integer(a % b))
2827 }
2828 }
2829 _ => unreachable!(),
2830 },
2831 (a, b) => {
2833 let af = as_number(a)?;
2834 let bf = as_number(b)?;
2835 match op {
2836 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
2837 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
2838 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
2839 BinaryOperator::Divide => {
2840 if bf == 0.0 {
2841 Err(SQLRiteError::General("division by zero".to_string()))
2842 } else {
2843 Ok(Value::Real(af / bf))
2844 }
2845 }
2846 BinaryOperator::Modulo => {
2847 if bf == 0.0 {
2848 Err(SQLRiteError::General("modulo by zero".to_string()))
2849 } else {
2850 Ok(Value::Real(af % bf))
2851 }
2852 }
2853 _ => unreachable!(),
2854 }
2855 }
2856 }
2857}
2858
2859fn as_number(v: &Value) -> Result<f64> {
2860 match v {
2861 Value::Integer(i) => Ok(*i as f64),
2862 Value::Real(f) => Ok(*f),
2863 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
2864 other => Err(SQLRiteError::General(format!(
2865 "arithmetic on non-numeric value '{}'",
2866 other.to_display_string()
2867 ))),
2868 }
2869}
2870
2871fn as_bool(v: &Value) -> Result<bool> {
2872 match v {
2873 Value::Bool(b) => Ok(*b),
2874 Value::Null => Ok(false),
2875 Value::Integer(i) => Ok(*i != 0),
2876 other => Err(SQLRiteError::Internal(format!(
2877 "expected boolean, got {}",
2878 other.to_display_string()
2879 ))),
2880 }
2881}
2882
2883#[allow(clippy::too_many_arguments)]
2888fn eval_like(
2889 scope: &dyn RowScope,
2890 negated: bool,
2891 any: bool,
2892 lhs: &Expr,
2893 pattern: &Expr,
2894 escape_char: Option<&AstValue>,
2895 case_insensitive: bool,
2896) -> Result<Value> {
2897 if any {
2898 return Err(SQLRiteError::NotImplemented(
2899 "LIKE ANY (...) is not supported".to_string(),
2900 ));
2901 }
2902 if escape_char.is_some() {
2903 return Err(SQLRiteError::NotImplemented(
2904 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
2905 ));
2906 }
2907
2908 let l = eval_expr_scope(lhs, scope)?;
2909 let p = eval_expr_scope(pattern, scope)?;
2910 if matches!(l, Value::Null) || matches!(p, Value::Null) {
2911 return Ok(Value::Null);
2912 }
2913 let text = match l {
2914 Value::Text(s) => s,
2915 other => other.to_display_string(),
2916 };
2917 let pat = match p {
2918 Value::Text(s) => s,
2919 other => other.to_display_string(),
2920 };
2921 let m = like_match(&text, &pat, case_insensitive);
2922 Ok(Value::Bool(if negated { !m } else { m }))
2923}
2924
2925fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
2926 let l = eval_expr_scope(lhs, scope)?;
2927 if matches!(l, Value::Null) {
2928 return Ok(Value::Null);
2929 }
2930 let mut saw_null = false;
2931 for item in list {
2932 let r = eval_expr_scope(item, scope)?;
2933 if matches!(r, Value::Null) {
2934 saw_null = true;
2935 continue;
2936 }
2937 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
2938 return Ok(Value::Bool(!negated));
2939 }
2940 }
2941 if saw_null {
2942 Ok(Value::Null)
2945 } else {
2946 Ok(Value::Bool(negated))
2947 }
2948}
2949
2950fn aggregate_rows(
2961 table: &Table,
2962 matching: &[i64],
2963 group_by: &[String],
2964 proj_items: &[ProjectionItem],
2965) -> Result<Vec<Vec<Value>>> {
2966 let template: Vec<Option<AggState>> = proj_items
2970 .iter()
2971 .map(|i| match &i.kind {
2972 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
2973 ProjectionKind::Column { .. } => None,
2974 })
2975 .collect();
2976
2977 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
2983 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
2984 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
2985
2986 for &rowid in matching {
2987 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
2988 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
2989 for col in group_by {
2990 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
2991 key.push(DistinctKey::from_value(&v));
2992 key_values.push(v);
2993 }
2994 let idx = match keys.iter().position(|k| k == &key) {
2995 Some(i) => i,
2996 None => {
2997 keys.push(key);
2998 group_states.push(template.clone());
2999 group_key_values.push(key_values);
3000 keys.len() - 1
3001 }
3002 };
3003
3004 for (slot, item) in proj_items.iter().enumerate() {
3005 if let ProjectionKind::Aggregate(call) = &item.kind {
3006 let v = match &call.arg {
3007 AggregateArg::Star => Value::Null,
3008 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
3009 };
3010 if let Some(state) = group_states[idx][slot].as_mut() {
3011 state.update(&v)?;
3012 }
3013 }
3014 }
3015 }
3016
3017 if keys.is_empty() && group_by.is_empty() {
3023 keys.push(Vec::new());
3026 group_states.push(template.clone());
3027 group_key_values.push(Vec::new());
3028 }
3029
3030 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3032 for (group_idx, _) in keys.iter().enumerate() {
3033 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3034 for (slot, item) in proj_items.iter().enumerate() {
3035 match &item.kind {
3036 ProjectionKind::Column { name: c, .. } => {
3037 let pos = group_by
3040 .iter()
3041 .position(|g| g == c)
3042 .expect("validated to be in GROUP BY");
3043 row.push(group_key_values[group_idx][pos].clone());
3044 }
3045 ProjectionKind::Aggregate(_) => {
3046 let state = group_states[group_idx][slot]
3047 .as_ref()
3048 .expect("aggregate slot has state");
3049 row.push(state.finalize());
3050 }
3051 }
3052 }
3053 rows.push(row);
3054 }
3055 Ok(rows)
3056}
3057
3058fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3062 use std::collections::HashSet;
3063 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3064 let mut out = Vec::with_capacity(rows.len());
3065 for row in rows {
3066 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3067 if seen.insert(key) {
3068 out.push(row);
3069 }
3070 }
3071 out
3072}
3073
3074fn sort_output_rows(
3078 rows: &mut [Vec<Value>],
3079 columns: &[String],
3080 proj_items: &[ProjectionItem],
3081 order: &OrderByClause,
3082) -> Result<()> {
3083 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3084 rows.sort_by(|a, b| {
3085 let va = &a[target_idx];
3086 let vb = &b[target_idx];
3087 let ord = compare_values(Some(va), Some(vb));
3088 if order.ascending { ord } else { ord.reverse() }
3089 });
3090 Ok(())
3091}
3092
3093fn resolve_order_by_index(
3096 expr: &Expr,
3097 columns: &[String],
3098 proj_items: &[ProjectionItem],
3099) -> Result<usize> {
3100 let target_name: Option<String> = match expr {
3102 Expr::Identifier(ident) => Some(ident.value.clone()),
3103 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3104 Expr::Function(_) => None,
3105 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3106 other => {
3107 return Err(SQLRiteError::NotImplemented(format!(
3108 "ORDER BY expression not supported on aggregating queries: {other:?}"
3109 )));
3110 }
3111 };
3112 if let Some(name) = target_name {
3113 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3114 return Ok(i);
3115 }
3116 return Err(SQLRiteError::Internal(format!(
3117 "ORDER BY references unknown column '{name}' in the SELECT output"
3118 )));
3119 }
3120 if let Expr::Function(func) = expr {
3124 let user_disp = format_function_display(func);
3125 for (i, item) in proj_items.iter().enumerate() {
3126 if let ProjectionKind::Aggregate(call) = &item.kind
3127 && call.display_name().eq_ignore_ascii_case(&user_disp)
3128 {
3129 return Ok(i);
3130 }
3131 }
3132 return Err(SQLRiteError::Internal(format!(
3133 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3134 )));
3135 }
3136 Err(SQLRiteError::Internal(
3137 "ORDER BY expression could not be resolved against the output columns".to_string(),
3138 ))
3139}
3140
3141fn format_function_display(func: &sqlparser::ast::Function) -> String {
3145 let name = match func.name.0.as_slice() {
3146 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3147 _ => format!("{:?}", func.name).to_uppercase(),
3148 };
3149 let inner = match &func.args {
3150 FunctionArguments::List(l) => {
3151 let distinct = matches!(
3152 l.duplicate_treatment,
3153 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3154 );
3155 let arg = l.args.first().map(|a| match a {
3156 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3157 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3158 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3159 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3160 }
3161 _ => String::new(),
3162 });
3163 match (distinct, arg) {
3164 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3165 (_, Some(a)) => a,
3166 _ => String::new(),
3167 }
3168 }
3169 _ => String::new(),
3170 };
3171 format!("{name}({inner})")
3172}
3173
3174fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3175 use sqlparser::ast::Value as AstValue;
3176 match v {
3177 AstValue::Number(n, _) => {
3178 if let Ok(i) = n.parse::<i64>() {
3179 Ok(Value::Integer(i))
3180 } else if let Ok(f) = n.parse::<f64>() {
3181 Ok(Value::Real(f))
3182 } else {
3183 Err(SQLRiteError::Internal(format!(
3184 "could not parse numeric literal '{n}'"
3185 )))
3186 }
3187 }
3188 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3189 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3190 AstValue::Null => Ok(Value::Null),
3191 other => Err(SQLRiteError::NotImplemented(format!(
3192 "unsupported literal value: {other:?}"
3193 ))),
3194 }
3195}
3196
3197#[cfg(test)]
3198mod tests {
3199 use super::*;
3200
3201 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3208 (a - b).abs() < eps
3209 }
3210
3211 #[test]
3212 fn vec_distance_l2_identical_is_zero() {
3213 let v = vec![0.1, 0.2, 0.3];
3214 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3215 }
3216
3217 #[test]
3218 fn vec_distance_l2_unit_basis_is_sqrt2() {
3219 let a = vec![1.0, 0.0];
3221 let b = vec![0.0, 1.0];
3222 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
3223 }
3224
3225 #[test]
3226 fn vec_distance_l2_known_value() {
3227 let a = vec![0.0, 0.0, 0.0];
3229 let b = vec![3.0, 4.0, 0.0];
3230 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
3231 }
3232
3233 #[test]
3234 fn vec_distance_cosine_identical_is_zero() {
3235 let v = vec![0.1, 0.2, 0.3];
3236 let d = vec_distance_cosine(&v, &v).unwrap();
3237 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
3238 }
3239
3240 #[test]
3241 fn vec_distance_cosine_orthogonal_is_one() {
3242 let a = vec![1.0, 0.0];
3245 let b = vec![0.0, 1.0];
3246 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
3247 }
3248
3249 #[test]
3250 fn vec_distance_cosine_opposite_is_two() {
3251 let a = vec![1.0, 0.0, 0.0];
3253 let b = vec![-1.0, 0.0, 0.0];
3254 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
3255 }
3256
3257 #[test]
3258 fn vec_distance_cosine_zero_magnitude_errors() {
3259 let a = vec![0.0, 0.0];
3261 let b = vec![1.0, 0.0];
3262 let err = vec_distance_cosine(&a, &b).unwrap_err();
3263 assert!(format!("{err}").contains("zero-magnitude"));
3264 }
3265
3266 #[test]
3267 fn vec_distance_dot_negates() {
3268 let a = vec![1.0, 2.0, 3.0];
3270 let b = vec![4.0, 5.0, 6.0];
3271 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
3272 }
3273
3274 #[test]
3275 fn vec_distance_dot_orthogonal_is_zero() {
3276 let a = vec![1.0, 0.0];
3278 let b = vec![0.0, 1.0];
3279 assert_eq!(vec_distance_dot(&a, &b), 0.0);
3280 }
3281
3282 #[test]
3283 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3284 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
3290 let cos = vec_distance_cosine(&a, &b).unwrap();
3291 assert!(approx_eq(dot, cos - 1.0, 1e-5));
3292 }
3293
3294 use crate::sql::db::database::Database;
3299 use crate::sql::parser::select::SelectQuery;
3300 use sqlparser::dialect::SQLiteDialect;
3301 use sqlparser::parser::Parser;
3302
3303 fn seed_score_table(n: usize) -> Database {
3316 let mut db = Database::new("tempdb".to_string());
3317 crate::sql::process_command(
3318 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
3319 &mut db,
3320 )
3321 .expect("create");
3322 for i in 0..n {
3323 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
3327 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
3328 crate::sql::process_command(&sql, &mut db).expect("insert");
3329 }
3330 db
3331 }
3332
3333 fn parse_select(sql: &str) -> SelectQuery {
3337 let dialect = SQLiteDialect {};
3338 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
3339 let stmt = ast.pop().expect("one statement");
3340 SelectQuery::new(&stmt).expect("select-query")
3341 }
3342
3343 #[test]
3344 fn topk_matches_full_sort_asc() {
3345 let db = seed_score_table(200);
3348 let table = db.get_table("docs".to_string()).unwrap();
3349 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3350 let order = q.order_by.as_ref().unwrap();
3351 let all_rowids = table.rowids();
3352
3353 let mut full = all_rowids.clone();
3355 sort_rowids(&mut full, table, order).unwrap();
3356 full.truncate(10);
3357
3358 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3360
3361 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
3362 }
3363
3364 #[test]
3365 fn topk_matches_full_sort_desc() {
3366 let db = seed_score_table(200);
3368 let table = db.get_table("docs".to_string()).unwrap();
3369 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
3370 let order = q.order_by.as_ref().unwrap();
3371 let all_rowids = table.rowids();
3372
3373 let mut full = all_rowids.clone();
3374 sort_rowids(&mut full, table, order).unwrap();
3375 full.truncate(10);
3376
3377 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3378
3379 assert_eq!(
3380 topk, full,
3381 "top-k DESC via heap should match full-sort+truncate"
3382 );
3383 }
3384
3385 #[test]
3386 fn topk_k_larger_than_n_returns_everything_sorted() {
3387 let db = seed_score_table(50);
3392 let table = db.get_table("docs".to_string()).unwrap();
3393 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
3394 let order = q.order_by.as_ref().unwrap();
3395 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
3396 assert_eq!(topk.len(), 50);
3397 let scores: Vec<f64> = topk
3399 .iter()
3400 .filter_map(|r| match table.get_value("score", *r) {
3401 Some(Value::Real(f)) => Some(f),
3402 _ => None,
3403 })
3404 .collect();
3405 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
3406 }
3407
3408 #[test]
3409 fn topk_k_zero_returns_empty() {
3410 let db = seed_score_table(10);
3411 let table = db.get_table("docs".to_string()).unwrap();
3412 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
3413 let order = q.order_by.as_ref().unwrap();
3414 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
3415 assert!(topk.is_empty());
3416 }
3417
3418 #[test]
3419 fn topk_empty_input_returns_empty() {
3420 let db = seed_score_table(0);
3421 let table = db.get_table("docs".to_string()).unwrap();
3422 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
3423 let order = q.order_by.as_ref().unwrap();
3424 let topk = select_topk(&[], table, order, 5).unwrap();
3425 assert!(topk.is_empty());
3426 }
3427
3428 #[test]
3429 fn topk_works_through_select_executor_with_distance_function() {
3430 let mut db = Database::new("tempdb".to_string());
3434 crate::sql::process_command(
3435 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
3436 &mut db,
3437 )
3438 .unwrap();
3439 for v in &[
3446 "[1.0, 0.0]",
3447 "[2.0, 0.0]",
3448 "[0.0, 3.0]",
3449 "[1.0, 4.0]",
3450 "[10.0, 10.0]",
3451 ] {
3452 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3453 .unwrap();
3454 }
3455 let resp = crate::sql::process_command(
3456 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
3457 &mut db,
3458 )
3459 .unwrap();
3460 assert!(resp.contains("3 rows returned"), "got: {resp}");
3463 }
3464
3465 #[test]
3488 #[ignore]
3489 fn topk_benchmark() {
3490 use std::time::Instant;
3491 const N: usize = 10_000;
3492 const K: usize = 10;
3493
3494 let db = seed_score_table(N);
3495 let table = db.get_table("docs".to_string()).unwrap();
3496 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3497 let order = q.order_by.as_ref().unwrap();
3498 let all_rowids = table.rowids();
3499
3500 let t0 = Instant::now();
3502 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3503 let heap_dur = t0.elapsed();
3504
3505 let t1 = Instant::now();
3507 let mut full = all_rowids.clone();
3508 sort_rowids(&mut full, table, order).unwrap();
3509 full.truncate(K);
3510 let sort_dur = t1.elapsed();
3511
3512 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3513 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3514 println!(" bounded heap: {heap_dur:?}");
3515 println!(" full sort+trunc: {sort_dur:?}");
3516 println!(" speedup ratio: {ratio:.2}×");
3517
3518 assert!(
3525 ratio > 1.4,
3526 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
3527 );
3528 }
3529
3530 fn run_select(db: &mut Database, sql: &str) -> String {
3538 crate::sql::process_command(sql, db).expect("select")
3539 }
3540
3541 #[test]
3542 fn where_is_null_returns_null_rows() {
3543 let mut db = Database::new("t".to_string());
3544 crate::sql::process_command(
3545 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3546 &mut db,
3547 )
3548 .unwrap();
3549 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3550 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3551 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3552 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
3553
3554 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
3555 assert!(
3556 response.contains("2 rows returned"),
3557 "IS NULL should return 2 rows, got: {response}"
3558 );
3559 }
3560
3561 #[test]
3562 fn where_is_not_null_returns_non_null_rows() {
3563 let mut db = Database::new("t".to_string());
3564 crate::sql::process_command(
3565 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3566 &mut db,
3567 )
3568 .unwrap();
3569 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3570 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3571 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3572
3573 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
3574 assert!(
3575 response.contains("2 rows returned"),
3576 "IS NOT NULL should return 2 rows, got: {response}"
3577 );
3578 }
3579
3580 #[test]
3581 fn where_is_null_on_indexed_column() {
3582 let mut db = Database::new("t".to_string());
3587 crate::sql::process_command(
3588 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
3589 &mut db,
3590 )
3591 .unwrap();
3592 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
3593 .unwrap();
3594 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
3595 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
3596 .unwrap();
3597
3598 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
3599 assert!(
3600 null_rows.contains("1 row returned"),
3601 "indexed IS NULL should return 1 row, got: {null_rows}"
3602 );
3603 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
3604 assert!(
3605 not_null_rows.contains("2 rows returned"),
3606 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
3607 );
3608 }
3609
3610 #[test]
3611 fn where_is_null_works_on_omitted_column() {
3612 let mut db = Database::new("t".to_string());
3616 crate::sql::process_command(
3617 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
3618 &mut db,
3619 )
3620 .unwrap();
3621 crate::sql::process_command(
3622 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
3623 &mut db,
3624 )
3625 .unwrap();
3626 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
3628
3629 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
3630 assert!(
3631 response.contains("1 row returned"),
3632 "IS NULL should match the omitted-column row, got: {response}"
3633 );
3634 }
3635
3636 #[test]
3637 fn where_is_null_combines_with_and_or() {
3638 let mut db = Database::new("t".to_string());
3642 crate::sql::process_command(
3643 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3644 &mut db,
3645 )
3646 .unwrap();
3647 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
3648 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3649 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3650
3651 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
3652 assert!(
3653 response.contains("1 row returned"),
3654 "IS NULL combined with AND should match exactly row 2, got: {response}"
3655 );
3656 }
3657
3658 fn seed_employees() -> Database {
3664 let mut db = Database::new("t".to_string());
3665 crate::sql::process_command(
3666 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
3667 &mut db,
3668 )
3669 .unwrap();
3670 let rows = [
3671 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
3672 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
3673 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
3674 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
3675 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
3676 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
3677 ];
3678 for sql in rows {
3679 crate::sql::process_command(sql, &mut db).unwrap();
3680 }
3681 db
3682 }
3683
3684 fn run_rows(db: &Database, sql: &str) -> SelectResult {
3686 let q = parse_select(sql);
3687 execute_select_rows(q, db).expect("select")
3688 }
3689
3690 #[test]
3693 fn like_percent_prefix_case_insensitive() {
3694 let db = seed_employees();
3695 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
3696 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3698 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
3699 assert!(names.contains(&"Alice".to_string()));
3700 assert!(names.contains(&"alex".to_string()));
3701 }
3702
3703 #[test]
3704 fn like_underscore_singlechar() {
3705 let db = seed_employees();
3706 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
3707 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3709 assert_eq!(names, vec!["Eve".to_string()]);
3710 }
3711
3712 #[test]
3713 fn not_like_excludes_match() {
3714 let db = seed_employees();
3715 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
3716 assert_eq!(r.rows.len(), 4);
3718 }
3719
3720 #[test]
3721 fn like_with_null_excludes_row() {
3722 let db = seed_employees();
3723 let r = run_rows(
3725 &db,
3726 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
3727 );
3728 assert_eq!(r.rows.len(), 1);
3729 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
3730 }
3731
3732 #[test]
3735 fn in_list_positive() {
3736 let db = seed_employees();
3737 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
3738 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3739 assert_eq!(names.len(), 3);
3740 assert!(names.contains(&"Alice".to_string()));
3741 assert!(names.contains(&"Bob".to_string()));
3742 assert!(names.contains(&"Dave".to_string()));
3743 }
3744
3745 #[test]
3746 fn not_in_excludes_listed() {
3747 let db = seed_employees();
3748 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
3749 assert_eq!(r.rows.len(), 4);
3751 }
3752
3753 #[test]
3754 fn in_list_with_null_three_valued() {
3755 let db = seed_employees();
3756 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
3759 assert_eq!(r.rows.len(), 1);
3760 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
3761 }
3762
3763 #[test]
3766 fn distinct_single_column() {
3767 let db = seed_employees();
3768 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
3769 assert_eq!(r.rows.len(), 3);
3771 }
3772
3773 #[test]
3774 fn distinct_multi_column_with_null() {
3775 let db = seed_employees();
3776 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
3778 assert_eq!(r.rows.len(), 5);
3780 }
3781
3782 #[test]
3785 fn count_star_no_groupby() {
3786 let db = seed_employees();
3787 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
3788 assert_eq!(r.rows.len(), 1);
3789 assert_eq!(r.rows[0][0], Value::Integer(6));
3790 }
3791
3792 #[test]
3793 fn count_col_skips_nulls() {
3794 let db = seed_employees();
3795 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
3796 assert_eq!(r.rows[0][0], Value::Integer(5));
3798 }
3799
3800 #[test]
3801 fn count_distinct_dedupes_and_skips_nulls() {
3802 let db = seed_employees();
3803 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
3804 assert_eq!(r.rows[0][0], Value::Integer(4));
3806 }
3807
3808 #[test]
3809 fn sum_int_stays_integer() {
3810 let db = seed_employees();
3811 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
3812 assert_eq!(r.rows[0][0], Value::Integer(490));
3814 }
3815
3816 #[test]
3817 fn avg_returns_real() {
3818 let db = seed_employees();
3819 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
3820 match &r.rows[0][0] {
3822 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
3823 other => panic!("expected Real, got {other:?}"),
3824 }
3825 }
3826
3827 #[test]
3828 fn min_max_skip_nulls() {
3829 let db = seed_employees();
3830 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
3831 assert_eq!(r.rows[0][0], Value::Integer(80));
3832 assert_eq!(r.rows[0][1], Value::Integer(120));
3833 }
3834
3835 #[test]
3836 fn aggregates_on_empty_table_emit_one_row() {
3837 let mut db = Database::new("t".to_string());
3838 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3839 let r = run_rows(
3840 &db,
3841 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
3842 );
3843 assert_eq!(r.rows.len(), 1);
3844 assert_eq!(r.rows[0][0], Value::Integer(0));
3845 assert_eq!(r.rows[0][1], Value::Null);
3846 assert_eq!(r.rows[0][2], Value::Null);
3847 assert_eq!(r.rows[0][3], Value::Null);
3848 assert_eq!(r.rows[0][4], Value::Null);
3849 }
3850
3851 #[test]
3854 fn group_by_single_col_with_count() {
3855 let db = seed_employees();
3856 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
3857 assert_eq!(r.rows.len(), 3);
3858 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
3860 for row in &r.rows {
3861 let d = row[0].to_display_string();
3862 let c = match &row[1] {
3863 Value::Integer(i) => *i,
3864 v => panic!("expected Integer count, got {v:?}"),
3865 };
3866 by_dept.insert(d, c);
3867 }
3868 assert_eq!(by_dept["eng"], 3);
3869 assert_eq!(by_dept["sales"], 2);
3870 assert_eq!(by_dept["ops"], 1);
3871 }
3872
3873 #[test]
3874 fn group_by_with_where_filter() {
3875 let db = seed_employees();
3876 let r = run_rows(
3877 &db,
3878 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
3879 );
3880 let by: std::collections::HashMap<String, i64> = r
3883 .rows
3884 .iter()
3885 .map(|row| {
3886 (
3887 row[0].to_display_string(),
3888 match &row[1] {
3889 Value::Integer(i) => *i,
3890 v => panic!("expected Integer sum, got {v:?}"),
3891 },
3892 )
3893 })
3894 .collect();
3895 assert_eq!(by.len(), 2);
3896 assert_eq!(by["eng"], 320);
3897 assert_eq!(by["sales"], 90);
3898 }
3899
3900 #[test]
3901 fn group_by_without_aggregates_is_distinct() {
3902 let db = seed_employees();
3903 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
3904 assert_eq!(r.rows.len(), 3);
3905 }
3906
3907 #[test]
3908 fn order_by_count_desc() {
3909 let db = seed_employees();
3910 let r = run_rows(
3911 &db,
3912 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
3913 );
3914 assert_eq!(r.rows.len(), 2);
3915 assert_eq!(r.rows[0][0].to_display_string(), "eng");
3917 assert_eq!(r.rows[0][1], Value::Integer(3));
3918 }
3919
3920 #[test]
3921 fn order_by_aggregate_call_form() {
3922 let db = seed_employees();
3923 let r = run_rows(
3925 &db,
3926 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
3927 );
3928 assert_eq!(r.rows.len(), 3);
3929 assert_eq!(r.rows[0][0].to_display_string(), "eng");
3930 }
3931
3932 #[test]
3933 fn group_by_invalid_bare_column_errors() {
3934 let mut db = Database::new("t".to_string());
3936 crate::sql::process_command(
3937 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
3938 &mut db,
3939 )
3940 .unwrap();
3941 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
3942 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
3943 }
3944
3945 #[test]
3946 fn aggregate_in_where_errors_friendly() {
3947 let mut db = Database::new("t".to_string());
3948 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3949 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
3950 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
3951 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
3952 }
3953
3954 fn seed_join_fixture() -> Database {
3965 let mut db = Database::new("t".to_string());
3966 for sql in [
3967 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
3968 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
3969 "INSERT INTO customers (name) VALUES ('Alice');",
3970 "INSERT INTO customers (name) VALUES ('Bob');",
3971 "INSERT INTO customers (name) VALUES ('Carol');",
3972 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
3973 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
3974 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
3975 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
3976 ] {
3977 crate::sql::process_command(sql, &mut db).unwrap();
3978 }
3979 db
3980 }
3981
3982 #[test]
3983 fn inner_join_returns_only_matched_rows() {
3984 let db = seed_join_fixture();
3985 let r = run_rows(
3986 &db,
3987 "SELECT customers.name, orders.amount FROM customers \
3988 INNER JOIN orders ON customers.id = orders.customer_id;",
3989 );
3990 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
3991 let pairs: Vec<(String, i64)> = r
3994 .rows
3995 .iter()
3996 .map(|row| {
3997 (
3998 row[0].to_display_string(),
3999 match row[1] {
4000 Value::Integer(i) => i,
4001 ref v => panic!("expected integer amount, got {v:?}"),
4002 },
4003 )
4004 })
4005 .collect();
4006 assert_eq!(pairs.len(), 3);
4007 assert!(pairs.contains(&("Alice".to_string(), 100)));
4008 assert!(pairs.contains(&("Alice".to_string(), 200)));
4009 assert!(pairs.contains(&("Bob".to_string(), 50)));
4010 }
4011
4012 #[test]
4013 fn bare_join_defaults_to_inner() {
4014 let db = seed_join_fixture();
4015 let r = run_rows(
4016 &db,
4017 "SELECT customers.name FROM customers \
4018 JOIN orders ON customers.id = orders.customer_id;",
4019 );
4020 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
4021 }
4022
4023 #[test]
4024 fn left_outer_join_preserves_unmatched_left() {
4025 let db = seed_join_fixture();
4026 let r = run_rows(
4027 &db,
4028 "SELECT customers.name, orders.amount FROM customers \
4029 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
4030 );
4031 assert_eq!(r.rows.len(), 4);
4034 let carol = r
4035 .rows
4036 .iter()
4037 .find(|row| row[0].to_display_string() == "Carol")
4038 .expect("Carol should appear with a NULL-padded right side");
4039 assert_eq!(carol[1], Value::Null);
4040 }
4041
4042 #[test]
4043 fn right_outer_join_preserves_unmatched_right() {
4044 let db = seed_join_fixture();
4045 let r = run_rows(
4046 &db,
4047 "SELECT customers.name, orders.amount FROM customers \
4048 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
4049 );
4050 assert_eq!(r.rows.len(), 4);
4054 let dangling = r
4055 .rows
4056 .iter()
4057 .find(|row| matches!(row[1], Value::Integer(999)))
4058 .expect("dangling order 999 should appear with a NULL-padded customer name");
4059 assert_eq!(dangling[0], Value::Null);
4060 }
4061
4062 #[test]
4063 fn full_outer_join_preserves_both_sides() {
4064 let db = seed_join_fixture();
4065 let r = run_rows(
4066 &db,
4067 "SELECT customers.name, orders.amount FROM customers \
4068 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
4069 );
4070 assert_eq!(r.rows.len(), 5);
4073 assert!(
4075 r.rows
4076 .iter()
4077 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
4078 );
4079 assert!(
4081 r.rows
4082 .iter()
4083 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
4084 );
4085 }
4086
4087 #[test]
4088 fn join_with_table_aliases_resolves_qualifiers() {
4089 let db = seed_join_fixture();
4090 let r = run_rows(
4091 &db,
4092 "SELECT c.name, o.amount FROM customers AS c \
4093 INNER JOIN orders AS o ON c.id = o.customer_id;",
4094 );
4095 assert_eq!(r.rows.len(), 3);
4096 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4097 }
4098
4099 #[test]
4100 fn join_with_where_filter_applies_after_join() {
4101 let db = seed_join_fixture();
4102 let r = run_rows(
4105 &db,
4106 "SELECT customers.name, orders.amount FROM customers \
4107 INNER JOIN orders ON customers.id = orders.customer_id \
4108 WHERE orders.amount >= 100;",
4109 );
4110 assert_eq!(r.rows.len(), 2);
4111 assert!(
4112 r.rows
4113 .iter()
4114 .all(|row| row[0].to_display_string() == "Alice")
4115 );
4116 }
4117
4118 #[test]
4119 fn left_join_with_where_on_right_side_is_not_inner() {
4120 let db = seed_join_fixture();
4124 let r = run_rows(
4125 &db,
4126 "SELECT customers.name, orders.amount FROM customers \
4127 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
4128 WHERE orders.amount IS NULL;",
4129 );
4130 assert_eq!(r.rows.len(), 1);
4132 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4133 assert_eq!(r.rows[0][1], Value::Null);
4134 }
4135
4136 #[test]
4137 fn select_star_over_join_emits_all_columns_from_both_tables() {
4138 let db = seed_join_fixture();
4139 let r = run_rows(
4140 &db,
4141 "SELECT * FROM customers \
4142 INNER JOIN orders ON customers.id = orders.customer_id;",
4143 );
4144 assert_eq!(
4148 r.columns,
4149 vec![
4150 "id".to_string(),
4151 "name".to_string(),
4152 "id".to_string(),
4153 "customer_id".to_string(),
4154 "amount".to_string(),
4155 ]
4156 );
4157 assert_eq!(r.rows.len(), 3);
4158 }
4159
4160 #[test]
4161 fn join_order_by_sorts_full_joined_rows() {
4162 let db = seed_join_fixture();
4163 let r = run_rows(
4164 &db,
4165 "SELECT c.name, o.amount FROM customers AS c \
4166 INNER JOIN orders AS o ON c.id = o.customer_id \
4167 ORDER BY o.amount;",
4168 );
4169 let amounts: Vec<i64> = r
4170 .rows
4171 .iter()
4172 .map(|row| match row[1] {
4173 Value::Integer(i) => i,
4174 ref v => panic!("expected integer, got {v:?}"),
4175 })
4176 .collect();
4177 assert_eq!(amounts, vec![50, 100, 200]);
4178 }
4179
4180 #[test]
4181 fn join_limit_truncates_after_join_and_sort() {
4182 let db = seed_join_fixture();
4183 let r = run_rows(
4184 &db,
4185 "SELECT c.name, o.amount FROM customers AS c \
4186 INNER JOIN orders AS o ON c.id = o.customer_id \
4187 ORDER BY o.amount DESC LIMIT 2;",
4188 );
4189 assert_eq!(r.rows.len(), 2);
4190 let amounts: Vec<i64> = r
4192 .rows
4193 .iter()
4194 .map(|row| match row[1] {
4195 Value::Integer(i) => i,
4196 ref v => panic!("expected integer, got {v:?}"),
4197 })
4198 .collect();
4199 assert_eq!(amounts, vec![200, 100]);
4200 }
4201
4202 #[test]
4203 fn three_table_join_chains_correctly() {
4204 let mut db = Database::new("t".to_string());
4205 for sql in [
4206 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4207 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4208 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4209 "INSERT INTO a (label) VALUES ('a-one');",
4210 "INSERT INTO a (label) VALUES ('a-two');",
4211 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4212 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
4213 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
4214 ] {
4215 crate::sql::process_command(sql, &mut db).unwrap();
4216 }
4217 let r = run_rows(
4218 &db,
4219 "SELECT a.label, b.tag, c.note FROM a \
4220 INNER JOIN b ON a.id = b.a_id \
4221 INNER JOIN c ON b.id = c.b_id;",
4222 );
4223 assert_eq!(r.rows.len(), 1);
4225 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
4226 assert_eq!(r.rows[0][1].to_display_string(), "b1");
4227 assert_eq!(r.rows[0][2].to_display_string(), "c1");
4228 }
4229
4230 #[test]
4231 fn ambiguous_unqualified_column_in_join_errors() {
4232 let db = seed_join_fixture();
4236 let q = parse_select(
4237 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
4238 );
4239 let res = execute_select_rows(q, &db);
4240 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
4241 }
4242
4243 #[test]
4244 fn join_self_without_alias_is_rejected() {
4245 let mut db = Database::new("t".to_string());
4246 crate::sql::process_command(
4247 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
4248 &mut db,
4249 )
4250 .unwrap();
4251 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
4252 let res = execute_select_rows(q, &db);
4253 assert!(
4254 res.is_err(),
4255 "self-join without an alias should error on duplicate qualifier"
4256 );
4257 }
4258
4259 #[test]
4260 fn using_or_natural_join_returns_not_implemented() {
4261 let mut db = Database::new("t".to_string());
4262 crate::sql::process_command("CREATE TABLE a (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4263 crate::sql::process_command("CREATE TABLE b (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4264 let err = crate::sql::process_command("SELECT * FROM a INNER JOIN b USING (id);", &mut db);
4265 assert!(err.is_err(), "USING is not yet supported");
4266
4267 let err = crate::sql::process_command("SELECT * FROM a NATURAL JOIN b;", &mut db);
4268 assert!(err.is_err(), "NATURAL is not supported");
4269 }
4270
4271 #[test]
4272 fn aggregates_over_join_are_rejected() {
4273 let db = seed_join_fixture();
4274 let err = crate::sql::process_command(
4275 "SELECT COUNT(*) FROM customers \
4276 INNER JOIN orders ON customers.id = orders.customer_id;",
4277 &mut seed_join_fixture(),
4278 );
4279 assert!(err.is_err(), "aggregates over JOIN are not yet supported");
4280 let _ = db; }
4282
4283 #[test]
4284 fn left_join_with_no_matches_pads_every_row() {
4285 let mut db = Database::new("t".to_string());
4286 for sql in [
4287 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4288 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4289 "INSERT INTO a (x) VALUES (1);",
4290 "INSERT INTO a (x) VALUES (2);",
4291 "INSERT INTO b (y) VALUES (10);",
4292 ] {
4293 crate::sql::process_command(sql, &mut db).unwrap();
4294 }
4295 let r = run_rows(
4297 &db,
4298 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
4299 );
4300 assert_eq!(r.rows.len(), 2);
4301 for row in &r.rows {
4302 assert_eq!(row[1], Value::Null);
4303 }
4304 }
4305
4306 #[test]
4307 fn left_outer_join_order_by_places_nulls_first() {
4308 let db = seed_join_fixture();
4313 let r = run_rows(
4314 &db,
4315 "SELECT c.name, o.amount FROM customers AS c \
4316 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
4317 ORDER BY o.amount ASC;",
4318 );
4319 assert_eq!(r.rows.len(), 4);
4320 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4322 assert_eq!(r.rows[0][1], Value::Null);
4323 }
4324
4325 #[test]
4326 fn chained_left_outer_join_preserves_left_through_two_levels() {
4327 let mut db = Database::new("t".to_string());
4330 for sql in [
4331 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4332 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4333 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4334 "INSERT INTO a (label) VALUES ('a-one');",
4335 "INSERT INTO a (label) VALUES ('a-two');",
4336 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4338 ] {
4340 crate::sql::process_command(sql, &mut db).unwrap();
4341 }
4342 let r = run_rows(
4343 &db,
4344 "SELECT a.label, b.tag, c.note FROM a \
4345 LEFT OUTER JOIN b ON a.id = b.a_id \
4346 LEFT OUTER JOIN c ON b.id = c.b_id;",
4347 );
4348 assert_eq!(r.rows.len(), 2);
4350 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
4351 .rows
4352 .iter()
4353 .map(|row| (row[0].to_display_string(), row))
4354 .collect();
4355 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
4356 assert_eq!(by_label["a-one"][2], Value::Null);
4357 assert_eq!(by_label["a-two"][1], Value::Null);
4358 assert_eq!(by_label["a-two"][2], Value::Null);
4359 }
4360
4361 #[test]
4362 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
4363 let mut db = Database::new("t".to_string());
4367 for sql in [
4368 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4369 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
4370 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
4371 "INSERT INTO a (x) VALUES (1);",
4372 "INSERT INTO b (x) VALUES (1);",
4373 "INSERT INTO c (x) VALUES (1);",
4374 ] {
4375 crate::sql::process_command(sql, &mut db).unwrap();
4376 }
4377 let q =
4378 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
4379 let res = execute_select_rows(q, &db);
4380 assert!(
4381 res.is_err(),
4382 "ON referencing not-yet-joined table 'c' should error"
4383 );
4384 }
4385
4386 #[test]
4387 fn join_on_truthy_integer_is_accepted() {
4388 let mut db = Database::new("t".to_string());
4392 for sql in [
4393 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4394 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4395 "INSERT INTO a (x) VALUES (1);",
4396 "INSERT INTO a (x) VALUES (2);",
4397 "INSERT INTO b (y) VALUES (10);",
4398 "INSERT INTO b (y) VALUES (20);",
4399 ] {
4400 crate::sql::process_command(sql, &mut db).unwrap();
4401 }
4402 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
4403 assert_eq!(r.rows.len(), 4);
4405 }
4406
4407 #[test]
4408 fn full_join_on_empty_tables_returns_empty() {
4409 let mut db = Database::new("t".to_string());
4410 for sql in [
4411 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4412 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4413 ] {
4414 crate::sql::process_command(sql, &mut db).unwrap();
4415 }
4416 let r = run_rows(
4417 &db,
4418 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
4419 );
4420 assert!(r.rows.is_empty());
4421 }
4422}