1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, JoinType, OrderByClause, Projection, ProjectionItem, ProjectionKind, SelectQuery,
25};
26
27pub(crate) trait RowScope {
56 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
57
58 fn single_table_view(&self) -> Option<(&Table, i64)>;
64}
65
66pub(crate) struct SingleTableScope<'a> {
68 table: &'a Table,
69 rowid: i64,
70}
71
72impl<'a> SingleTableScope<'a> {
73 pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
74 Self { table, rowid }
75 }
76}
77
78impl RowScope for SingleTableScope<'_> {
79 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
80 let _ = qualifier;
85 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
86 }
87
88 fn single_table_view(&self) -> Option<(&Table, i64)> {
89 Some((self.table, self.rowid))
90 }
91}
92
93pub(crate) struct JoinedTableRef<'a> {
97 pub table: &'a Table,
98 pub scope_name: String,
99}
100
101pub(crate) struct JoinedScope<'a> {
105 pub tables: &'a [JoinedTableRef<'a>],
106 pub rowids: &'a [Option<i64>],
107}
108
109impl RowScope for JoinedScope<'_> {
110 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
111 if let Some(q) = qualifier {
112 let pos = self
115 .tables
116 .iter()
117 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
118 .ok_or_else(|| {
119 SQLRiteError::Internal(format!(
120 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
121 ))
122 })?;
123 if !self.tables[pos].table.contains_column(col.to_string()) {
124 return Err(SQLRiteError::Internal(format!(
125 "column '{col}' does not exist on '{}'",
126 self.tables[pos].scope_name
127 )));
128 }
129 return Ok(match self.rowids[pos] {
130 None => Value::Null,
131 Some(r) => self.tables[pos]
132 .table
133 .get_value(col, r)
134 .unwrap_or(Value::Null),
135 });
136 }
137 let mut hit: Option<usize> = None;
141 for (i, t) in self.tables.iter().enumerate() {
142 if t.table.contains_column(col.to_string()) {
143 if hit.is_some() {
144 return Err(SQLRiteError::Internal(format!(
145 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
146 )));
147 }
148 hit = Some(i);
149 }
150 }
151 let i = hit.ok_or_else(|| {
152 SQLRiteError::Internal(format!(
153 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
154 ))
155 })?;
156 Ok(match self.rowids[i] {
157 None => Value::Null,
158 Some(r) => self.tables[i]
159 .table
160 .get_value(col, r)
161 .unwrap_or(Value::Null),
162 })
163 }
164
165 fn single_table_view(&self) -> Option<(&Table, i64)> {
166 None
167 }
168}
169
170pub struct SelectResult {
179 pub columns: Vec<String>,
180 pub rows: Vec<Vec<Value>>,
181}
182
183pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
187 if !query.joins.is_empty() {
192 return execute_select_rows_joined(query, db);
193 }
194
195 let master_snapshot;
204 let table: &Table = if query.table_name == crate::sql::pager::MASTER_TABLE_NAME {
205 master_snapshot = crate::sql::pager::build_master_table_snapshot(db)?;
206 &master_snapshot
207 } else {
208 db.get_table(query.table_name.clone()).map_err(|_| {
209 SQLRiteError::Internal(format!("Table '{}' not found", query.table_name))
210 })?
211 };
212
213 let proj_items: Vec<ProjectionItem> = match &query.projection {
218 Projection::All => table
219 .column_names()
220 .into_iter()
221 .map(|c| ProjectionItem {
222 kind: ProjectionKind::Column {
223 qualifier: None,
224 name: c,
225 },
226 alias: None,
227 })
228 .collect(),
229 Projection::Items(items) => items.clone(),
230 };
231 let has_aggregates = proj_items
232 .iter()
233 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
234 for item in &proj_items {
236 if let ProjectionKind::Column { name: c, .. } = &item.kind
237 && !table.contains_column(c.clone())
238 {
239 return Err(SQLRiteError::Internal(format!(
240 "Column '{c}' does not exist on table '{}'",
241 query.table_name
242 )));
243 }
244 }
245 for c in &query.group_by {
246 if !table.contains_column(c.clone()) {
247 return Err(SQLRiteError::Internal(format!(
248 "GROUP BY references unknown column '{c}' on table '{}'",
249 query.table_name
250 )));
251 }
252 }
253 let matching = match select_rowids(table, query.selection.as_ref())? {
257 RowidSource::IndexProbe(rowids) => rowids,
258 RowidSource::FullScan => {
259 let mut out = Vec::new();
260 for rowid in table.rowids() {
261 if let Some(expr) = &query.selection
262 && !eval_predicate(expr, table, rowid)?
263 {
264 continue;
265 }
266 out.push(rowid);
267 }
268 out
269 }
270 };
271 let mut matching = matching;
272
273 let aggregating = has_aggregates || !query.group_by.is_empty();
274
275 if aggregating {
281 for item in &proj_items {
283 if let ProjectionKind::Aggregate(call) = &item.kind
284 && let AggregateArg::Column(c) = &call.arg
285 && !table.contains_column(c.clone())
286 {
287 return Err(SQLRiteError::Internal(format!(
288 "{}({}) references unknown column '{c}' on table '{}'",
289 call.func.as_str(),
290 c,
291 query.table_name
292 )));
293 }
294 }
295
296 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
297 let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
298
299 if query.distinct {
300 rows = dedupe_rows(rows);
301 }
302
303 if let Some(order) = &query.order_by {
304 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
305 }
306 if let Some(k) = query.limit {
307 rows.truncate(k);
308 }
309
310 return Ok(SelectResult { columns, rows });
311 }
312
313 let defer_limit_for_distinct = query.distinct;
351 match (&query.order_by, query.limit) {
352 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
353 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
354 }
355 (Some(order), Some(k))
356 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
357 {
358 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
359 }
360 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
361 matching = select_topk(&matching, table, order, k)?;
362 }
363 (Some(order), _) => {
364 sort_rowids(&mut matching, table, order)?;
365 if let Some(k) = query.limit
366 && !defer_limit_for_distinct
367 {
368 matching.truncate(k);
369 }
370 }
371 (None, Some(k)) if !defer_limit_for_distinct => {
372 matching.truncate(k);
373 }
374 _ => {}
375 }
376
377 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
378 let projected_cols: Vec<String> = proj_items
379 .iter()
380 .map(|i| match &i.kind {
381 ProjectionKind::Column { name, .. } => name.clone(),
382 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
383 })
384 .collect();
385
386 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
390 for rowid in &matching {
391 let row: Vec<Value> = projected_cols
392 .iter()
393 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
394 .collect();
395 rows.push(row);
396 }
397
398 if query.distinct {
399 rows = dedupe_rows(rows);
400 if let Some(k) = query.limit {
401 rows.truncate(k);
402 }
403 }
404
405 Ok(SelectResult { columns, rows })
406}
407
408fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
435 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
442
443 let primary = db
444 .get_table(query.table_name.clone())
445 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
446 joined_tables.push(JoinedTableRef {
447 table: primary,
448 scope_name: query
449 .table_alias
450 .clone()
451 .unwrap_or_else(|| query.table_name.clone()),
452 });
453 for j in &query.joins {
454 let t = db
455 .get_table(j.right_table.clone())
456 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
457 joined_tables.push(JoinedTableRef {
458 table: t,
459 scope_name: j
460 .right_alias
461 .clone()
462 .unwrap_or_else(|| j.right_table.clone()),
463 });
464 }
465
466 {
471 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
472 for t in &joined_tables {
473 let key = t.scope_name.to_ascii_lowercase();
474 if !seen.insert(key) {
475 return Err(SQLRiteError::Internal(format!(
476 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
477 t.scope_name
478 )));
479 }
480 }
481 }
482
483 let proj_items: Vec<ProjectionItem> = match &query.projection {
489 Projection::All => {
490 let mut all = Vec::new();
499 for t in &joined_tables {
500 for col in t.table.column_names() {
501 all.push(ProjectionItem {
502 kind: ProjectionKind::Column {
503 qualifier: Some(t.scope_name.clone()),
508 name: col,
509 },
510 alias: None,
511 });
512 }
513 }
514 all
515 }
516 Projection::Items(items) => items.clone(),
517 };
518
519 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
520
521 let mut acc: Vec<Vec<Option<i64>>> = primary
526 .rowids()
527 .into_iter()
528 .map(|r| {
529 let mut row = Vec::with_capacity(joined_tables.len());
530 row.push(Some(r));
531 row
532 })
533 .collect();
534
535 for (j_idx, join) in query.joins.iter().enumerate() {
540 let right_pos = j_idx + 1;
541 let right_table = joined_tables[right_pos].table;
542 let right_rowids: Vec<i64> = right_table.rowids();
543
544 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
548
549 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
550
551 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
559
560 for left_row in acc.into_iter() {
561 let mut left_match_count = 0usize;
565 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
566 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
567 on_rowids.push(Some(rrid));
568 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
569 let scope = JoinedScope {
570 tables: on_scope_tables,
571 rowids: &on_rowids,
572 };
573 if eval_predicate_scope(&join.on, &scope)? {
578 left_match_count += 1;
579 right_matched[r_idx] = true;
580 next_acc.push(on_rowids);
585 }
586 }
587
588 if left_match_count == 0
589 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
590 {
591 let mut padded = left_row;
594 padded.push(None);
595 next_acc.push(padded);
596 }
597 }
598
599 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
603 for (r_idx, matched) in right_matched.iter().enumerate() {
604 if *matched {
605 continue;
606 }
607 let mut row: Vec<Option<i64>> = vec![None; right_pos];
608 row.push(Some(right_rowids[r_idx]));
609 next_acc.push(row);
610 }
611 }
612
613 acc = next_acc;
614 }
615
616 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
621 let mut out = Vec::with_capacity(acc.len());
622 for row in acc {
623 let scope = JoinedScope {
624 tables: &joined_tables,
625 rowids: &row,
626 };
627 if eval_predicate_scope(where_expr, &scope)? {
628 out.push(row);
629 }
630 }
631 out
632 } else {
633 acc
634 };
635
636 if let Some(order) = &query.order_by {
640 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
643 for (i, row) in filtered.iter().enumerate() {
644 let scope = JoinedScope {
645 tables: &joined_tables,
646 rowids: row,
647 };
648 let v = eval_expr_scope(&order.expr, &scope)?;
649 keys.push((i, v));
650 }
651 keys.sort_by(|(_, a), (_, b)| {
652 let ord = compare_values(Some(a), Some(b));
653 if order.ascending { ord } else { ord.reverse() }
654 });
655 let mut sorted = Vec::with_capacity(filtered.len());
656 for (i, _) in keys {
657 sorted.push(filtered[i].clone());
658 }
659 filtered = sorted;
660 }
661
662 if let Some(k) = query.limit {
664 filtered.truncate(k);
665 }
666
667 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
670 for row in &filtered {
671 let scope = JoinedScope {
672 tables: &joined_tables,
673 rowids: row,
674 };
675 let mut out_row = Vec::with_capacity(proj_items.len());
676 for item in &proj_items {
677 let v = match &item.kind {
678 ProjectionKind::Column { qualifier, name } => {
679 scope.lookup(qualifier.as_deref(), name)?
680 }
681 ProjectionKind::Aggregate(_) => {
682 return Err(SQLRiteError::Internal(
685 "aggregate functions over JOIN are not supported".to_string(),
686 ));
687 }
688 };
689 out_row.push(v);
690 }
691 rows.push(out_row);
692 }
693
694 Ok(SelectResult { columns, rows })
695}
696
697pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
702 let result = execute_select_rows(query, db)?;
703 let row_count = result.rows.len();
704
705 let mut print_table = PrintTable::new();
706 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
707 print_table.add_row(PrintRow::new(header_cells));
708
709 for row in &result.rows {
710 let cells: Vec<PrintCell> = row
711 .iter()
712 .map(|v| PrintCell::new(&v.to_display_string()))
713 .collect();
714 print_table.add_row(PrintRow::new(cells));
715 }
716
717 Ok((print_table.to_string(), row_count))
718}
719
720pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
722 let Statement::Delete(Delete {
723 from, selection, ..
724 }) = stmt
725 else {
726 return Err(SQLRiteError::Internal(
727 "execute_delete called on a non-DELETE statement".to_string(),
728 ));
729 };
730
731 let tables = match from {
732 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
733 };
734 let table_name = extract_single_table_name(tables)?;
735
736 let matching: Vec<i64> = {
738 let table = db
739 .get_table(table_name.clone())
740 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
741 match select_rowids(table, selection.as_ref())? {
742 RowidSource::IndexProbe(rowids) => rowids,
743 RowidSource::FullScan => {
744 let mut out = Vec::new();
745 for rowid in table.rowids() {
746 if let Some(expr) = selection {
747 if !eval_predicate(expr, table, rowid)? {
748 continue;
749 }
750 }
751 out.push(rowid);
752 }
753 out
754 }
755 }
756 };
757
758 let table = db.get_table_mut(table_name)?;
759 for rowid in &matching {
760 table.delete_row(*rowid);
761 }
762 if !matching.is_empty() {
771 for entry in &mut table.hnsw_indexes {
772 entry.needs_rebuild = true;
773 }
774 for entry in &mut table.fts_indexes {
775 entry.needs_rebuild = true;
776 }
777 }
778 Ok(matching.len())
779}
780
781pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
783 let Statement::Update(Update {
784 table,
785 assignments,
786 from,
787 selection,
788 ..
789 }) = stmt
790 else {
791 return Err(SQLRiteError::Internal(
792 "execute_update called on a non-UPDATE statement".to_string(),
793 ));
794 };
795
796 if from.is_some() {
797 return Err(SQLRiteError::NotImplemented(
798 "UPDATE ... FROM is not supported yet".to_string(),
799 ));
800 }
801
802 let table_name = extract_table_name(table)?;
803
804 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
806 {
807 let tbl = db
808 .get_table(table_name.clone())
809 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
810 for a in assignments {
811 let col = match &a.target {
812 AssignmentTarget::ColumnName(name) => name
813 .0
814 .last()
815 .map(|p| p.to_string())
816 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
817 AssignmentTarget::Tuple(_) => {
818 return Err(SQLRiteError::NotImplemented(
819 "tuple assignment targets are not supported".to_string(),
820 ));
821 }
822 };
823 if !tbl.contains_column(col.clone()) {
824 return Err(SQLRiteError::Internal(format!(
825 "UPDATE references unknown column '{col}'"
826 )));
827 }
828 parsed_assignments.push((col, a.value.clone()));
829 }
830 }
831
832 let work: Vec<(i64, Vec<(String, Value)>)> = {
836 let tbl = db.get_table(table_name.clone())?;
837 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
838 RowidSource::IndexProbe(rowids) => rowids,
839 RowidSource::FullScan => {
840 let mut out = Vec::new();
841 for rowid in tbl.rowids() {
842 if let Some(expr) = selection {
843 if !eval_predicate(expr, tbl, rowid)? {
844 continue;
845 }
846 }
847 out.push(rowid);
848 }
849 out
850 }
851 };
852 let mut rows_to_update = Vec::new();
853 for rowid in matched_rowids {
854 let mut values = Vec::with_capacity(parsed_assignments.len());
855 for (col, expr) in &parsed_assignments {
856 let v = eval_expr(expr, tbl, rowid)?;
859 values.push((col.clone(), v));
860 }
861 rows_to_update.push((rowid, values));
862 }
863 rows_to_update
864 };
865
866 let tbl = db.get_table_mut(table_name)?;
867 for (rowid, values) in &work {
868 for (col, v) in values {
869 tbl.set_value(col, *rowid, v.clone())?;
870 }
871 }
872
873 if !work.is_empty() {
882 let updated_columns: std::collections::HashSet<&str> = work
883 .iter()
884 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
885 .collect();
886 for entry in &mut tbl.hnsw_indexes {
887 if updated_columns.contains(entry.column_name.as_str()) {
888 entry.needs_rebuild = true;
889 }
890 }
891 for entry in &mut tbl.fts_indexes {
892 if updated_columns.contains(entry.column_name.as_str()) {
893 entry.needs_rebuild = true;
894 }
895 }
896 }
897 Ok(work.len())
898}
899
900pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
912 let Statement::CreateIndex(CreateIndex {
913 name,
914 table_name,
915 columns,
916 using,
917 unique,
918 if_not_exists,
919 predicate,
920 with,
921 ..
922 }) = stmt
923 else {
924 return Err(SQLRiteError::Internal(
925 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
926 ));
927 };
928
929 if predicate.is_some() {
930 return Err(SQLRiteError::NotImplemented(
931 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
932 ));
933 }
934
935 if columns.len() != 1 {
936 return Err(SQLRiteError::NotImplemented(format!(
937 "multi-column indexes are not supported yet ({} columns given)",
938 columns.len()
939 )));
940 }
941
942 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
943 SQLRiteError::NotImplemented(
944 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
945 )
946 })?;
947
948 let method = match using {
954 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
955 IndexMethod::Hnsw
956 }
957 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
958 IndexMethod::Fts
959 }
960 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
961 IndexMethod::Btree
962 }
963 Some(other) => {
964 return Err(SQLRiteError::NotImplemented(format!(
965 "CREATE INDEX … USING {other:?} is not supported \
966 (try `hnsw`, `fts`, or no USING clause)"
967 )));
968 }
969 None => IndexMethod::Btree,
970 };
971
972 let hnsw_metric = parse_hnsw_with_options(with, &index_name, method)?;
978
979 let table_name_str = table_name.to_string();
980 let column_name = match &columns[0].column.expr {
981 Expr::Identifier(ident) => ident.value.clone(),
982 Expr::CompoundIdentifier(parts) => parts
983 .last()
984 .map(|p| p.value.clone())
985 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
986 other => {
987 return Err(SQLRiteError::NotImplemented(format!(
988 "CREATE INDEX only supports simple column references, got {other:?}"
989 )));
990 }
991 };
992
993 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
998 let table = db.get_table(table_name_str.clone()).map_err(|_| {
999 SQLRiteError::General(format!(
1000 "CREATE INDEX references unknown table '{table_name_str}'"
1001 ))
1002 })?;
1003 if !table.contains_column(column_name.clone()) {
1004 return Err(SQLRiteError::General(format!(
1005 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
1006 )));
1007 }
1008 let col = table
1009 .columns
1010 .iter()
1011 .find(|c| c.column_name == column_name)
1012 .expect("we just verified the column exists");
1013
1014 if table.index_by_name(&index_name).is_some()
1017 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
1018 || table.fts_indexes.iter().any(|i| i.name == index_name)
1019 {
1020 if *if_not_exists {
1021 return Ok(index_name);
1022 }
1023 return Err(SQLRiteError::General(format!(
1024 "index '{index_name}' already exists"
1025 )));
1026 }
1027 let datatype = clone_datatype(&col.datatype);
1028
1029 let mut pairs = Vec::new();
1030 for rowid in table.rowids() {
1031 if let Some(v) = table.get_value(&column_name, rowid) {
1032 pairs.push((rowid, v));
1033 }
1034 }
1035 (datatype, pairs)
1036 };
1037
1038 match method {
1039 IndexMethod::Btree => create_btree_index(
1040 db,
1041 &table_name_str,
1042 &index_name,
1043 &column_name,
1044 &datatype,
1045 *unique,
1046 &existing_rowids_and_values,
1047 ),
1048 IndexMethod::Hnsw => create_hnsw_index(
1049 db,
1050 &table_name_str,
1051 &index_name,
1052 &column_name,
1053 &datatype,
1054 *unique,
1055 hnsw_metric.unwrap_or(DistanceMetric::L2),
1056 &existing_rowids_and_values,
1057 ),
1058 IndexMethod::Fts => create_fts_index(
1059 db,
1060 &table_name_str,
1061 &index_name,
1062 &column_name,
1063 &datatype,
1064 *unique,
1065 &existing_rowids_and_values,
1066 ),
1067 }
1068}
1069
1070pub fn execute_drop_table(
1081 names: &[ObjectName],
1082 if_exists: bool,
1083 db: &mut Database,
1084) -> Result<usize> {
1085 if names.len() != 1 {
1086 return Err(SQLRiteError::NotImplemented(
1087 "DROP TABLE supports a single table per statement".to_string(),
1088 ));
1089 }
1090 let name = names[0].to_string();
1091
1092 if name == crate::sql::pager::MASTER_TABLE_NAME {
1093 return Err(SQLRiteError::General(format!(
1094 "'{}' is a reserved name used by the internal schema catalog",
1095 crate::sql::pager::MASTER_TABLE_NAME
1096 )));
1097 }
1098
1099 if !db.contains_table(name.clone()) {
1100 return if if_exists {
1101 Ok(0)
1102 } else {
1103 Err(SQLRiteError::General(format!(
1104 "Table '{name}' does not exist"
1105 )))
1106 };
1107 }
1108
1109 db.tables.remove(&name);
1110 Ok(1)
1111}
1112
1113pub fn execute_drop_index(
1122 names: &[ObjectName],
1123 if_exists: bool,
1124 db: &mut Database,
1125) -> Result<usize> {
1126 if names.len() != 1 {
1127 return Err(SQLRiteError::NotImplemented(
1128 "DROP INDEX supports a single index per statement".to_string(),
1129 ));
1130 }
1131 let name = names[0].to_string();
1132
1133 for table in db.tables.values_mut() {
1134 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1135 if secondary.origin == IndexOrigin::Auto {
1136 return Err(SQLRiteError::General(format!(
1137 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1138 )));
1139 }
1140 table.secondary_indexes.retain(|i| i.name != name);
1141 return Ok(1);
1142 }
1143 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1144 table.hnsw_indexes.retain(|i| i.name != name);
1145 return Ok(1);
1146 }
1147 if table.fts_indexes.iter().any(|i| i.name == name) {
1148 table.fts_indexes.retain(|i| i.name != name);
1149 return Ok(1);
1150 }
1151 }
1152
1153 if if_exists {
1154 Ok(0)
1155 } else {
1156 Err(SQLRiteError::General(format!(
1157 "Index '{name}' does not exist"
1158 )))
1159 }
1160}
1161
1162pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1174 let table_name = alter.name.to_string();
1175
1176 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1177 return Err(SQLRiteError::General(format!(
1178 "'{}' is a reserved name used by the internal schema catalog",
1179 crate::sql::pager::MASTER_TABLE_NAME
1180 )));
1181 }
1182
1183 if !db.contains_table(table_name.clone()) {
1184 return if alter.if_exists {
1185 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1186 } else {
1187 Err(SQLRiteError::General(format!(
1188 "Table '{table_name}' does not exist"
1189 )))
1190 };
1191 }
1192
1193 if alter.operations.len() != 1 {
1194 return Err(SQLRiteError::NotImplemented(
1195 "ALTER TABLE supports one operation per statement".to_string(),
1196 ));
1197 }
1198
1199 match &alter.operations[0] {
1200 AlterTableOperation::RenameTable { table_name: kind } => {
1201 let new_name = match kind {
1202 RenameTableNameKind::To(name) => name.to_string(),
1203 RenameTableNameKind::As(_) => {
1204 return Err(SQLRiteError::NotImplemented(
1205 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1206 .to_string(),
1207 ));
1208 }
1209 };
1210 alter_rename_table(db, &table_name, &new_name)?;
1211 Ok(format!(
1212 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1213 ))
1214 }
1215 AlterTableOperation::RenameColumn {
1216 old_column_name,
1217 new_column_name,
1218 } => {
1219 let old = old_column_name.value.clone();
1220 let new = new_column_name.value.clone();
1221 db.get_table_mut(table_name.clone())?
1222 .rename_column(&old, &new)?;
1223 Ok(format!(
1224 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1225 ))
1226 }
1227 AlterTableOperation::AddColumn {
1228 column_def,
1229 if_not_exists,
1230 ..
1231 } => {
1232 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1233 let table = db.get_table_mut(table_name.clone())?;
1234 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1235 return Ok(format!(
1236 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1237 parsed.name
1238 ));
1239 }
1240 let col_name = parsed.name.clone();
1241 table.add_column(parsed)?;
1242 Ok(format!(
1243 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1244 ))
1245 }
1246 AlterTableOperation::DropColumn {
1247 column_names,
1248 if_exists,
1249 ..
1250 } => {
1251 if column_names.len() != 1 {
1252 return Err(SQLRiteError::NotImplemented(
1253 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1254 ));
1255 }
1256 let col_name = column_names[0].value.clone();
1257 let table = db.get_table_mut(table_name.clone())?;
1258 if *if_exists && !table.contains_column(col_name.clone()) {
1259 return Ok(format!(
1260 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1261 ));
1262 }
1263 table.drop_column(&col_name)?;
1264 Ok(format!(
1265 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1266 ))
1267 }
1268 other => Err(SQLRiteError::NotImplemented(format!(
1269 "ALTER TABLE operation {other:?} is not supported"
1270 ))),
1271 }
1272}
1273
1274pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1284 if db.in_transaction() {
1285 return Err(SQLRiteError::General(
1286 "VACUUM cannot run inside a transaction".to_string(),
1287 ));
1288 }
1289 let path = match db.source_path.clone() {
1290 Some(p) => p,
1291 None => {
1292 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1293 }
1294 };
1295 if let Some(pager) = db.pager.as_mut() {
1301 let _ = pager.checkpoint();
1302 }
1303 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1304 let pages_before = db
1305 .pager
1306 .as_ref()
1307 .map(|p| p.header().page_count)
1308 .unwrap_or(0);
1309 crate::sql::pager::vacuum_database(db, &path)?;
1310 if let Some(pager) = db.pager.as_mut() {
1313 let _ = pager.checkpoint();
1314 }
1315 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1316 let pages_after = db
1317 .pager
1318 .as_ref()
1319 .map(|p| p.header().page_count)
1320 .unwrap_or(0);
1321 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1322 let bytes_reclaimed = size_before.saturating_sub(size_after);
1323 Ok(format!(
1324 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1325 ))
1326}
1327
1328fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1334 if new == crate::sql::pager::MASTER_TABLE_NAME {
1335 return Err(SQLRiteError::General(format!(
1336 "'{}' is a reserved name used by the internal schema catalog",
1337 crate::sql::pager::MASTER_TABLE_NAME
1338 )));
1339 }
1340 if old == new {
1341 return Ok(());
1342 }
1343 if db.contains_table(new.to_string()) {
1344 return Err(SQLRiteError::General(format!(
1345 "target table '{new}' already exists"
1346 )));
1347 }
1348
1349 let mut table = db
1350 .tables
1351 .remove(old)
1352 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1353 table.tb_name = new.to_string();
1354 for idx in table.secondary_indexes.iter_mut() {
1355 idx.table_name = new.to_string();
1356 if idx.origin == IndexOrigin::Auto
1357 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1358 {
1359 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1360 }
1361 }
1362 db.tables.insert(new.to_string(), table);
1363 Ok(())
1364}
1365
1366#[derive(Debug, Clone, Copy)]
1370enum IndexMethod {
1371 Btree,
1372 Hnsw,
1373 Fts,
1375}
1376
1377fn create_btree_index(
1379 db: &mut Database,
1380 table_name: &str,
1381 index_name: &str,
1382 column_name: &str,
1383 datatype: &DataType,
1384 unique: bool,
1385 existing: &[(i64, Value)],
1386) -> Result<String> {
1387 let mut idx = SecondaryIndex::new(
1388 index_name.to_string(),
1389 table_name.to_string(),
1390 column_name.to_string(),
1391 datatype,
1392 unique,
1393 IndexOrigin::Explicit,
1394 )?;
1395
1396 for (rowid, v) in existing {
1400 if unique && idx.would_violate_unique(v) {
1401 return Err(SQLRiteError::General(format!(
1402 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1403 already contains the duplicate value {}",
1404 v.to_display_string()
1405 )));
1406 }
1407 idx.insert(v, *rowid)?;
1408 }
1409
1410 let table_mut = db.get_table_mut(table_name.to_string())?;
1411 table_mut.secondary_indexes.push(idx);
1412 Ok(index_name.to_string())
1413}
1414
1415fn create_hnsw_index(
1417 db: &mut Database,
1418 table_name: &str,
1419 index_name: &str,
1420 column_name: &str,
1421 datatype: &DataType,
1422 unique: bool,
1423 metric: DistanceMetric,
1424 existing: &[(i64, Value)],
1425) -> Result<String> {
1426 let dim = match datatype {
1429 DataType::Vector(d) => *d,
1430 other => {
1431 return Err(SQLRiteError::General(format!(
1432 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1433 )));
1434 }
1435 };
1436
1437 if unique {
1438 return Err(SQLRiteError::General(
1439 "UNIQUE has no meaning for HNSW indexes".to_string(),
1440 ));
1441 }
1442
1443 let seed = hash_str_to_seed(index_name);
1454 let mut idx = HnswIndex::new(metric, seed);
1455
1456 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1460 std::collections::HashMap::with_capacity(existing.len());
1461 for (rowid, v) in existing {
1462 match v {
1463 Value::Vector(vec) => {
1464 if vec.len() != dim {
1465 return Err(SQLRiteError::Internal(format!(
1466 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1467 declared as VECTOR({dim}) — schema invariant violated",
1468 vec.len()
1469 )));
1470 }
1471 vec_map.insert(*rowid, vec.clone());
1472 }
1473 _ => continue,
1477 }
1478 }
1479
1480 for (rowid, _) in existing {
1481 if let Some(v) = vec_map.get(rowid) {
1482 let v_clone = v.clone();
1483 idx.insert(*rowid, &v_clone, |id| {
1484 vec_map.get(&id).cloned().unwrap_or_default()
1485 })?;
1486 }
1487 }
1488
1489 let table_mut = db.get_table_mut(table_name.to_string())?;
1490 table_mut.hnsw_indexes.push(HnswIndexEntry {
1491 name: index_name.to_string(),
1492 column_name: column_name.to_string(),
1493 metric,
1494 index: idx,
1495 needs_rebuild: false,
1497 });
1498 Ok(index_name.to_string())
1499}
1500
1501fn parse_hnsw_with_options(
1512 with: &[Expr],
1513 index_name: &str,
1514 method: IndexMethod,
1515) -> Result<Option<DistanceMetric>> {
1516 if with.is_empty() {
1517 return Ok(None);
1518 }
1519 if !matches!(method, IndexMethod::Hnsw) {
1520 return Err(SQLRiteError::General(format!(
1521 "CREATE INDEX '{index_name}' has a WITH (...) clause but its index method \
1522 doesn't support any options — only `USING hnsw` recognises `WITH (metric = ...)`"
1523 )));
1524 }
1525
1526 let mut metric: Option<DistanceMetric> = None;
1527 for opt in with {
1528 let Expr::BinaryOp { left, op, right } = opt else {
1529 return Err(SQLRiteError::General(format!(
1530 "CREATE INDEX '{index_name}': unsupported WITH option {opt:?} \
1531 (expected `key = 'value'`)"
1532 )));
1533 };
1534 if !matches!(op, BinaryOperator::Eq) {
1535 return Err(SQLRiteError::General(format!(
1536 "CREATE INDEX '{index_name}': WITH options must use `=` (got {op:?})"
1537 )));
1538 }
1539 let key = match left.as_ref() {
1540 Expr::Identifier(ident) => ident.value.clone(),
1541 other => {
1542 return Err(SQLRiteError::General(format!(
1543 "CREATE INDEX '{index_name}': WITH option key must be a bare identifier, \
1544 got {other:?}"
1545 )));
1546 }
1547 };
1548 let value = match right.as_ref() {
1549 Expr::Value(v) => match &v.value {
1550 AstValue::SingleQuotedString(s) => s.clone(),
1551 AstValue::DoubleQuotedString(s) => s.clone(),
1552 other => {
1553 return Err(SQLRiteError::General(format!(
1554 "CREATE INDEX '{index_name}': WITH option '{key}' value must be \
1555 a quoted string, got {other:?}"
1556 )));
1557 }
1558 },
1559 Expr::Identifier(ident) => ident.value.clone(),
1560 other => {
1561 return Err(SQLRiteError::General(format!(
1562 "CREATE INDEX '{index_name}': WITH option '{key}' value must be a \
1563 quoted string, got {other:?}"
1564 )));
1565 }
1566 };
1567
1568 if key.eq_ignore_ascii_case("metric") {
1569 let parsed = DistanceMetric::from_sql_name(&value).ok_or_else(|| {
1570 SQLRiteError::General(format!(
1571 "CREATE INDEX '{index_name}': unknown HNSW metric '{value}' \
1572 (try 'l2', 'cosine', or 'dot')"
1573 ))
1574 })?;
1575 if metric.is_some() {
1576 return Err(SQLRiteError::General(format!(
1577 "CREATE INDEX '{index_name}': metric specified more than once in WITH (...)"
1578 )));
1579 }
1580 metric = Some(parsed);
1581 } else {
1582 return Err(SQLRiteError::General(format!(
1583 "CREATE INDEX '{index_name}': unknown WITH option '{key}' \
1584 (only 'metric' is recognised on HNSW indexes)"
1585 )));
1586 }
1587 }
1588
1589 Ok(metric)
1590}
1591
1592fn create_fts_index(
1597 db: &mut Database,
1598 table_name: &str,
1599 index_name: &str,
1600 column_name: &str,
1601 datatype: &DataType,
1602 unique: bool,
1603 existing: &[(i64, Value)],
1604) -> Result<String> {
1605 match datatype {
1610 DataType::Text => {}
1611 other => {
1612 return Err(SQLRiteError::General(format!(
1613 "USING fts requires a TEXT column; '{column_name}' is {other}"
1614 )));
1615 }
1616 }
1617
1618 if unique {
1619 return Err(SQLRiteError::General(
1620 "UNIQUE has no meaning for FTS indexes".to_string(),
1621 ));
1622 }
1623
1624 let mut idx = PostingList::new();
1625 for (rowid, v) in existing {
1626 if let Value::Text(text) = v {
1627 idx.insert(*rowid, text);
1628 }
1629 }
1632
1633 let table_mut = db.get_table_mut(table_name.to_string())?;
1634 table_mut.fts_indexes.push(FtsIndexEntry {
1635 name: index_name.to_string(),
1636 column_name: column_name.to_string(),
1637 index: idx,
1638 needs_rebuild: false,
1639 });
1640 Ok(index_name.to_string())
1641}
1642
1643fn hash_str_to_seed(s: &str) -> u64 {
1647 let mut h: u64 = 0xCBF29CE484222325;
1648 for b in s.as_bytes() {
1649 h ^= *b as u64;
1650 h = h.wrapping_mul(0x100000001B3);
1651 }
1652 h
1653}
1654
1655fn clone_datatype(dt: &DataType) -> DataType {
1658 match dt {
1659 DataType::Integer => DataType::Integer,
1660 DataType::Text => DataType::Text,
1661 DataType::Real => DataType::Real,
1662 DataType::Bool => DataType::Bool,
1663 DataType::Vector(dim) => DataType::Vector(*dim),
1664 DataType::Json => DataType::Json,
1665 DataType::None => DataType::None,
1666 DataType::Invalid => DataType::Invalid,
1667 }
1668}
1669
1670fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1671 if tables.len() != 1 {
1672 return Err(SQLRiteError::NotImplemented(
1673 "multi-table DELETE is not supported yet".to_string(),
1674 ));
1675 }
1676 extract_table_name(&tables[0])
1677}
1678
1679fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1680 if !twj.joins.is_empty() {
1681 return Err(SQLRiteError::NotImplemented(
1682 "JOIN is not supported yet".to_string(),
1683 ));
1684 }
1685 match &twj.relation {
1686 TableFactor::Table { name, .. } => Ok(name.to_string()),
1687 _ => Err(SQLRiteError::NotImplemented(
1688 "only plain table references are supported".to_string(),
1689 )),
1690 }
1691}
1692
1693enum RowidSource {
1695 IndexProbe(Vec<i64>),
1699 FullScan,
1702}
1703
1704fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1709 let Some(expr) = selection else {
1710 return Ok(RowidSource::FullScan);
1711 };
1712 let Some((col, literal)) = try_extract_equality(expr) else {
1713 return Ok(RowidSource::FullScan);
1714 };
1715 let Some(idx) = table.index_for_column(&col) else {
1716 return Ok(RowidSource::FullScan);
1717 };
1718
1719 let literal_value = match convert_literal(&literal) {
1723 Ok(v) => v,
1724 Err(_) => return Ok(RowidSource::FullScan),
1725 };
1726
1727 let mut rowids = idx.lookup(&literal_value);
1731 rowids.sort_unstable();
1732 Ok(RowidSource::IndexProbe(rowids))
1733}
1734
1735fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1739 let peeled = match expr {
1741 Expr::Nested(inner) => inner.as_ref(),
1742 other => other,
1743 };
1744 let Expr::BinaryOp { left, op, right } = peeled else {
1745 return None;
1746 };
1747 if !matches!(op, BinaryOperator::Eq) {
1748 return None;
1749 }
1750 let col_from = |e: &Expr| -> Option<String> {
1751 match e {
1752 Expr::Identifier(ident) => Some(ident.value.clone()),
1753 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1754 _ => None,
1755 }
1756 };
1757 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1758 if let Expr::Value(v) = e {
1759 Some(v.value.clone())
1760 } else {
1761 None
1762 }
1763 };
1764 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1765 return Some((c, l));
1766 }
1767 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1768 return Some((c, l));
1769 }
1770 None
1771}
1772
1773fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1798 if k == 0 {
1799 return None;
1800 }
1801
1802 let func = match order_expr {
1805 Expr::Function(f) => f,
1806 _ => return None,
1807 };
1808 let fname = match func.name.0.as_slice() {
1809 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1810 _ => return None,
1811 };
1812 let query_metric = match fname.as_str() {
1813 "vec_distance_l2" => DistanceMetric::L2,
1814 "vec_distance_cosine" => DistanceMetric::Cosine,
1815 "vec_distance_dot" => DistanceMetric::Dot,
1816 _ => return None,
1817 };
1818
1819 let arg_list = match &func.args {
1821 FunctionArguments::List(l) => &l.args,
1822 _ => return None,
1823 };
1824 if arg_list.len() != 2 {
1825 return None;
1826 }
1827 let exprs: Vec<&Expr> = arg_list
1828 .iter()
1829 .filter_map(|a| match a {
1830 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1831 _ => None,
1832 })
1833 .collect();
1834 if exprs.len() != 2 {
1835 return None;
1836 }
1837
1838 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1843 Some(v) => v,
1844 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1845 Some(v) => v,
1846 None => return None,
1847 },
1848 };
1849
1850 let entry = table
1855 .hnsw_indexes
1856 .iter()
1857 .find(|e| e.column_name == col_name && e.metric == query_metric)?;
1858
1859 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
1865 Some(c) => match &c.datatype {
1866 DataType::Vector(d) => *d,
1867 _ => return None,
1868 },
1869 None => return None,
1870 };
1871 if query_vec.len() != declared_dim {
1872 return None;
1873 }
1874
1875 let column_for_closure = col_name.clone();
1879 let table_ref = table;
1880 let result = entry
1881 .index
1882 .search(&query_vec, k, |id| {
1883 match table_ref.get_value(&column_for_closure, id) {
1884 Some(Value::Vector(v)) => v,
1885 _ => Vec::new(),
1886 }
1887 })
1888 .ok()?;
1889 Some(result)
1890}
1891
1892fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
1908 if k == 0 || ascending {
1909 return None;
1913 }
1914
1915 let func = match order_expr {
1916 Expr::Function(f) => f,
1917 _ => return None,
1918 };
1919 let fname = match func.name.0.as_slice() {
1920 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1921 _ => return None,
1922 };
1923 if fname != "bm25_score" {
1924 return None;
1925 }
1926
1927 let arg_list = match &func.args {
1928 FunctionArguments::List(l) => &l.args,
1929 _ => return None,
1930 };
1931 if arg_list.len() != 2 {
1932 return None;
1933 }
1934 let exprs: Vec<&Expr> = arg_list
1935 .iter()
1936 .filter_map(|a| match a {
1937 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1938 _ => None,
1939 })
1940 .collect();
1941 if exprs.len() != 2 {
1942 return None;
1943 }
1944
1945 let col_name = match exprs[0] {
1947 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1948 _ => return None,
1949 };
1950
1951 let query = match exprs[1] {
1955 Expr::Value(v) => match &v.value {
1956 AstValue::SingleQuotedString(s) => s.clone(),
1957 _ => return None,
1958 },
1959 _ => return None,
1960 };
1961
1962 let entry = table
1963 .fts_indexes
1964 .iter()
1965 .find(|e| e.column_name == col_name)?;
1966
1967 let scored = entry.index.query(&query, &Bm25Params::default());
1968 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1969 if out.len() > k {
1970 out.truncate(k);
1971 }
1972 Some(out)
1973}
1974
1975fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1980 let col_name = match a {
1981 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1982 _ => return None,
1983 };
1984 let lit_str = match b {
1985 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1986 format!("[{}]", ident.value)
1987 }
1988 _ => return None,
1989 };
1990 let v = parse_vector_literal(&lit_str).ok()?;
1991 Some((col_name, v))
1992}
1993
1994struct HeapEntry {
2007 key: Value,
2008 rowid: i64,
2009 asc: bool,
2010}
2011
2012impl PartialEq for HeapEntry {
2013 fn eq(&self, other: &Self) -> bool {
2014 self.cmp(other) == Ordering::Equal
2015 }
2016}
2017
2018impl Eq for HeapEntry {}
2019
2020impl PartialOrd for HeapEntry {
2021 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2022 Some(self.cmp(other))
2023 }
2024}
2025
2026impl Ord for HeapEntry {
2027 fn cmp(&self, other: &Self) -> Ordering {
2028 let raw = compare_values(Some(&self.key), Some(&other.key));
2029 if self.asc { raw } else { raw.reverse() }
2030 }
2031}
2032
2033fn select_topk(
2042 matching: &[i64],
2043 table: &Table,
2044 order: &OrderByClause,
2045 k: usize,
2046) -> Result<Vec<i64>> {
2047 use std::collections::BinaryHeap;
2048
2049 if k == 0 || matching.is_empty() {
2050 return Ok(Vec::new());
2051 }
2052
2053 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
2054
2055 for &rowid in matching {
2056 let key = eval_expr(&order.expr, table, rowid)?;
2057 let entry = HeapEntry {
2058 key,
2059 rowid,
2060 asc: order.ascending,
2061 };
2062
2063 if heap.len() < k {
2064 heap.push(entry);
2065 } else {
2066 if entry < *heap.peek().unwrap() {
2070 heap.pop();
2071 heap.push(entry);
2072 }
2073 }
2074 }
2075
2076 Ok(heap
2081 .into_sorted_vec()
2082 .into_iter()
2083 .map(|e| e.rowid)
2084 .collect())
2085}
2086
2087fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
2088 let mut keys: Vec<(i64, Result<Value>)> = rowids
2096 .iter()
2097 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
2098 .collect();
2099
2100 for (_, k) in &keys {
2104 if let Err(e) = k {
2105 return Err(SQLRiteError::General(format!(
2106 "ORDER BY expression failed: {e}"
2107 )));
2108 }
2109 }
2110
2111 keys.sort_by(|(_, ka), (_, kb)| {
2112 let va = ka.as_ref().unwrap();
2115 let vb = kb.as_ref().unwrap();
2116 let ord = compare_values(Some(va), Some(vb));
2117 if order.ascending { ord } else { ord.reverse() }
2118 });
2119
2120 for (i, (rowid, _)) in keys.into_iter().enumerate() {
2122 rowids[i] = rowid;
2123 }
2124 Ok(())
2125}
2126
2127fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
2128 match (a, b) {
2129 (None, None) => Ordering::Equal,
2130 (None, _) => Ordering::Less,
2131 (_, None) => Ordering::Greater,
2132 (Some(a), Some(b)) => match (a, b) {
2133 (Value::Null, Value::Null) => Ordering::Equal,
2134 (Value::Null, _) => Ordering::Less,
2135 (_, Value::Null) => Ordering::Greater,
2136 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2137 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2138 (Value::Integer(x), Value::Real(y)) => {
2139 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2140 }
2141 (Value::Real(x), Value::Integer(y)) => {
2142 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2143 }
2144 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2145 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2146 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2148 },
2149 }
2150}
2151
2152pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
2154 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
2155}
2156
2157pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2161 let v = eval_expr_scope(expr, scope)?;
2162 match v {
2163 Value::Bool(b) => Ok(b),
2164 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2166 other => Err(SQLRiteError::Internal(format!(
2167 "WHERE clause must evaluate to boolean, got {}",
2168 other.to_display_string()
2169 ))),
2170 }
2171}
2172
2173fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
2175 eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
2176}
2177
2178fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2179 match expr {
2180 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2181
2182 Expr::Identifier(ident) => {
2183 if ident.quote_style == Some('[') {
2193 let raw = format!("[{}]", ident.value);
2194 let v = parse_vector_literal(&raw)?;
2195 return Ok(Value::Vector(v));
2196 }
2197 scope.lookup(None, &ident.value)
2198 }
2199
2200 Expr::CompoundIdentifier(parts) => {
2201 match parts.as_slice() {
2207 [only] => scope.lookup(None, &only.value),
2208 [q, c] => scope.lookup(Some(&q.value), &c.value),
2209 _ => Err(SQLRiteError::NotImplemented(format!(
2210 "compound identifier with {} parts is not supported",
2211 parts.len()
2212 ))),
2213 }
2214 }
2215
2216 Expr::Value(v) => convert_literal(&v.value),
2217
2218 Expr::UnaryOp { op, expr } => {
2219 let inner = eval_expr_scope(expr, scope)?;
2220 match op {
2221 UnaryOperator::Not => match inner {
2222 Value::Bool(b) => Ok(Value::Bool(!b)),
2223 Value::Null => Ok(Value::Null),
2224 other => Err(SQLRiteError::Internal(format!(
2225 "NOT applied to non-boolean value: {}",
2226 other.to_display_string()
2227 ))),
2228 },
2229 UnaryOperator::Minus => match inner {
2230 Value::Integer(i) => Ok(Value::Integer(-i)),
2231 Value::Real(f) => Ok(Value::Real(-f)),
2232 Value::Null => Ok(Value::Null),
2233 other => Err(SQLRiteError::Internal(format!(
2234 "unary minus on non-numeric value: {}",
2235 other.to_display_string()
2236 ))),
2237 },
2238 UnaryOperator::Plus => Ok(inner),
2239 other => Err(SQLRiteError::NotImplemented(format!(
2240 "unary operator {other:?} is not supported"
2241 ))),
2242 }
2243 }
2244
2245 Expr::BinaryOp { left, op, right } => match op {
2246 BinaryOperator::And => {
2247 let l = eval_expr_scope(left, scope)?;
2248 let r = eval_expr_scope(right, scope)?;
2249 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2250 }
2251 BinaryOperator::Or => {
2252 let l = eval_expr_scope(left, scope)?;
2253 let r = eval_expr_scope(right, scope)?;
2254 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2255 }
2256 cmp @ (BinaryOperator::Eq
2257 | BinaryOperator::NotEq
2258 | BinaryOperator::Lt
2259 | BinaryOperator::LtEq
2260 | BinaryOperator::Gt
2261 | BinaryOperator::GtEq) => {
2262 let l = eval_expr_scope(left, scope)?;
2263 let r = eval_expr_scope(right, scope)?;
2264 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2266 return Ok(Value::Bool(false));
2267 }
2268 let ord = compare_values(Some(&l), Some(&r));
2269 let result = match cmp {
2270 BinaryOperator::Eq => ord == Ordering::Equal,
2271 BinaryOperator::NotEq => ord != Ordering::Equal,
2272 BinaryOperator::Lt => ord == Ordering::Less,
2273 BinaryOperator::LtEq => ord != Ordering::Greater,
2274 BinaryOperator::Gt => ord == Ordering::Greater,
2275 BinaryOperator::GtEq => ord != Ordering::Less,
2276 _ => unreachable!(),
2277 };
2278 Ok(Value::Bool(result))
2279 }
2280 arith @ (BinaryOperator::Plus
2281 | BinaryOperator::Minus
2282 | BinaryOperator::Multiply
2283 | BinaryOperator::Divide
2284 | BinaryOperator::Modulo) => {
2285 let l = eval_expr_scope(left, scope)?;
2286 let r = eval_expr_scope(right, scope)?;
2287 eval_arith(arith, &l, &r)
2288 }
2289 BinaryOperator::StringConcat => {
2290 let l = eval_expr_scope(left, scope)?;
2291 let r = eval_expr_scope(right, scope)?;
2292 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2293 return Ok(Value::Null);
2294 }
2295 Ok(Value::Text(format!(
2296 "{}{}",
2297 l.to_display_string(),
2298 r.to_display_string()
2299 )))
2300 }
2301 other => Err(SQLRiteError::NotImplemented(format!(
2302 "binary operator {other:?} is not supported yet"
2303 ))),
2304 },
2305
2306 Expr::IsNull(inner) => {
2314 let v = eval_expr_scope(inner, scope)?;
2315 Ok(Value::Bool(matches!(v, Value::Null)))
2316 }
2317 Expr::IsNotNull(inner) => {
2318 let v = eval_expr_scope(inner, scope)?;
2319 Ok(Value::Bool(!matches!(v, Value::Null)))
2320 }
2321
2322 Expr::Like {
2329 negated,
2330 any,
2331 expr: lhs,
2332 pattern,
2333 escape_char,
2334 } => eval_like(
2335 scope,
2336 *negated,
2337 *any,
2338 lhs,
2339 pattern,
2340 escape_char.as_ref(),
2341 true,
2342 ),
2343 Expr::ILike {
2344 negated,
2345 any,
2346 expr: lhs,
2347 pattern,
2348 escape_char,
2349 } => eval_like(
2350 scope,
2351 *negated,
2352 *any,
2353 lhs,
2354 pattern,
2355 escape_char.as_ref(),
2356 true,
2357 ),
2358
2359 Expr::InList {
2365 expr: lhs,
2366 list,
2367 negated,
2368 } => eval_in_list(scope, lhs, list, *negated),
2369 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2370 "IN (subquery) is not supported (only literal lists are)".to_string(),
2371 )),
2372
2373 Expr::Function(func) => eval_function(func, scope),
2384
2385 other => Err(SQLRiteError::NotImplemented(format!(
2386 "unsupported expression in WHERE/projection: {other:?}"
2387 ))),
2388 }
2389}
2390
2391fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2396 let name = match func.name.0.as_slice() {
2399 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2400 _ => {
2401 return Err(SQLRiteError::NotImplemented(format!(
2402 "qualified function names not supported: {:?}",
2403 func.name
2404 )));
2405 }
2406 };
2407
2408 match name.as_str() {
2409 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2410 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2411 let dist = match name.as_str() {
2412 "vec_distance_l2" => vec_distance_l2(&a, &b),
2413 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2414 "vec_distance_dot" => vec_distance_dot(&a, &b),
2415 _ => unreachable!(),
2416 };
2417 Ok(Value::Real(dist as f64))
2423 }
2424 "json_extract" => json_fn_extract(&name, &func.args, scope),
2429 "json_type" => json_fn_type(&name, &func.args, scope),
2430 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2431 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2432 "fts_match" | "bm25_score" => {
2443 let Some((table, rowid)) = scope.single_table_view() else {
2444 return Err(SQLRiteError::NotImplemented(format!(
2445 "{name}() is not yet supported inside a JOIN query — \
2446 use it on a single-table SELECT or move the FTS lookup into a subquery"
2447 )));
2448 };
2449 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2450 Ok(match name.as_str() {
2451 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2452 "bm25_score" => {
2453 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2454 }
2455 _ => unreachable!(),
2456 })
2457 }
2458 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2462 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2463 use it as a top-level projection item (HAVING is not yet supported)"
2464 ))),
2465 other => Err(SQLRiteError::NotImplemented(format!(
2466 "unknown function: {other}(...)"
2467 ))),
2468 }
2469}
2470
2471fn resolve_fts_args<'t>(
2476 fn_name: &str,
2477 args: &FunctionArguments,
2478 table: &'t Table,
2479 scope: &dyn RowScope,
2480) -> Result<(&'t FtsIndexEntry, String)> {
2481 let arg_list = match args {
2482 FunctionArguments::List(l) => &l.args,
2483 _ => {
2484 return Err(SQLRiteError::General(format!(
2485 "{fn_name}() expects exactly two arguments: (column, query_text)"
2486 )));
2487 }
2488 };
2489 if arg_list.len() != 2 {
2490 return Err(SQLRiteError::General(format!(
2491 "{fn_name}() expects exactly 2 arguments, got {}",
2492 arg_list.len()
2493 )));
2494 }
2495
2496 let col_expr = match &arg_list[0] {
2500 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2501 other => {
2502 return Err(SQLRiteError::NotImplemented(format!(
2503 "{fn_name}() argument 0 must be a column name, got {other:?}"
2504 )));
2505 }
2506 };
2507 let col_name = match col_expr {
2508 Expr::Identifier(ident) => ident.value.clone(),
2509 Expr::CompoundIdentifier(parts) => parts
2510 .last()
2511 .map(|p| p.value.clone())
2512 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
2513 other => {
2514 return Err(SQLRiteError::General(format!(
2515 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2516 )));
2517 }
2518 };
2519
2520 let q_expr = match &arg_list[1] {
2524 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2525 other => {
2526 return Err(SQLRiteError::NotImplemented(format!(
2527 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2528 )));
2529 }
2530 };
2531 let query = match eval_expr_scope(q_expr, scope)? {
2532 Value::Text(s) => s,
2533 other => {
2534 return Err(SQLRiteError::General(format!(
2535 "{fn_name}() argument 1 must be TEXT, got {}",
2536 other.to_display_string()
2537 )));
2538 }
2539 };
2540
2541 let entry = table
2542 .fts_indexes
2543 .iter()
2544 .find(|e| e.column_name == col_name)
2545 .ok_or_else(|| {
2546 SQLRiteError::General(format!(
2547 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2548 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2549 ))
2550 })?;
2551 Ok((entry, query))
2552}
2553
2554fn extract_json_and_path(
2568 fn_name: &str,
2569 args: &FunctionArguments,
2570 scope: &dyn RowScope,
2571) -> Result<(String, String)> {
2572 let arg_list = match args {
2573 FunctionArguments::List(l) => &l.args,
2574 _ => {
2575 return Err(SQLRiteError::General(format!(
2576 "{fn_name}() expects 1 or 2 arguments"
2577 )));
2578 }
2579 };
2580 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2581 return Err(SQLRiteError::General(format!(
2582 "{fn_name}() expects 1 or 2 arguments, got {}",
2583 arg_list.len()
2584 )));
2585 }
2586 let first_expr = match &arg_list[0] {
2588 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2589 other => {
2590 return Err(SQLRiteError::NotImplemented(format!(
2591 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2592 )));
2593 }
2594 };
2595 let json_text = match eval_expr_scope(first_expr, scope)? {
2596 Value::Text(s) => s,
2597 Value::Null => {
2598 return Err(SQLRiteError::General(format!(
2599 "{fn_name}() called on NULL — JSON column has no value for this row"
2600 )));
2601 }
2602 other => {
2603 return Err(SQLRiteError::General(format!(
2604 "{fn_name}() argument 0 is not JSON-typed: got {}",
2605 other.to_display_string()
2606 )));
2607 }
2608 };
2609
2610 let path = if arg_list.len() == 2 {
2612 let path_expr = match &arg_list[1] {
2613 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2614 other => {
2615 return Err(SQLRiteError::NotImplemented(format!(
2616 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2617 )));
2618 }
2619 };
2620 match eval_expr_scope(path_expr, scope)? {
2621 Value::Text(s) => s,
2622 other => {
2623 return Err(SQLRiteError::General(format!(
2624 "{fn_name}() path argument must be a string literal, got {}",
2625 other.to_display_string()
2626 )));
2627 }
2628 }
2629 } else {
2630 "$".to_string()
2631 };
2632
2633 Ok((json_text, path))
2634}
2635
2636fn walk_json_path<'a>(
2646 value: &'a serde_json::Value,
2647 path: &str,
2648) -> Result<Option<&'a serde_json::Value>> {
2649 let mut chars = path.chars().peekable();
2650 if chars.next() != Some('$') {
2651 return Err(SQLRiteError::General(format!(
2652 "JSON path must start with '$', got `{path}`"
2653 )));
2654 }
2655 let mut current = value;
2656 while let Some(&c) = chars.peek() {
2657 match c {
2658 '.' => {
2659 chars.next();
2660 let mut key = String::new();
2661 while let Some(&c) = chars.peek() {
2662 if c == '.' || c == '[' {
2663 break;
2664 }
2665 key.push(c);
2666 chars.next();
2667 }
2668 if key.is_empty() {
2669 return Err(SQLRiteError::General(format!(
2670 "JSON path has empty key after '.' in `{path}`"
2671 )));
2672 }
2673 match current.get(&key) {
2674 Some(v) => current = v,
2675 None => return Ok(None),
2676 }
2677 }
2678 '[' => {
2679 chars.next();
2680 let mut idx_str = String::new();
2681 while let Some(&c) = chars.peek() {
2682 if c == ']' {
2683 break;
2684 }
2685 idx_str.push(c);
2686 chars.next();
2687 }
2688 if chars.next() != Some(']') {
2689 return Err(SQLRiteError::General(format!(
2690 "JSON path has unclosed `[` in `{path}`"
2691 )));
2692 }
2693 let idx: usize = idx_str.trim().parse().map_err(|_| {
2694 SQLRiteError::General(format!(
2695 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2696 ))
2697 })?;
2698 match current.get(idx) {
2699 Some(v) => current = v,
2700 None => return Ok(None),
2701 }
2702 }
2703 other => {
2704 return Err(SQLRiteError::General(format!(
2705 "JSON path has unexpected character `{other}` in `{path}` \
2706 (expected `.`, `[`, or end-of-path)"
2707 )));
2708 }
2709 }
2710 }
2711 Ok(Some(current))
2712}
2713
2714fn json_value_to_sql(v: &serde_json::Value) -> Value {
2718 match v {
2719 serde_json::Value::Null => Value::Null,
2720 serde_json::Value::Bool(b) => Value::Bool(*b),
2721 serde_json::Value::Number(n) => {
2722 if let Some(i) = n.as_i64() {
2724 Value::Integer(i)
2725 } else if let Some(f) = n.as_f64() {
2726 Value::Real(f)
2727 } else {
2728 Value::Null
2729 }
2730 }
2731 serde_json::Value::String(s) => Value::Text(s.clone()),
2732 composite => Value::Text(composite.to_string()),
2736 }
2737}
2738
2739fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2740 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2741 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2742 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2743 })?;
2744 match walk_json_path(&parsed, &path)? {
2745 Some(v) => Ok(json_value_to_sql(v)),
2746 None => Ok(Value::Null),
2747 }
2748}
2749
2750fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2751 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2752 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2753 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2754 })?;
2755 let resolved = match walk_json_path(&parsed, &path)? {
2756 Some(v) => v,
2757 None => return Ok(Value::Null),
2758 };
2759 let ty = match resolved {
2760 serde_json::Value::Null => "null",
2761 serde_json::Value::Bool(true) => "true",
2762 serde_json::Value::Bool(false) => "false",
2763 serde_json::Value::Number(n) => {
2764 if n.is_i64() || n.is_u64() {
2765 "integer"
2766 } else {
2767 "real"
2768 }
2769 }
2770 serde_json::Value::String(_) => "text",
2771 serde_json::Value::Array(_) => "array",
2772 serde_json::Value::Object(_) => "object",
2773 };
2774 Ok(Value::Text(ty.to_string()))
2775}
2776
2777fn json_fn_array_length(
2778 name: &str,
2779 args: &FunctionArguments,
2780 scope: &dyn RowScope,
2781) -> Result<Value> {
2782 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2783 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2784 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2785 })?;
2786 let resolved = match walk_json_path(&parsed, &path)? {
2787 Some(v) => v,
2788 None => return Ok(Value::Null),
2789 };
2790 match resolved.as_array() {
2791 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2792 None => Err(SQLRiteError::General(format!(
2793 "{name}() resolved to a non-array value at path `{path}`"
2794 ))),
2795 }
2796}
2797
2798fn json_fn_object_keys(
2799 name: &str,
2800 args: &FunctionArguments,
2801 scope: &dyn RowScope,
2802) -> Result<Value> {
2803 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2804 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2805 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2806 })?;
2807 let resolved = match walk_json_path(&parsed, &path)? {
2808 Some(v) => v,
2809 None => return Ok(Value::Null),
2810 };
2811 let obj = resolved.as_object().ok_or_else(|| {
2812 SQLRiteError::General(format!(
2813 "{name}() resolved to a non-object value at path `{path}`"
2814 ))
2815 })?;
2816 let keys: Vec<serde_json::Value> = obj
2823 .keys()
2824 .map(|k| serde_json::Value::String(k.clone()))
2825 .collect();
2826 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2827}
2828
2829fn extract_two_vector_args(
2833 fn_name: &str,
2834 args: &FunctionArguments,
2835 scope: &dyn RowScope,
2836) -> Result<(Vec<f32>, Vec<f32>)> {
2837 let arg_list = match args {
2838 FunctionArguments::List(l) => &l.args,
2839 _ => {
2840 return Err(SQLRiteError::General(format!(
2841 "{fn_name}() expects exactly two vector arguments"
2842 )));
2843 }
2844 };
2845 if arg_list.len() != 2 {
2846 return Err(SQLRiteError::General(format!(
2847 "{fn_name}() expects exactly 2 arguments, got {}",
2848 arg_list.len()
2849 )));
2850 }
2851 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
2852 for (i, arg) in arg_list.iter().enumerate() {
2853 let expr = match arg {
2854 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2855 other => {
2856 return Err(SQLRiteError::NotImplemented(format!(
2857 "{fn_name}() argument {i} has unsupported shape: {other:?}"
2858 )));
2859 }
2860 };
2861 let val = eval_expr_scope(expr, scope)?;
2862 match val {
2863 Value::Vector(v) => out.push(v),
2864 other => {
2865 return Err(SQLRiteError::General(format!(
2866 "{fn_name}() argument {i} is not a vector: got {}",
2867 other.to_display_string()
2868 )));
2869 }
2870 }
2871 }
2872 let b = out.pop().unwrap();
2873 let a = out.pop().unwrap();
2874 if a.len() != b.len() {
2875 return Err(SQLRiteError::General(format!(
2876 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
2877 a.len(),
2878 b.len()
2879 )));
2880 }
2881 Ok((a, b))
2882}
2883
2884pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
2887 debug_assert_eq!(a.len(), b.len());
2888 let mut sum = 0.0f32;
2889 for i in 0..a.len() {
2890 let d = a[i] - b[i];
2891 sum += d * d;
2892 }
2893 sum.sqrt()
2894}
2895
2896pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
2906 debug_assert_eq!(a.len(), b.len());
2907 let mut dot = 0.0f32;
2908 let mut norm_a_sq = 0.0f32;
2909 let mut norm_b_sq = 0.0f32;
2910 for i in 0..a.len() {
2911 dot += a[i] * b[i];
2912 norm_a_sq += a[i] * a[i];
2913 norm_b_sq += b[i] * b[i];
2914 }
2915 let denom = (norm_a_sq * norm_b_sq).sqrt();
2916 if denom == 0.0 {
2917 return Err(SQLRiteError::General(
2918 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
2919 ));
2920 }
2921 Ok(1.0 - dot / denom)
2922}
2923
2924pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
2928 debug_assert_eq!(a.len(), b.len());
2929 let mut dot = 0.0f32;
2930 for i in 0..a.len() {
2931 dot += a[i] * b[i];
2932 }
2933 -dot
2934}
2935
2936fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
2939 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2940 return Ok(Value::Null);
2941 }
2942 match (l, r) {
2943 (Value::Integer(a), Value::Integer(b)) => match op {
2944 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
2945 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
2946 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
2947 BinaryOperator::Divide => {
2948 if *b == 0 {
2949 Err(SQLRiteError::General("division by zero".to_string()))
2950 } else {
2951 Ok(Value::Integer(a / b))
2952 }
2953 }
2954 BinaryOperator::Modulo => {
2955 if *b == 0 {
2956 Err(SQLRiteError::General("modulo by zero".to_string()))
2957 } else {
2958 Ok(Value::Integer(a % b))
2959 }
2960 }
2961 _ => unreachable!(),
2962 },
2963 (a, b) => {
2965 let af = as_number(a)?;
2966 let bf = as_number(b)?;
2967 match op {
2968 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
2969 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
2970 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
2971 BinaryOperator::Divide => {
2972 if bf == 0.0 {
2973 Err(SQLRiteError::General("division by zero".to_string()))
2974 } else {
2975 Ok(Value::Real(af / bf))
2976 }
2977 }
2978 BinaryOperator::Modulo => {
2979 if bf == 0.0 {
2980 Err(SQLRiteError::General("modulo by zero".to_string()))
2981 } else {
2982 Ok(Value::Real(af % bf))
2983 }
2984 }
2985 _ => unreachable!(),
2986 }
2987 }
2988 }
2989}
2990
2991fn as_number(v: &Value) -> Result<f64> {
2992 match v {
2993 Value::Integer(i) => Ok(*i as f64),
2994 Value::Real(f) => Ok(*f),
2995 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
2996 other => Err(SQLRiteError::General(format!(
2997 "arithmetic on non-numeric value '{}'",
2998 other.to_display_string()
2999 ))),
3000 }
3001}
3002
3003fn as_bool(v: &Value) -> Result<bool> {
3004 match v {
3005 Value::Bool(b) => Ok(*b),
3006 Value::Null => Ok(false),
3007 Value::Integer(i) => Ok(*i != 0),
3008 other => Err(SQLRiteError::Internal(format!(
3009 "expected boolean, got {}",
3010 other.to_display_string()
3011 ))),
3012 }
3013}
3014
3015#[allow(clippy::too_many_arguments)]
3020fn eval_like(
3021 scope: &dyn RowScope,
3022 negated: bool,
3023 any: bool,
3024 lhs: &Expr,
3025 pattern: &Expr,
3026 escape_char: Option<&AstValue>,
3027 case_insensitive: bool,
3028) -> Result<Value> {
3029 if any {
3030 return Err(SQLRiteError::NotImplemented(
3031 "LIKE ANY (...) is not supported".to_string(),
3032 ));
3033 }
3034 if escape_char.is_some() {
3035 return Err(SQLRiteError::NotImplemented(
3036 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
3037 ));
3038 }
3039
3040 let l = eval_expr_scope(lhs, scope)?;
3041 let p = eval_expr_scope(pattern, scope)?;
3042 if matches!(l, Value::Null) || matches!(p, Value::Null) {
3043 return Ok(Value::Null);
3044 }
3045 let text = match l {
3046 Value::Text(s) => s,
3047 other => other.to_display_string(),
3048 };
3049 let pat = match p {
3050 Value::Text(s) => s,
3051 other => other.to_display_string(),
3052 };
3053 let m = like_match(&text, &pat, case_insensitive);
3054 Ok(Value::Bool(if negated { !m } else { m }))
3055}
3056
3057fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
3058 let l = eval_expr_scope(lhs, scope)?;
3059 if matches!(l, Value::Null) {
3060 return Ok(Value::Null);
3061 }
3062 let mut saw_null = false;
3063 for item in list {
3064 let r = eval_expr_scope(item, scope)?;
3065 if matches!(r, Value::Null) {
3066 saw_null = true;
3067 continue;
3068 }
3069 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
3070 return Ok(Value::Bool(!negated));
3071 }
3072 }
3073 if saw_null {
3074 Ok(Value::Null)
3077 } else {
3078 Ok(Value::Bool(negated))
3079 }
3080}
3081
3082fn aggregate_rows(
3093 table: &Table,
3094 matching: &[i64],
3095 group_by: &[String],
3096 proj_items: &[ProjectionItem],
3097) -> Result<Vec<Vec<Value>>> {
3098 let template: Vec<Option<AggState>> = proj_items
3102 .iter()
3103 .map(|i| match &i.kind {
3104 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
3105 ProjectionKind::Column { .. } => None,
3106 })
3107 .collect();
3108
3109 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
3115 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
3116 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
3117
3118 for &rowid in matching {
3119 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
3120 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
3121 for col in group_by {
3122 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
3123 key.push(DistinctKey::from_value(&v));
3124 key_values.push(v);
3125 }
3126 let idx = match keys.iter().position(|k| k == &key) {
3127 Some(i) => i,
3128 None => {
3129 keys.push(key);
3130 group_states.push(template.clone());
3131 group_key_values.push(key_values);
3132 keys.len() - 1
3133 }
3134 };
3135
3136 for (slot, item) in proj_items.iter().enumerate() {
3137 if let ProjectionKind::Aggregate(call) = &item.kind {
3138 let v = match &call.arg {
3139 AggregateArg::Star => Value::Null,
3140 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
3141 };
3142 if let Some(state) = group_states[idx][slot].as_mut() {
3143 state.update(&v)?;
3144 }
3145 }
3146 }
3147 }
3148
3149 if keys.is_empty() && group_by.is_empty() {
3155 keys.push(Vec::new());
3158 group_states.push(template.clone());
3159 group_key_values.push(Vec::new());
3160 }
3161
3162 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3164 for (group_idx, _) in keys.iter().enumerate() {
3165 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3166 for (slot, item) in proj_items.iter().enumerate() {
3167 match &item.kind {
3168 ProjectionKind::Column { name: c, .. } => {
3169 let pos = group_by
3172 .iter()
3173 .position(|g| g == c)
3174 .expect("validated to be in GROUP BY");
3175 row.push(group_key_values[group_idx][pos].clone());
3176 }
3177 ProjectionKind::Aggregate(_) => {
3178 let state = group_states[group_idx][slot]
3179 .as_ref()
3180 .expect("aggregate slot has state");
3181 row.push(state.finalize());
3182 }
3183 }
3184 }
3185 rows.push(row);
3186 }
3187 Ok(rows)
3188}
3189
3190fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3194 use std::collections::HashSet;
3195 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3196 let mut out = Vec::with_capacity(rows.len());
3197 for row in rows {
3198 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3199 if seen.insert(key) {
3200 out.push(row);
3201 }
3202 }
3203 out
3204}
3205
3206fn sort_output_rows(
3210 rows: &mut [Vec<Value>],
3211 columns: &[String],
3212 proj_items: &[ProjectionItem],
3213 order: &OrderByClause,
3214) -> Result<()> {
3215 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3216 rows.sort_by(|a, b| {
3217 let va = &a[target_idx];
3218 let vb = &b[target_idx];
3219 let ord = compare_values(Some(va), Some(vb));
3220 if order.ascending { ord } else { ord.reverse() }
3221 });
3222 Ok(())
3223}
3224
3225fn resolve_order_by_index(
3228 expr: &Expr,
3229 columns: &[String],
3230 proj_items: &[ProjectionItem],
3231) -> Result<usize> {
3232 let target_name: Option<String> = match expr {
3234 Expr::Identifier(ident) => Some(ident.value.clone()),
3235 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3236 Expr::Function(_) => None,
3237 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3238 other => {
3239 return Err(SQLRiteError::NotImplemented(format!(
3240 "ORDER BY expression not supported on aggregating queries: {other:?}"
3241 )));
3242 }
3243 };
3244 if let Some(name) = target_name {
3245 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3246 return Ok(i);
3247 }
3248 return Err(SQLRiteError::Internal(format!(
3249 "ORDER BY references unknown column '{name}' in the SELECT output"
3250 )));
3251 }
3252 if let Expr::Function(func) = expr {
3256 let user_disp = format_function_display(func);
3257 for (i, item) in proj_items.iter().enumerate() {
3258 if let ProjectionKind::Aggregate(call) = &item.kind
3259 && call.display_name().eq_ignore_ascii_case(&user_disp)
3260 {
3261 return Ok(i);
3262 }
3263 }
3264 return Err(SQLRiteError::Internal(format!(
3265 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3266 )));
3267 }
3268 Err(SQLRiteError::Internal(
3269 "ORDER BY expression could not be resolved against the output columns".to_string(),
3270 ))
3271}
3272
3273fn format_function_display(func: &sqlparser::ast::Function) -> String {
3277 let name = match func.name.0.as_slice() {
3278 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3279 _ => format!("{:?}", func.name).to_uppercase(),
3280 };
3281 let inner = match &func.args {
3282 FunctionArguments::List(l) => {
3283 let distinct = matches!(
3284 l.duplicate_treatment,
3285 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3286 );
3287 let arg = l.args.first().map(|a| match a {
3288 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3289 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3290 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3291 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3292 }
3293 _ => String::new(),
3294 });
3295 match (distinct, arg) {
3296 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3297 (_, Some(a)) => a,
3298 _ => String::new(),
3299 }
3300 }
3301 _ => String::new(),
3302 };
3303 format!("{name}({inner})")
3304}
3305
3306fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3307 use sqlparser::ast::Value as AstValue;
3308 match v {
3309 AstValue::Number(n, _) => {
3310 if let Ok(i) = n.parse::<i64>() {
3311 Ok(Value::Integer(i))
3312 } else if let Ok(f) = n.parse::<f64>() {
3313 Ok(Value::Real(f))
3314 } else {
3315 Err(SQLRiteError::Internal(format!(
3316 "could not parse numeric literal '{n}'"
3317 )))
3318 }
3319 }
3320 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3321 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3322 AstValue::Null => Ok(Value::Null),
3323 other => Err(SQLRiteError::NotImplemented(format!(
3324 "unsupported literal value: {other:?}"
3325 ))),
3326 }
3327}
3328
3329#[cfg(test)]
3330mod tests {
3331 use super::*;
3332
3333 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3340 (a - b).abs() < eps
3341 }
3342
3343 #[test]
3344 fn vec_distance_l2_identical_is_zero() {
3345 let v = vec![0.1, 0.2, 0.3];
3346 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3347 }
3348
3349 #[test]
3350 fn vec_distance_l2_unit_basis_is_sqrt2() {
3351 let a = vec![1.0, 0.0];
3353 let b = vec![0.0, 1.0];
3354 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
3355 }
3356
3357 #[test]
3358 fn vec_distance_l2_known_value() {
3359 let a = vec![0.0, 0.0, 0.0];
3361 let b = vec![3.0, 4.0, 0.0];
3362 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
3363 }
3364
3365 #[test]
3366 fn vec_distance_cosine_identical_is_zero() {
3367 let v = vec![0.1, 0.2, 0.3];
3368 let d = vec_distance_cosine(&v, &v).unwrap();
3369 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
3370 }
3371
3372 #[test]
3373 fn vec_distance_cosine_orthogonal_is_one() {
3374 let a = vec![1.0, 0.0];
3377 let b = vec![0.0, 1.0];
3378 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
3379 }
3380
3381 #[test]
3382 fn vec_distance_cosine_opposite_is_two() {
3383 let a = vec![1.0, 0.0, 0.0];
3385 let b = vec![-1.0, 0.0, 0.0];
3386 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
3387 }
3388
3389 #[test]
3390 fn vec_distance_cosine_zero_magnitude_errors() {
3391 let a = vec![0.0, 0.0];
3393 let b = vec![1.0, 0.0];
3394 let err = vec_distance_cosine(&a, &b).unwrap_err();
3395 assert!(format!("{err}").contains("zero-magnitude"));
3396 }
3397
3398 #[test]
3399 fn vec_distance_dot_negates() {
3400 let a = vec![1.0, 2.0, 3.0];
3402 let b = vec![4.0, 5.0, 6.0];
3403 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
3404 }
3405
3406 #[test]
3407 fn vec_distance_dot_orthogonal_is_zero() {
3408 let a = vec![1.0, 0.0];
3410 let b = vec![0.0, 1.0];
3411 assert_eq!(vec_distance_dot(&a, &b), 0.0);
3412 }
3413
3414 #[test]
3415 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3416 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
3422 let cos = vec_distance_cosine(&a, &b).unwrap();
3423 assert!(approx_eq(dot, cos - 1.0, 1e-5));
3424 }
3425
3426 use crate::sql::db::database::Database;
3431 use crate::sql::dialect::SqlriteDialect;
3432 use crate::sql::parser::select::SelectQuery;
3433 use sqlparser::parser::Parser;
3434
3435 fn seed_score_table(n: usize) -> Database {
3448 let mut db = Database::new("tempdb".to_string());
3449 crate::sql::process_command(
3450 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
3451 &mut db,
3452 )
3453 .expect("create");
3454 for i in 0..n {
3455 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
3459 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
3460 crate::sql::process_command(&sql, &mut db).expect("insert");
3461 }
3462 db
3463 }
3464
3465 fn parse_select(sql: &str) -> SelectQuery {
3469 let dialect = SqlriteDialect::new();
3470 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
3471 let stmt = ast.pop().expect("one statement");
3472 SelectQuery::new(&stmt).expect("select-query")
3473 }
3474
3475 #[test]
3476 fn topk_matches_full_sort_asc() {
3477 let db = seed_score_table(200);
3480 let table = db.get_table("docs".to_string()).unwrap();
3481 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3482 let order = q.order_by.as_ref().unwrap();
3483 let all_rowids = table.rowids();
3484
3485 let mut full = all_rowids.clone();
3487 sort_rowids(&mut full, table, order).unwrap();
3488 full.truncate(10);
3489
3490 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3492
3493 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
3494 }
3495
3496 #[test]
3497 fn topk_matches_full_sort_desc() {
3498 let db = seed_score_table(200);
3500 let table = db.get_table("docs".to_string()).unwrap();
3501 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
3502 let order = q.order_by.as_ref().unwrap();
3503 let all_rowids = table.rowids();
3504
3505 let mut full = all_rowids.clone();
3506 sort_rowids(&mut full, table, order).unwrap();
3507 full.truncate(10);
3508
3509 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3510
3511 assert_eq!(
3512 topk, full,
3513 "top-k DESC via heap should match full-sort+truncate"
3514 );
3515 }
3516
3517 #[test]
3518 fn topk_k_larger_than_n_returns_everything_sorted() {
3519 let db = seed_score_table(50);
3524 let table = db.get_table("docs".to_string()).unwrap();
3525 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
3526 let order = q.order_by.as_ref().unwrap();
3527 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
3528 assert_eq!(topk.len(), 50);
3529 let scores: Vec<f64> = topk
3531 .iter()
3532 .filter_map(|r| match table.get_value("score", *r) {
3533 Some(Value::Real(f)) => Some(f),
3534 _ => None,
3535 })
3536 .collect();
3537 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
3538 }
3539
3540 #[test]
3541 fn topk_k_zero_returns_empty() {
3542 let db = seed_score_table(10);
3543 let table = db.get_table("docs".to_string()).unwrap();
3544 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
3545 let order = q.order_by.as_ref().unwrap();
3546 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
3547 assert!(topk.is_empty());
3548 }
3549
3550 #[test]
3551 fn topk_empty_input_returns_empty() {
3552 let db = seed_score_table(0);
3553 let table = db.get_table("docs".to_string()).unwrap();
3554 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
3555 let order = q.order_by.as_ref().unwrap();
3556 let topk = select_topk(&[], table, order, 5).unwrap();
3557 assert!(topk.is_empty());
3558 }
3559
3560 #[test]
3561 fn topk_works_through_select_executor_with_distance_function() {
3562 let mut db = Database::new("tempdb".to_string());
3566 crate::sql::process_command(
3567 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
3568 &mut db,
3569 )
3570 .unwrap();
3571 for v in &[
3578 "[1.0, 0.0]",
3579 "[2.0, 0.0]",
3580 "[0.0, 3.0]",
3581 "[1.0, 4.0]",
3582 "[10.0, 10.0]",
3583 ] {
3584 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3585 .unwrap();
3586 }
3587 let resp = crate::sql::process_command(
3588 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
3589 &mut db,
3590 )
3591 .unwrap();
3592 assert!(resp.contains("3 rows returned"), "got: {resp}");
3595 }
3596
3597 #[test]
3620 #[ignore]
3621 fn topk_benchmark() {
3622 use std::time::Instant;
3623 const N: usize = 10_000;
3624 const K: usize = 10;
3625
3626 let db = seed_score_table(N);
3627 let table = db.get_table("docs".to_string()).unwrap();
3628 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3629 let order = q.order_by.as_ref().unwrap();
3630 let all_rowids = table.rowids();
3631
3632 let t0 = Instant::now();
3634 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3635 let heap_dur = t0.elapsed();
3636
3637 let t1 = Instant::now();
3639 let mut full = all_rowids.clone();
3640 sort_rowids(&mut full, table, order).unwrap();
3641 full.truncate(K);
3642 let sort_dur = t1.elapsed();
3643
3644 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3645 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3646 println!(" bounded heap: {heap_dur:?}");
3647 println!(" full sort+trunc: {sort_dur:?}");
3648 println!(" speedup ratio: {ratio:.2}×");
3649
3650 assert!(
3657 ratio > 1.4,
3658 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
3659 );
3660 }
3661
3662 fn run_select(db: &mut Database, sql: &str) -> String {
3670 crate::sql::process_command(sql, db).expect("select")
3671 }
3672
3673 #[test]
3674 fn where_is_null_returns_null_rows() {
3675 let mut db = Database::new("t".to_string());
3676 crate::sql::process_command(
3677 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3678 &mut db,
3679 )
3680 .unwrap();
3681 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3682 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3683 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3684 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
3685
3686 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
3687 assert!(
3688 response.contains("2 rows returned"),
3689 "IS NULL should return 2 rows, got: {response}"
3690 );
3691 }
3692
3693 #[test]
3694 fn where_is_not_null_returns_non_null_rows() {
3695 let mut db = Database::new("t".to_string());
3696 crate::sql::process_command(
3697 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3698 &mut db,
3699 )
3700 .unwrap();
3701 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3702 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3703 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3704
3705 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
3706 assert!(
3707 response.contains("2 rows returned"),
3708 "IS NOT NULL should return 2 rows, got: {response}"
3709 );
3710 }
3711
3712 #[test]
3713 fn where_is_null_on_indexed_column() {
3714 let mut db = Database::new("t".to_string());
3719 crate::sql::process_command(
3720 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
3721 &mut db,
3722 )
3723 .unwrap();
3724 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
3725 .unwrap();
3726 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
3727 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
3728 .unwrap();
3729
3730 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
3731 assert!(
3732 null_rows.contains("1 row returned"),
3733 "indexed IS NULL should return 1 row, got: {null_rows}"
3734 );
3735 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
3736 assert!(
3737 not_null_rows.contains("2 rows returned"),
3738 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
3739 );
3740 }
3741
3742 #[test]
3743 fn where_is_null_works_on_omitted_column() {
3744 let mut db = Database::new("t".to_string());
3748 crate::sql::process_command(
3749 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
3750 &mut db,
3751 )
3752 .unwrap();
3753 crate::sql::process_command(
3754 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
3755 &mut db,
3756 )
3757 .unwrap();
3758 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
3760
3761 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
3762 assert!(
3763 response.contains("1 row returned"),
3764 "IS NULL should match the omitted-column row, got: {response}"
3765 );
3766 }
3767
3768 #[test]
3769 fn where_is_null_combines_with_and_or() {
3770 let mut db = Database::new("t".to_string());
3774 crate::sql::process_command(
3775 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3776 &mut db,
3777 )
3778 .unwrap();
3779 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
3780 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3781 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3782
3783 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
3784 assert!(
3785 response.contains("1 row returned"),
3786 "IS NULL combined with AND should match exactly row 2, got: {response}"
3787 );
3788 }
3789
3790 fn seed_employees() -> Database {
3796 let mut db = Database::new("t".to_string());
3797 crate::sql::process_command(
3798 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
3799 &mut db,
3800 )
3801 .unwrap();
3802 let rows = [
3803 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
3804 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
3805 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
3806 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
3807 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
3808 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
3809 ];
3810 for sql in rows {
3811 crate::sql::process_command(sql, &mut db).unwrap();
3812 }
3813 db
3814 }
3815
3816 fn run_rows(db: &Database, sql: &str) -> SelectResult {
3818 let q = parse_select(sql);
3819 execute_select_rows(q, db).expect("select")
3820 }
3821
3822 #[test]
3825 fn like_percent_prefix_case_insensitive() {
3826 let db = seed_employees();
3827 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
3828 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3830 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
3831 assert!(names.contains(&"Alice".to_string()));
3832 assert!(names.contains(&"alex".to_string()));
3833 }
3834
3835 #[test]
3836 fn like_underscore_singlechar() {
3837 let db = seed_employees();
3838 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
3839 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3841 assert_eq!(names, vec!["Eve".to_string()]);
3842 }
3843
3844 #[test]
3845 fn not_like_excludes_match() {
3846 let db = seed_employees();
3847 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
3848 assert_eq!(r.rows.len(), 4);
3850 }
3851
3852 #[test]
3853 fn like_with_null_excludes_row() {
3854 let db = seed_employees();
3855 let r = run_rows(
3857 &db,
3858 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
3859 );
3860 assert_eq!(r.rows.len(), 1);
3861 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
3862 }
3863
3864 #[test]
3867 fn in_list_positive() {
3868 let db = seed_employees();
3869 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
3870 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3871 assert_eq!(names.len(), 3);
3872 assert!(names.contains(&"Alice".to_string()));
3873 assert!(names.contains(&"Bob".to_string()));
3874 assert!(names.contains(&"Dave".to_string()));
3875 }
3876
3877 #[test]
3878 fn not_in_excludes_listed() {
3879 let db = seed_employees();
3880 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
3881 assert_eq!(r.rows.len(), 4);
3883 }
3884
3885 #[test]
3886 fn in_list_with_null_three_valued() {
3887 let db = seed_employees();
3888 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
3891 assert_eq!(r.rows.len(), 1);
3892 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
3893 }
3894
3895 #[test]
3898 fn distinct_single_column() {
3899 let db = seed_employees();
3900 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
3901 assert_eq!(r.rows.len(), 3);
3903 }
3904
3905 #[test]
3906 fn distinct_multi_column_with_null() {
3907 let db = seed_employees();
3908 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
3910 assert_eq!(r.rows.len(), 5);
3912 }
3913
3914 #[test]
3917 fn count_star_no_groupby() {
3918 let db = seed_employees();
3919 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
3920 assert_eq!(r.rows.len(), 1);
3921 assert_eq!(r.rows[0][0], Value::Integer(6));
3922 }
3923
3924 #[test]
3925 fn count_col_skips_nulls() {
3926 let db = seed_employees();
3927 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
3928 assert_eq!(r.rows[0][0], Value::Integer(5));
3930 }
3931
3932 #[test]
3933 fn count_distinct_dedupes_and_skips_nulls() {
3934 let db = seed_employees();
3935 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
3936 assert_eq!(r.rows[0][0], Value::Integer(4));
3938 }
3939
3940 #[test]
3941 fn sum_int_stays_integer() {
3942 let db = seed_employees();
3943 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
3944 assert_eq!(r.rows[0][0], Value::Integer(490));
3946 }
3947
3948 #[test]
3949 fn avg_returns_real() {
3950 let db = seed_employees();
3951 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
3952 match &r.rows[0][0] {
3954 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
3955 other => panic!("expected Real, got {other:?}"),
3956 }
3957 }
3958
3959 #[test]
3960 fn min_max_skip_nulls() {
3961 let db = seed_employees();
3962 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
3963 assert_eq!(r.rows[0][0], Value::Integer(80));
3964 assert_eq!(r.rows[0][1], Value::Integer(120));
3965 }
3966
3967 #[test]
3968 fn aggregates_on_empty_table_emit_one_row() {
3969 let mut db = Database::new("t".to_string());
3970 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3971 let r = run_rows(
3972 &db,
3973 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
3974 );
3975 assert_eq!(r.rows.len(), 1);
3976 assert_eq!(r.rows[0][0], Value::Integer(0));
3977 assert_eq!(r.rows[0][1], Value::Null);
3978 assert_eq!(r.rows[0][2], Value::Null);
3979 assert_eq!(r.rows[0][3], Value::Null);
3980 assert_eq!(r.rows[0][4], Value::Null);
3981 }
3982
3983 #[test]
3986 fn group_by_single_col_with_count() {
3987 let db = seed_employees();
3988 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
3989 assert_eq!(r.rows.len(), 3);
3990 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
3992 for row in &r.rows {
3993 let d = row[0].to_display_string();
3994 let c = match &row[1] {
3995 Value::Integer(i) => *i,
3996 v => panic!("expected Integer count, got {v:?}"),
3997 };
3998 by_dept.insert(d, c);
3999 }
4000 assert_eq!(by_dept["eng"], 3);
4001 assert_eq!(by_dept["sales"], 2);
4002 assert_eq!(by_dept["ops"], 1);
4003 }
4004
4005 #[test]
4006 fn group_by_with_where_filter() {
4007 let db = seed_employees();
4008 let r = run_rows(
4009 &db,
4010 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
4011 );
4012 let by: std::collections::HashMap<String, i64> = r
4015 .rows
4016 .iter()
4017 .map(|row| {
4018 (
4019 row[0].to_display_string(),
4020 match &row[1] {
4021 Value::Integer(i) => *i,
4022 v => panic!("expected Integer sum, got {v:?}"),
4023 },
4024 )
4025 })
4026 .collect();
4027 assert_eq!(by.len(), 2);
4028 assert_eq!(by["eng"], 320);
4029 assert_eq!(by["sales"], 90);
4030 }
4031
4032 #[test]
4033 fn group_by_without_aggregates_is_distinct() {
4034 let db = seed_employees();
4035 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
4036 assert_eq!(r.rows.len(), 3);
4037 }
4038
4039 #[test]
4040 fn order_by_count_desc() {
4041 let db = seed_employees();
4042 let r = run_rows(
4043 &db,
4044 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
4045 );
4046 assert_eq!(r.rows.len(), 2);
4047 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4049 assert_eq!(r.rows[0][1], Value::Integer(3));
4050 }
4051
4052 #[test]
4053 fn order_by_aggregate_call_form() {
4054 let db = seed_employees();
4055 let r = run_rows(
4057 &db,
4058 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
4059 );
4060 assert_eq!(r.rows.len(), 3);
4061 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4062 }
4063
4064 #[test]
4065 fn group_by_invalid_bare_column_errors() {
4066 let mut db = Database::new("t".to_string());
4068 crate::sql::process_command(
4069 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
4070 &mut db,
4071 )
4072 .unwrap();
4073 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
4074 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
4075 }
4076
4077 #[test]
4078 fn aggregate_in_where_errors_friendly() {
4079 let mut db = Database::new("t".to_string());
4080 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4081 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
4082 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
4083 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
4084 }
4085
4086 fn seed_join_fixture() -> Database {
4097 let mut db = Database::new("t".to_string());
4098 for sql in [
4099 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
4100 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
4101 "INSERT INTO customers (name) VALUES ('Alice');",
4102 "INSERT INTO customers (name) VALUES ('Bob');",
4103 "INSERT INTO customers (name) VALUES ('Carol');",
4104 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
4105 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
4106 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
4107 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
4108 ] {
4109 crate::sql::process_command(sql, &mut db).unwrap();
4110 }
4111 db
4112 }
4113
4114 #[test]
4115 fn inner_join_returns_only_matched_rows() {
4116 let db = seed_join_fixture();
4117 let r = run_rows(
4118 &db,
4119 "SELECT customers.name, orders.amount FROM customers \
4120 INNER JOIN orders ON customers.id = orders.customer_id;",
4121 );
4122 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4123 let pairs: Vec<(String, i64)> = r
4126 .rows
4127 .iter()
4128 .map(|row| {
4129 (
4130 row[0].to_display_string(),
4131 match row[1] {
4132 Value::Integer(i) => i,
4133 ref v => panic!("expected integer amount, got {v:?}"),
4134 },
4135 )
4136 })
4137 .collect();
4138 assert_eq!(pairs.len(), 3);
4139 assert!(pairs.contains(&("Alice".to_string(), 100)));
4140 assert!(pairs.contains(&("Alice".to_string(), 200)));
4141 assert!(pairs.contains(&("Bob".to_string(), 50)));
4142 }
4143
4144 #[test]
4145 fn bare_join_defaults_to_inner() {
4146 let db = seed_join_fixture();
4147 let r = run_rows(
4148 &db,
4149 "SELECT customers.name FROM customers \
4150 JOIN orders ON customers.id = orders.customer_id;",
4151 );
4152 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
4153 }
4154
4155 #[test]
4156 fn left_outer_join_preserves_unmatched_left() {
4157 let db = seed_join_fixture();
4158 let r = run_rows(
4159 &db,
4160 "SELECT customers.name, orders.amount FROM customers \
4161 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
4162 );
4163 assert_eq!(r.rows.len(), 4);
4166 let carol = r
4167 .rows
4168 .iter()
4169 .find(|row| row[0].to_display_string() == "Carol")
4170 .expect("Carol should appear with a NULL-padded right side");
4171 assert_eq!(carol[1], Value::Null);
4172 }
4173
4174 #[test]
4175 fn right_outer_join_preserves_unmatched_right() {
4176 let db = seed_join_fixture();
4177 let r = run_rows(
4178 &db,
4179 "SELECT customers.name, orders.amount FROM customers \
4180 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
4181 );
4182 assert_eq!(r.rows.len(), 4);
4186 let dangling = r
4187 .rows
4188 .iter()
4189 .find(|row| matches!(row[1], Value::Integer(999)))
4190 .expect("dangling order 999 should appear with a NULL-padded customer name");
4191 assert_eq!(dangling[0], Value::Null);
4192 }
4193
4194 #[test]
4195 fn full_outer_join_preserves_both_sides() {
4196 let db = seed_join_fixture();
4197 let r = run_rows(
4198 &db,
4199 "SELECT customers.name, orders.amount FROM customers \
4200 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
4201 );
4202 assert_eq!(r.rows.len(), 5);
4205 assert!(
4207 r.rows
4208 .iter()
4209 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
4210 );
4211 assert!(
4213 r.rows
4214 .iter()
4215 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
4216 );
4217 }
4218
4219 #[test]
4220 fn join_with_table_aliases_resolves_qualifiers() {
4221 let db = seed_join_fixture();
4222 let r = run_rows(
4223 &db,
4224 "SELECT c.name, o.amount FROM customers AS c \
4225 INNER JOIN orders AS o ON c.id = o.customer_id;",
4226 );
4227 assert_eq!(r.rows.len(), 3);
4228 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4229 }
4230
4231 #[test]
4232 fn join_with_where_filter_applies_after_join() {
4233 let db = seed_join_fixture();
4234 let r = run_rows(
4237 &db,
4238 "SELECT customers.name, orders.amount FROM customers \
4239 INNER JOIN orders ON customers.id = orders.customer_id \
4240 WHERE orders.amount >= 100;",
4241 );
4242 assert_eq!(r.rows.len(), 2);
4243 assert!(
4244 r.rows
4245 .iter()
4246 .all(|row| row[0].to_display_string() == "Alice")
4247 );
4248 }
4249
4250 #[test]
4251 fn left_join_with_where_on_right_side_is_not_inner() {
4252 let db = seed_join_fixture();
4256 let r = run_rows(
4257 &db,
4258 "SELECT customers.name, orders.amount FROM customers \
4259 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
4260 WHERE orders.amount IS NULL;",
4261 );
4262 assert_eq!(r.rows.len(), 1);
4264 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4265 assert_eq!(r.rows[0][1], Value::Null);
4266 }
4267
4268 #[test]
4269 fn select_star_over_join_emits_all_columns_from_both_tables() {
4270 let db = seed_join_fixture();
4271 let r = run_rows(
4272 &db,
4273 "SELECT * FROM customers \
4274 INNER JOIN orders ON customers.id = orders.customer_id;",
4275 );
4276 assert_eq!(
4280 r.columns,
4281 vec![
4282 "id".to_string(),
4283 "name".to_string(),
4284 "id".to_string(),
4285 "customer_id".to_string(),
4286 "amount".to_string(),
4287 ]
4288 );
4289 assert_eq!(r.rows.len(), 3);
4290 }
4291
4292 #[test]
4293 fn join_order_by_sorts_full_joined_rows() {
4294 let db = seed_join_fixture();
4295 let r = run_rows(
4296 &db,
4297 "SELECT c.name, o.amount FROM customers AS c \
4298 INNER JOIN orders AS o ON c.id = o.customer_id \
4299 ORDER BY o.amount;",
4300 );
4301 let amounts: Vec<i64> = r
4302 .rows
4303 .iter()
4304 .map(|row| match row[1] {
4305 Value::Integer(i) => i,
4306 ref v => panic!("expected integer, got {v:?}"),
4307 })
4308 .collect();
4309 assert_eq!(amounts, vec![50, 100, 200]);
4310 }
4311
4312 #[test]
4313 fn join_limit_truncates_after_join_and_sort() {
4314 let db = seed_join_fixture();
4315 let r = run_rows(
4316 &db,
4317 "SELECT c.name, o.amount FROM customers AS c \
4318 INNER JOIN orders AS o ON c.id = o.customer_id \
4319 ORDER BY o.amount DESC LIMIT 2;",
4320 );
4321 assert_eq!(r.rows.len(), 2);
4322 let amounts: Vec<i64> = r
4324 .rows
4325 .iter()
4326 .map(|row| match row[1] {
4327 Value::Integer(i) => i,
4328 ref v => panic!("expected integer, got {v:?}"),
4329 })
4330 .collect();
4331 assert_eq!(amounts, vec![200, 100]);
4332 }
4333
4334 #[test]
4335 fn three_table_join_chains_correctly() {
4336 let mut db = Database::new("t".to_string());
4337 for sql in [
4338 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4339 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4340 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4341 "INSERT INTO a (label) VALUES ('a-one');",
4342 "INSERT INTO a (label) VALUES ('a-two');",
4343 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4344 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
4345 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
4346 ] {
4347 crate::sql::process_command(sql, &mut db).unwrap();
4348 }
4349 let r = run_rows(
4350 &db,
4351 "SELECT a.label, b.tag, c.note FROM a \
4352 INNER JOIN b ON a.id = b.a_id \
4353 INNER JOIN c ON b.id = c.b_id;",
4354 );
4355 assert_eq!(r.rows.len(), 1);
4357 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
4358 assert_eq!(r.rows[0][1].to_display_string(), "b1");
4359 assert_eq!(r.rows[0][2].to_display_string(), "c1");
4360 }
4361
4362 #[test]
4363 fn ambiguous_unqualified_column_in_join_errors() {
4364 let db = seed_join_fixture();
4368 let q = parse_select(
4369 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
4370 );
4371 let res = execute_select_rows(q, &db);
4372 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
4373 }
4374
4375 #[test]
4376 fn join_self_without_alias_is_rejected() {
4377 let mut db = Database::new("t".to_string());
4378 crate::sql::process_command(
4379 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
4380 &mut db,
4381 )
4382 .unwrap();
4383 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
4384 let res = execute_select_rows(q, &db);
4385 assert!(
4386 res.is_err(),
4387 "self-join without an alias should error on duplicate qualifier"
4388 );
4389 }
4390
4391 #[test]
4392 fn using_or_natural_join_returns_not_implemented() {
4393 let mut db = Database::new("t".to_string());
4394 crate::sql::process_command("CREATE TABLE a (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4395 crate::sql::process_command("CREATE TABLE b (id INTEGER PRIMARY KEY);", &mut db).unwrap();
4396 let err = crate::sql::process_command("SELECT * FROM a INNER JOIN b USING (id);", &mut db);
4397 assert!(err.is_err(), "USING is not yet supported");
4398
4399 let err = crate::sql::process_command("SELECT * FROM a NATURAL JOIN b;", &mut db);
4400 assert!(err.is_err(), "NATURAL is not supported");
4401 }
4402
4403 #[test]
4404 fn aggregates_over_join_are_rejected() {
4405 let db = seed_join_fixture();
4406 let err = crate::sql::process_command(
4407 "SELECT COUNT(*) FROM customers \
4408 INNER JOIN orders ON customers.id = orders.customer_id;",
4409 &mut seed_join_fixture(),
4410 );
4411 assert!(err.is_err(), "aggregates over JOIN are not yet supported");
4412 let _ = db; }
4414
4415 #[test]
4416 fn left_join_with_no_matches_pads_every_row() {
4417 let mut db = Database::new("t".to_string());
4418 for sql in [
4419 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4420 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4421 "INSERT INTO a (x) VALUES (1);",
4422 "INSERT INTO a (x) VALUES (2);",
4423 "INSERT INTO b (y) VALUES (10);",
4424 ] {
4425 crate::sql::process_command(sql, &mut db).unwrap();
4426 }
4427 let r = run_rows(
4429 &db,
4430 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
4431 );
4432 assert_eq!(r.rows.len(), 2);
4433 for row in &r.rows {
4434 assert_eq!(row[1], Value::Null);
4435 }
4436 }
4437
4438 #[test]
4439 fn left_outer_join_order_by_places_nulls_first() {
4440 let db = seed_join_fixture();
4445 let r = run_rows(
4446 &db,
4447 "SELECT c.name, o.amount FROM customers AS c \
4448 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
4449 ORDER BY o.amount ASC;",
4450 );
4451 assert_eq!(r.rows.len(), 4);
4452 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4454 assert_eq!(r.rows[0][1], Value::Null);
4455 }
4456
4457 #[test]
4458 fn chained_left_outer_join_preserves_left_through_two_levels() {
4459 let mut db = Database::new("t".to_string());
4462 for sql in [
4463 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4464 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4465 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4466 "INSERT INTO a (label) VALUES ('a-one');",
4467 "INSERT INTO a (label) VALUES ('a-two');",
4468 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4470 ] {
4472 crate::sql::process_command(sql, &mut db).unwrap();
4473 }
4474 let r = run_rows(
4475 &db,
4476 "SELECT a.label, b.tag, c.note FROM a \
4477 LEFT OUTER JOIN b ON a.id = b.a_id \
4478 LEFT OUTER JOIN c ON b.id = c.b_id;",
4479 );
4480 assert_eq!(r.rows.len(), 2);
4482 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
4483 .rows
4484 .iter()
4485 .map(|row| (row[0].to_display_string(), row))
4486 .collect();
4487 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
4488 assert_eq!(by_label["a-one"][2], Value::Null);
4489 assert_eq!(by_label["a-two"][1], Value::Null);
4490 assert_eq!(by_label["a-two"][2], Value::Null);
4491 }
4492
4493 #[test]
4494 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
4495 let mut db = Database::new("t".to_string());
4499 for sql in [
4500 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4501 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
4502 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
4503 "INSERT INTO a (x) VALUES (1);",
4504 "INSERT INTO b (x) VALUES (1);",
4505 "INSERT INTO c (x) VALUES (1);",
4506 ] {
4507 crate::sql::process_command(sql, &mut db).unwrap();
4508 }
4509 let q =
4510 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
4511 let res = execute_select_rows(q, &db);
4512 assert!(
4513 res.is_err(),
4514 "ON referencing not-yet-joined table 'c' should error"
4515 );
4516 }
4517
4518 #[test]
4519 fn join_on_truthy_integer_is_accepted() {
4520 let mut db = Database::new("t".to_string());
4524 for sql in [
4525 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4526 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4527 "INSERT INTO a (x) VALUES (1);",
4528 "INSERT INTO a (x) VALUES (2);",
4529 "INSERT INTO b (y) VALUES (10);",
4530 "INSERT INTO b (y) VALUES (20);",
4531 ] {
4532 crate::sql::process_command(sql, &mut db).unwrap();
4533 }
4534 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
4535 assert_eq!(r.rows.len(), 4);
4537 }
4538
4539 #[test]
4540 fn full_join_on_empty_tables_returns_empty() {
4541 let mut db = Database::new("t".to_string());
4542 for sql in [
4543 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
4544 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
4545 ] {
4546 crate::sql::process_command(sql, &mut db).unwrap();
4547 }
4548 let r = run_rows(
4549 &db,
4550 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
4551 );
4552 assert!(r.rows.is_empty());
4553 }
4554}