1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, Ident, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, AggregateFn, JoinConstraintKind, JoinType, OrderByClause, Projection,
25 ProjectionItem, ProjectionKind, SelectQuery, parse_aggregate_call,
26};
27
28pub(crate) trait RowScope {
57 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value>;
58
59 fn single_table_view(&self) -> Option<(&Table, i64)>;
65}
66
67pub(crate) struct SingleTableScope<'a> {
69 table: &'a Table,
70 rowid: i64,
71}
72
73impl<'a> SingleTableScope<'a> {
74 pub(crate) fn new(table: &'a Table, rowid: i64) -> Self {
75 Self { table, rowid }
76 }
77}
78
79impl RowScope for SingleTableScope<'_> {
80 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
81 let _ = qualifier;
86 Ok(self.table.get_value(col, self.rowid).unwrap_or(Value::Null))
87 }
88
89 fn single_table_view(&self) -> Option<(&Table, i64)> {
90 Some((self.table, self.rowid))
91 }
92}
93
94pub(crate) struct JoinedTableRef<'a> {
98 pub table: &'a Table,
99 pub scope_name: String,
100}
101
102pub(crate) struct JoinedScope<'a> {
106 pub tables: &'a [JoinedTableRef<'a>],
107 pub rowids: &'a [Option<i64>],
108}
109
110impl RowScope for JoinedScope<'_> {
111 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
112 if let Some(q) = qualifier {
113 let pos = self
116 .tables
117 .iter()
118 .position(|t| t.scope_name.eq_ignore_ascii_case(q))
119 .ok_or_else(|| {
120 SQLRiteError::Internal(format!(
121 "unknown table qualifier '{q}' in column reference '{q}.{col}'"
122 ))
123 })?;
124 if !self.tables[pos].table.contains_column(col.to_string()) {
125 return Err(SQLRiteError::Internal(format!(
126 "column '{col}' does not exist on '{}'",
127 self.tables[pos].scope_name
128 )));
129 }
130 return Ok(match self.rowids[pos] {
131 None => Value::Null,
132 Some(r) => self.tables[pos]
133 .table
134 .get_value(col, r)
135 .unwrap_or(Value::Null),
136 });
137 }
138 let mut hit: Option<usize> = None;
142 for (i, t) in self.tables.iter().enumerate() {
143 if t.table.contains_column(col.to_string()) {
144 if hit.is_some() {
145 return Err(SQLRiteError::Internal(format!(
146 "column reference '{col}' is ambiguous — qualify it as <table>.{col}"
147 )));
148 }
149 hit = Some(i);
150 }
151 }
152 let i = hit.ok_or_else(|| {
153 SQLRiteError::Internal(format!(
154 "unknown column '{col}' in joined SELECT (no in-scope table has it)"
155 ))
156 })?;
157 Ok(match self.rowids[i] {
158 None => Value::Null,
159 Some(r) => self.tables[i]
160 .table
161 .get_value(col, r)
162 .unwrap_or(Value::Null),
163 })
164 }
165
166 fn single_table_view(&self) -> Option<(&Table, i64)> {
167 None
168 }
169}
170
171pub struct SelectResult {
180 pub columns: Vec<String>,
181 pub rows: Vec<Vec<Value>>,
182}
183
184pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
188 if !query.joins.is_empty() {
193 return execute_select_rows_joined(query, db);
194 }
195
196 let master_snapshot;
205 let table: &Table = if query.table_name == crate::sql::pager::MASTER_TABLE_NAME {
206 master_snapshot = crate::sql::pager::build_master_table_snapshot(db)?;
207 &master_snapshot
208 } else {
209 db.get_table(query.table_name.clone()).map_err(|_| {
210 SQLRiteError::Internal(format!("Table '{}' not found", query.table_name))
211 })?
212 };
213
214 let proj_items: Vec<ProjectionItem> = match &query.projection {
219 Projection::All => table
220 .column_names()
221 .into_iter()
222 .map(|c| ProjectionItem {
223 kind: ProjectionKind::Column {
224 qualifier: None,
225 name: c,
226 },
227 alias: None,
228 })
229 .collect(),
230 Projection::Items(items) => items.clone(),
231 };
232 let has_aggregates = proj_items
233 .iter()
234 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
235 for item in &proj_items {
237 if let ProjectionKind::Column { name: c, .. } = &item.kind
238 && !table.contains_column(c.clone())
239 {
240 return Err(SQLRiteError::Internal(format!(
241 "Column '{c}' does not exist on table '{}'",
242 query.table_name
243 )));
244 }
245 }
246 for c in &query.group_by {
247 if !table.contains_column(c.clone()) {
248 return Err(SQLRiteError::Internal(format!(
249 "GROUP BY references unknown column '{c}' on table '{}'",
250 query.table_name
251 )));
252 }
253 }
254 let matching = match select_rowids(table, query.selection.as_ref())? {
258 RowidSource::IndexProbe(rowids) => rowids,
259 RowidSource::FullScan => {
260 let mut out = Vec::new();
261 for rowid in table.rowids() {
262 if let Some(expr) = &query.selection
263 && !eval_predicate(expr, table, rowid)?
264 {
265 continue;
266 }
267 out.push(rowid);
268 }
269 out
270 }
271 };
272 let mut matching = matching;
273
274 let aggregating = has_aggregates || !query.group_by.is_empty();
275
276 if aggregating {
282 let mut all_items = proj_items.clone();
293 let having_expr = match &query.having {
294 Some(h) => {
295 for g in &query.group_by {
296 if !all_items
297 .iter()
298 .any(|i| i.output_name().eq_ignore_ascii_case(g))
299 {
300 all_items.push(ProjectionItem {
301 kind: ProjectionKind::Column {
302 qualifier: None,
303 name: g.clone(),
304 },
305 alias: None,
306 });
307 }
308 }
309 Some(lower_having_expr(h, &mut all_items)?)
310 }
311 None => None,
312 };
313
314 for item in &all_items {
316 if let ProjectionKind::Aggregate(call) = &item.kind
317 && let AggregateArg::Column(c) = &call.arg
318 && !table.contains_column(c.clone())
319 {
320 return Err(SQLRiteError::Internal(format!(
321 "{}({}) references unknown column '{c}' on table '{}'",
322 call.func.as_str(),
323 c,
324 query.table_name
325 )));
326 }
327 }
328
329 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
330 let mut rows = aggregate_rows(table, &matching, &query.group_by, &all_items)?;
331
332 if let Some(h) = &having_expr {
333 let all_columns: Vec<String> = all_items.iter().map(|i| i.output_name()).collect();
334 rows = filter_groups_by_having(rows, h, &all_columns)?;
335 }
336 if all_items.len() > proj_items.len() {
338 for row in &mut rows {
339 row.truncate(proj_items.len());
340 }
341 }
342
343 if query.distinct {
344 rows = dedupe_rows(rows);
345 }
346
347 if let Some(order) = &query.order_by {
348 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
349 }
350 if let Some(k) = query.limit {
351 rows.truncate(k);
352 }
353
354 return Ok(SelectResult { columns, rows });
355 }
356
357 let defer_limit_for_distinct = query.distinct;
395 match (&query.order_by, query.limit) {
396 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
397 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
398 }
399 (Some(order), Some(k))
400 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
401 {
402 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
403 }
404 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
405 matching = select_topk(&matching, table, order, k)?;
406 }
407 (Some(order), _) => {
408 sort_rowids(&mut matching, table, order)?;
409 if let Some(k) = query.limit
410 && !defer_limit_for_distinct
411 {
412 matching.truncate(k);
413 }
414 }
415 (None, Some(k)) if !defer_limit_for_distinct => {
416 matching.truncate(k);
417 }
418 _ => {}
419 }
420
421 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
422 let projected_cols: Vec<String> = proj_items
423 .iter()
424 .map(|i| match &i.kind {
425 ProjectionKind::Column { name, .. } => name.clone(),
426 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
427 })
428 .collect();
429
430 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
434 for rowid in &matching {
435 let row: Vec<Value> = projected_cols
436 .iter()
437 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
438 .collect();
439 rows.push(row);
440 }
441
442 if query.distinct {
443 rows = dedupe_rows(rows);
444 if let Some(k) = query.limit {
445 rows.truncate(k);
446 }
447 }
448
449 Ok(SelectResult { columns, rows })
450}
451
452struct ResolvedJoin {
457 on: Expr,
458 using_columns: Vec<String>,
459}
460
461fn resolve_join_constraint(
476 constraint: &JoinConstraintKind,
477 tables: &[JoinedTableRef<'_>],
478 right_pos: usize,
479) -> Result<ResolvedJoin> {
480 match constraint {
481 JoinConstraintKind::On(expr) => Ok(ResolvedJoin {
482 on: (**expr).clone(),
483 using_columns: Vec::new(),
484 }),
485 JoinConstraintKind::Using(cols) => build_using_join(cols, tables, right_pos),
486 JoinConstraintKind::Natural => {
487 let shared: Vec<String> = tables[right_pos]
491 .table
492 .column_names()
493 .into_iter()
494 .filter(|c| {
495 tables[..right_pos]
496 .iter()
497 .any(|t| t.table.contains_column(c.clone()))
498 })
499 .collect();
500 build_using_join(&shared, tables, right_pos)
501 }
502 }
503}
504
505fn build_using_join(
510 cols: &[String],
511 tables: &[JoinedTableRef<'_>],
512 right_pos: usize,
513) -> Result<ResolvedJoin> {
514 let right = &tables[right_pos];
515 let mut predicate: Option<Expr> = None;
516 for col in cols {
517 if !right.table.contains_column(col.clone()) {
519 return Err(SQLRiteError::Internal(format!(
520 "cannot join USING column '{col}' — it is not present on table '{}'",
521 right.scope_name
522 )));
523 }
524 let left = tables[..right_pos]
527 .iter()
528 .find(|t| t.table.contains_column(col.clone()))
529 .ok_or_else(|| {
530 SQLRiteError::Internal(format!(
531 "cannot join USING column '{col}' — it is not present on any left-side table"
532 ))
533 })?;
534 let eq = col_eq(&left.scope_name, &right.scope_name, col);
535 predicate = Some(match predicate {
536 None => eq,
537 Some(prev) => Expr::BinaryOp {
538 left: Box::new(prev),
539 op: BinaryOperator::And,
540 right: Box::new(eq),
541 },
542 });
543 }
544 Ok(ResolvedJoin {
545 on: predicate
546 .unwrap_or_else(|| Expr::Value(sqlparser::ast::Value::Boolean(true).with_empty_span())),
547 using_columns: cols.to_vec(),
548 })
549}
550
551fn col_eq(left_scope: &str, right_scope: &str, col: &str) -> Expr {
554 let col_ref = |scope: &str| {
555 Expr::CompoundIdentifier(vec![
556 Ident::new(scope.to_string()),
557 Ident::new(col.to_string()),
558 ])
559 };
560 Expr::BinaryOp {
561 left: Box::new(col_ref(left_scope)),
562 op: BinaryOperator::Eq,
563 right: Box::new(col_ref(right_scope)),
564 }
565}
566
567fn execute_select_rows_joined(query: SelectQuery, db: &Database) -> Result<SelectResult> {
594 let mut joined_tables: Vec<JoinedTableRef<'_>> = Vec::with_capacity(1 + query.joins.len());
601
602 let primary = db
603 .get_table(query.table_name.clone())
604 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
605 joined_tables.push(JoinedTableRef {
606 table: primary,
607 scope_name: query
608 .table_alias
609 .clone()
610 .unwrap_or_else(|| query.table_name.clone()),
611 });
612 for j in &query.joins {
613 let t = db
614 .get_table(j.right_table.clone())
615 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", j.right_table)))?;
616 joined_tables.push(JoinedTableRef {
617 table: t,
618 scope_name: j
619 .right_alias
620 .clone()
621 .unwrap_or_else(|| j.right_table.clone()),
622 });
623 }
624
625 {
630 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
631 for t in &joined_tables {
632 let key = t.scope_name.to_ascii_lowercase();
633 if !seen.insert(key) {
634 return Err(SQLRiteError::Internal(format!(
635 "duplicate table reference '{}' in FROM/JOIN — use AS to alias one side",
636 t.scope_name
637 )));
638 }
639 }
640 }
641
642 let resolved: Vec<ResolvedJoin> = query
650 .joins
651 .iter()
652 .enumerate()
653 .map(|(j_idx, join)| resolve_join_constraint(&join.constraint, &joined_tables, j_idx + 1))
654 .collect::<Result<Vec<_>>>()?;
655
656 let proj_items: Vec<ProjectionItem> = match &query.projection {
662 Projection::All => {
663 let mut all = Vec::new();
679 for (t_idx, t) in joined_tables.iter().enumerate() {
680 let dedup: &[String] = t_idx
683 .checked_sub(1)
684 .map(|r| resolved[r].using_columns.as_slice())
685 .unwrap_or(&[]);
686 for col in t.table.column_names() {
687 if dedup.contains(&col) {
688 continue;
689 }
690 all.push(ProjectionItem {
691 kind: ProjectionKind::Column {
692 qualifier: Some(t.scope_name.clone()),
697 name: col,
698 },
699 alias: None,
700 });
701 }
702 }
703 all
704 }
705 Projection::Items(items) => items.clone(),
706 };
707
708 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
709
710 let mut acc: Vec<Vec<Option<i64>>> = primary
715 .rowids()
716 .into_iter()
717 .map(|r| {
718 let mut row = Vec::with_capacity(joined_tables.len());
719 row.push(Some(r));
720 row
721 })
722 .collect();
723
724 for (j_idx, join) in query.joins.iter().enumerate() {
729 let right_pos = j_idx + 1;
730 let right_table = joined_tables[right_pos].table;
731 let right_rowids: Vec<i64> = right_table.rowids();
732
733 let mut right_matched: Vec<bool> = vec![false; right_rowids.len()];
737
738 let mut next_acc: Vec<Vec<Option<i64>>> = Vec::with_capacity(acc.len());
739
740 let on_scope_tables: &[JoinedTableRef<'_>] = &joined_tables[..=right_pos];
748
749 for left_row in acc.into_iter() {
750 let mut left_match_count = 0usize;
754 for (r_idx, &rrid) in right_rowids.iter().enumerate() {
755 let mut on_rowids: Vec<Option<i64>> = left_row.clone();
756 on_rowids.push(Some(rrid));
757 debug_assert_eq!(on_rowids.len(), on_scope_tables.len());
758 let scope = JoinedScope {
759 tables: on_scope_tables,
760 rowids: &on_rowids,
761 };
762 if eval_predicate_scope(&resolved[j_idx].on, &scope)? {
769 left_match_count += 1;
770 right_matched[r_idx] = true;
771 next_acc.push(on_rowids);
776 }
777 }
778
779 if left_match_count == 0
780 && matches!(join.join_type, JoinType::LeftOuter | JoinType::FullOuter)
781 {
782 let mut padded = left_row;
785 padded.push(None);
786 next_acc.push(padded);
787 }
788 }
789
790 if matches!(join.join_type, JoinType::RightOuter | JoinType::FullOuter) {
794 for (r_idx, matched) in right_matched.iter().enumerate() {
795 if *matched {
796 continue;
797 }
798 let mut row: Vec<Option<i64>> = vec![None; right_pos];
799 row.push(Some(right_rowids[r_idx]));
800 next_acc.push(row);
801 }
802 }
803
804 acc = next_acc;
805 }
806
807 let mut filtered: Vec<Vec<Option<i64>>> = if let Some(where_expr) = &query.selection {
812 let mut out = Vec::with_capacity(acc.len());
813 for row in acc {
814 let scope = JoinedScope {
815 tables: &joined_tables,
816 rowids: &row,
817 };
818 if eval_predicate_scope(where_expr, &scope)? {
819 out.push(row);
820 }
821 }
822 out
823 } else {
824 acc
825 };
826
827 if let Some(order) = &query.order_by {
831 let mut keys: Vec<(usize, Value)> = Vec::with_capacity(filtered.len());
834 for (i, row) in filtered.iter().enumerate() {
835 let scope = JoinedScope {
836 tables: &joined_tables,
837 rowids: row,
838 };
839 let v = eval_expr_scope(&order.expr, &scope)?;
840 keys.push((i, v));
841 }
842 keys.sort_by(|(_, a), (_, b)| {
843 let ord = compare_values(Some(a), Some(b));
844 if order.ascending { ord } else { ord.reverse() }
845 });
846 let mut sorted = Vec::with_capacity(filtered.len());
847 for (i, _) in keys {
848 sorted.push(filtered[i].clone());
849 }
850 filtered = sorted;
851 }
852
853 if let Some(k) = query.limit {
855 filtered.truncate(k);
856 }
857
858 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(filtered.len());
861 for row in &filtered {
862 let scope = JoinedScope {
863 tables: &joined_tables,
864 rowids: row,
865 };
866 let mut out_row = Vec::with_capacity(proj_items.len());
867 for item in &proj_items {
868 let v = match &item.kind {
869 ProjectionKind::Column { qualifier, name } => {
870 scope.lookup(qualifier.as_deref(), name)?
871 }
872 ProjectionKind::Aggregate(_) => {
873 return Err(SQLRiteError::Internal(
876 "aggregate functions over JOIN are not supported".to_string(),
877 ));
878 }
879 };
880 out_row.push(v);
881 }
882 rows.push(out_row);
883 }
884
885 Ok(SelectResult { columns, rows })
886}
887
888pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
893 let result = execute_select_rows(query, db)?;
894 let row_count = result.rows.len();
895
896 let mut print_table = PrintTable::new();
897 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
898 print_table.add_row(PrintRow::new(header_cells));
899
900 for row in &result.rows {
901 let cells: Vec<PrintCell> = row
902 .iter()
903 .map(|v| PrintCell::new(&v.to_display_string()))
904 .collect();
905 print_table.add_row(PrintRow::new(cells));
906 }
907
908 Ok((print_table.to_string(), row_count))
909}
910
911pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
913 let Statement::Delete(Delete {
914 from, selection, ..
915 }) = stmt
916 else {
917 return Err(SQLRiteError::Internal(
918 "execute_delete called on a non-DELETE statement".to_string(),
919 ));
920 };
921
922 let tables = match from {
923 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
924 };
925 let table_name = extract_single_table_name(tables)?;
926
927 let matching: Vec<i64> = {
929 let table = db
930 .get_table(table_name.clone())
931 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
932 match select_rowids(table, selection.as_ref())? {
933 RowidSource::IndexProbe(rowids) => rowids,
934 RowidSource::FullScan => {
935 let mut out = Vec::new();
936 for rowid in table.rowids() {
937 if let Some(expr) = selection {
938 if !eval_predicate(expr, table, rowid)? {
939 continue;
940 }
941 }
942 out.push(rowid);
943 }
944 out
945 }
946 }
947 };
948
949 let table = db.get_table_mut(table_name)?;
950 for rowid in &matching {
951 table.delete_row(*rowid);
952 }
953 if !matching.is_empty() {
962 for entry in &mut table.hnsw_indexes {
963 entry.needs_rebuild = true;
964 }
965 for entry in &mut table.fts_indexes {
966 entry.needs_rebuild = true;
967 }
968 }
969 Ok(matching.len())
970}
971
972pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
974 let Statement::Update(Update {
975 table,
976 assignments,
977 from,
978 selection,
979 ..
980 }) = stmt
981 else {
982 return Err(SQLRiteError::Internal(
983 "execute_update called on a non-UPDATE statement".to_string(),
984 ));
985 };
986
987 if from.is_some() {
988 return Err(SQLRiteError::NotImplemented(
989 "UPDATE ... FROM is not supported yet".to_string(),
990 ));
991 }
992
993 let table_name = extract_table_name(table)?;
994
995 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
997 {
998 let tbl = db
999 .get_table(table_name.clone())
1000 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
1001 for a in assignments {
1002 let col = match &a.target {
1003 AssignmentTarget::ColumnName(name) => name
1004 .0
1005 .last()
1006 .map(|p| p.to_string())
1007 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
1008 AssignmentTarget::Tuple(_) => {
1009 return Err(SQLRiteError::NotImplemented(
1010 "tuple assignment targets are not supported".to_string(),
1011 ));
1012 }
1013 };
1014 if !tbl.contains_column(col.clone()) {
1015 return Err(SQLRiteError::Internal(format!(
1016 "UPDATE references unknown column '{col}'"
1017 )));
1018 }
1019 parsed_assignments.push((col, a.value.clone()));
1020 }
1021 }
1022
1023 let work: Vec<(i64, Vec<(String, Value)>)> = {
1027 let tbl = db.get_table(table_name.clone())?;
1028 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
1029 RowidSource::IndexProbe(rowids) => rowids,
1030 RowidSource::FullScan => {
1031 let mut out = Vec::new();
1032 for rowid in tbl.rowids() {
1033 if let Some(expr) = selection {
1034 if !eval_predicate(expr, tbl, rowid)? {
1035 continue;
1036 }
1037 }
1038 out.push(rowid);
1039 }
1040 out
1041 }
1042 };
1043 let mut rows_to_update = Vec::new();
1044 for rowid in matched_rowids {
1045 let mut values = Vec::with_capacity(parsed_assignments.len());
1046 for (col, expr) in &parsed_assignments {
1047 let v = eval_expr(expr, tbl, rowid)?;
1050 values.push((col.clone(), v));
1051 }
1052 rows_to_update.push((rowid, values));
1053 }
1054 rows_to_update
1055 };
1056
1057 let tbl = db.get_table_mut(table_name)?;
1058 for (rowid, values) in &work {
1059 for (col, v) in values {
1060 tbl.set_value(col, *rowid, v.clone())?;
1061 }
1062 }
1063
1064 if !work.is_empty() {
1073 let updated_columns: std::collections::HashSet<&str> = work
1074 .iter()
1075 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
1076 .collect();
1077 for entry in &mut tbl.hnsw_indexes {
1078 if updated_columns.contains(entry.column_name.as_str()) {
1079 entry.needs_rebuild = true;
1080 }
1081 }
1082 for entry in &mut tbl.fts_indexes {
1083 if updated_columns.contains(entry.column_name.as_str()) {
1084 entry.needs_rebuild = true;
1085 }
1086 }
1087 }
1088 Ok(work.len())
1089}
1090
1091pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
1103 let Statement::CreateIndex(CreateIndex {
1104 name,
1105 table_name,
1106 columns,
1107 using,
1108 unique,
1109 if_not_exists,
1110 predicate,
1111 with,
1112 ..
1113 }) = stmt
1114 else {
1115 return Err(SQLRiteError::Internal(
1116 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
1117 ));
1118 };
1119
1120 if predicate.is_some() {
1121 return Err(SQLRiteError::NotImplemented(
1122 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
1123 ));
1124 }
1125
1126 if columns.len() != 1 {
1127 return Err(SQLRiteError::NotImplemented(format!(
1128 "multi-column indexes are not supported yet ({} columns given)",
1129 columns.len()
1130 )));
1131 }
1132
1133 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
1134 SQLRiteError::NotImplemented(
1135 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
1136 )
1137 })?;
1138
1139 let method = match using {
1145 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
1146 IndexMethod::Hnsw
1147 }
1148 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
1149 IndexMethod::Fts
1150 }
1151 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
1152 IndexMethod::Btree
1153 }
1154 Some(other) => {
1155 return Err(SQLRiteError::NotImplemented(format!(
1156 "CREATE INDEX … USING {other:?} is not supported \
1157 (try `hnsw`, `fts`, or no USING clause)"
1158 )));
1159 }
1160 None => IndexMethod::Btree,
1161 };
1162
1163 let hnsw_metric = parse_hnsw_with_options(with, &index_name, method)?;
1169
1170 let table_name_str = table_name.to_string();
1171 let column_name = match &columns[0].column.expr {
1172 Expr::Identifier(ident) => ident.value.clone(),
1173 Expr::CompoundIdentifier(parts) => parts
1174 .last()
1175 .map(|p| p.value.clone())
1176 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
1177 other => {
1178 return Err(SQLRiteError::NotImplemented(format!(
1179 "CREATE INDEX only supports simple column references, got {other:?}"
1180 )));
1181 }
1182 };
1183
1184 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
1189 let table = db.get_table(table_name_str.clone()).map_err(|_| {
1190 SQLRiteError::General(format!(
1191 "CREATE INDEX references unknown table '{table_name_str}'"
1192 ))
1193 })?;
1194 if !table.contains_column(column_name.clone()) {
1195 return Err(SQLRiteError::General(format!(
1196 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
1197 )));
1198 }
1199 let col = table
1200 .columns
1201 .iter()
1202 .find(|c| c.column_name == column_name)
1203 .expect("we just verified the column exists");
1204
1205 if table.index_by_name(&index_name).is_some()
1208 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
1209 || table.fts_indexes.iter().any(|i| i.name == index_name)
1210 {
1211 if *if_not_exists {
1212 return Ok(index_name);
1213 }
1214 return Err(SQLRiteError::General(format!(
1215 "index '{index_name}' already exists"
1216 )));
1217 }
1218 let datatype = clone_datatype(&col.datatype);
1219
1220 let mut pairs = Vec::new();
1221 for rowid in table.rowids() {
1222 if let Some(v) = table.get_value(&column_name, rowid) {
1223 pairs.push((rowid, v));
1224 }
1225 }
1226 (datatype, pairs)
1227 };
1228
1229 match method {
1230 IndexMethod::Btree => create_btree_index(
1231 db,
1232 &table_name_str,
1233 &index_name,
1234 &column_name,
1235 &datatype,
1236 *unique,
1237 &existing_rowids_and_values,
1238 ),
1239 IndexMethod::Hnsw => create_hnsw_index(
1240 db,
1241 &table_name_str,
1242 &index_name,
1243 &column_name,
1244 &datatype,
1245 *unique,
1246 hnsw_metric.unwrap_or(DistanceMetric::L2),
1247 &existing_rowids_and_values,
1248 ),
1249 IndexMethod::Fts => create_fts_index(
1250 db,
1251 &table_name_str,
1252 &index_name,
1253 &column_name,
1254 &datatype,
1255 *unique,
1256 &existing_rowids_and_values,
1257 ),
1258 }
1259}
1260
1261pub fn execute_drop_table(
1272 names: &[ObjectName],
1273 if_exists: bool,
1274 db: &mut Database,
1275) -> Result<usize> {
1276 if names.len() != 1 {
1277 return Err(SQLRiteError::NotImplemented(
1278 "DROP TABLE supports a single table per statement".to_string(),
1279 ));
1280 }
1281 let name = names[0].to_string();
1282
1283 if name == crate::sql::pager::MASTER_TABLE_NAME {
1284 return Err(SQLRiteError::General(format!(
1285 "'{}' is a reserved name used by the internal schema catalog",
1286 crate::sql::pager::MASTER_TABLE_NAME
1287 )));
1288 }
1289
1290 if !db.contains_table(name.clone()) {
1291 return if if_exists {
1292 Ok(0)
1293 } else {
1294 Err(SQLRiteError::General(format!(
1295 "Table '{name}' does not exist"
1296 )))
1297 };
1298 }
1299
1300 db.tables.remove(&name);
1301 Ok(1)
1302}
1303
1304pub fn execute_drop_index(
1313 names: &[ObjectName],
1314 if_exists: bool,
1315 db: &mut Database,
1316) -> Result<usize> {
1317 if names.len() != 1 {
1318 return Err(SQLRiteError::NotImplemented(
1319 "DROP INDEX supports a single index per statement".to_string(),
1320 ));
1321 }
1322 let name = names[0].to_string();
1323
1324 for table in db.tables.values_mut() {
1325 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
1326 if secondary.origin == IndexOrigin::Auto {
1327 return Err(SQLRiteError::General(format!(
1328 "cannot drop auto-created index '{name}' (drop the column or table instead)"
1329 )));
1330 }
1331 table.secondary_indexes.retain(|i| i.name != name);
1332 return Ok(1);
1333 }
1334 if table.hnsw_indexes.iter().any(|i| i.name == name) {
1335 table.hnsw_indexes.retain(|i| i.name != name);
1336 return Ok(1);
1337 }
1338 if table.fts_indexes.iter().any(|i| i.name == name) {
1339 table.fts_indexes.retain(|i| i.name != name);
1340 return Ok(1);
1341 }
1342 }
1343
1344 if if_exists {
1345 Ok(0)
1346 } else {
1347 Err(SQLRiteError::General(format!(
1348 "Index '{name}' does not exist"
1349 )))
1350 }
1351}
1352
1353pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
1365 let table_name = alter.name.to_string();
1366
1367 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
1368 return Err(SQLRiteError::General(format!(
1369 "'{}' is a reserved name used by the internal schema catalog",
1370 crate::sql::pager::MASTER_TABLE_NAME
1371 )));
1372 }
1373
1374 if !db.contains_table(table_name.clone()) {
1375 return if alter.if_exists {
1376 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
1377 } else {
1378 Err(SQLRiteError::General(format!(
1379 "Table '{table_name}' does not exist"
1380 )))
1381 };
1382 }
1383
1384 if alter.operations.len() != 1 {
1385 return Err(SQLRiteError::NotImplemented(
1386 "ALTER TABLE supports one operation per statement".to_string(),
1387 ));
1388 }
1389
1390 match &alter.operations[0] {
1391 AlterTableOperation::RenameTable { table_name: kind } => {
1392 let new_name = match kind {
1393 RenameTableNameKind::To(name) => name.to_string(),
1394 RenameTableNameKind::As(_) => {
1395 return Err(SQLRiteError::NotImplemented(
1396 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
1397 .to_string(),
1398 ));
1399 }
1400 };
1401 alter_rename_table(db, &table_name, &new_name)?;
1402 Ok(format!(
1403 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
1404 ))
1405 }
1406 AlterTableOperation::RenameColumn {
1407 old_column_name,
1408 new_column_name,
1409 } => {
1410 let old = old_column_name.value.clone();
1411 let new = new_column_name.value.clone();
1412 db.get_table_mut(table_name.clone())?
1413 .rename_column(&old, &new)?;
1414 Ok(format!(
1415 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
1416 ))
1417 }
1418 AlterTableOperation::AddColumn {
1419 column_def,
1420 if_not_exists,
1421 ..
1422 } => {
1423 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
1424 let table = db.get_table_mut(table_name.clone())?;
1425 if *if_not_exists && table.contains_column(parsed.name.clone()) {
1426 return Ok(format!(
1427 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
1428 parsed.name
1429 ));
1430 }
1431 let col_name = parsed.name.clone();
1432 table.add_column(parsed)?;
1433 Ok(format!(
1434 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
1435 ))
1436 }
1437 AlterTableOperation::DropColumn {
1438 column_names,
1439 if_exists,
1440 ..
1441 } => {
1442 if column_names.len() != 1 {
1443 return Err(SQLRiteError::NotImplemented(
1444 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
1445 ));
1446 }
1447 let col_name = column_names[0].value.clone();
1448 let table = db.get_table_mut(table_name.clone())?;
1449 if *if_exists && !table.contains_column(col_name.clone()) {
1450 return Ok(format!(
1451 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
1452 ));
1453 }
1454 table.drop_column(&col_name)?;
1455 Ok(format!(
1456 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
1457 ))
1458 }
1459 other => Err(SQLRiteError::NotImplemented(format!(
1460 "ALTER TABLE operation {other:?} is not supported"
1461 ))),
1462 }
1463}
1464
1465pub fn execute_vacuum(db: &mut Database) -> Result<String> {
1475 if db.in_transaction() {
1476 return Err(SQLRiteError::General(
1477 "VACUUM cannot run inside a transaction".to_string(),
1478 ));
1479 }
1480 let path = match db.source_path.clone() {
1481 Some(p) => p,
1482 None => {
1483 return Ok("VACUUM is a no-op for in-memory databases".to_string());
1484 }
1485 };
1486 if let Some(pager) = db.pager.as_mut() {
1492 let _ = pager.checkpoint();
1493 }
1494 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1495 let pages_before = db
1496 .pager
1497 .as_ref()
1498 .map(|p| p.header().page_count)
1499 .unwrap_or(0);
1500 crate::sql::pager::vacuum_database(db, &path)?;
1501 if let Some(pager) = db.pager.as_mut() {
1504 let _ = pager.checkpoint();
1505 }
1506 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
1507 let pages_after = db
1508 .pager
1509 .as_ref()
1510 .map(|p| p.header().page_count)
1511 .unwrap_or(0);
1512 let pages_reclaimed = pages_before.saturating_sub(pages_after);
1513 let bytes_reclaimed = size_before.saturating_sub(size_after);
1514 Ok(format!(
1515 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
1516 ))
1517}
1518
1519fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
1525 if new == crate::sql::pager::MASTER_TABLE_NAME {
1526 return Err(SQLRiteError::General(format!(
1527 "'{}' is a reserved name used by the internal schema catalog",
1528 crate::sql::pager::MASTER_TABLE_NAME
1529 )));
1530 }
1531 if old == new {
1532 return Ok(());
1533 }
1534 if db.contains_table(new.to_string()) {
1535 return Err(SQLRiteError::General(format!(
1536 "target table '{new}' already exists"
1537 )));
1538 }
1539
1540 let mut table = db
1541 .tables
1542 .remove(old)
1543 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
1544 table.tb_name = new.to_string();
1545 for idx in table.secondary_indexes.iter_mut() {
1546 idx.table_name = new.to_string();
1547 if idx.origin == IndexOrigin::Auto
1548 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
1549 {
1550 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
1551 }
1552 }
1553 db.tables.insert(new.to_string(), table);
1554 Ok(())
1555}
1556
1557#[derive(Debug, Clone, Copy)]
1561enum IndexMethod {
1562 Btree,
1563 Hnsw,
1564 Fts,
1566}
1567
1568fn create_btree_index(
1570 db: &mut Database,
1571 table_name: &str,
1572 index_name: &str,
1573 column_name: &str,
1574 datatype: &DataType,
1575 unique: bool,
1576 existing: &[(i64, Value)],
1577) -> Result<String> {
1578 let mut idx = SecondaryIndex::new(
1579 index_name.to_string(),
1580 table_name.to_string(),
1581 column_name.to_string(),
1582 datatype,
1583 unique,
1584 IndexOrigin::Explicit,
1585 )?;
1586
1587 for (rowid, v) in existing {
1591 if unique && idx.would_violate_unique(v) {
1592 return Err(SQLRiteError::General(format!(
1593 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
1594 already contains the duplicate value {}",
1595 v.to_display_string()
1596 )));
1597 }
1598 idx.insert(v, *rowid)?;
1599 }
1600
1601 let table_mut = db.get_table_mut(table_name.to_string())?;
1602 table_mut.secondary_indexes.push(idx);
1603 Ok(index_name.to_string())
1604}
1605
1606fn create_hnsw_index(
1608 db: &mut Database,
1609 table_name: &str,
1610 index_name: &str,
1611 column_name: &str,
1612 datatype: &DataType,
1613 unique: bool,
1614 metric: DistanceMetric,
1615 existing: &[(i64, Value)],
1616) -> Result<String> {
1617 let dim = match datatype {
1620 DataType::Vector(d) => *d,
1621 other => {
1622 return Err(SQLRiteError::General(format!(
1623 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
1624 )));
1625 }
1626 };
1627
1628 if unique {
1629 return Err(SQLRiteError::General(
1630 "UNIQUE has no meaning for HNSW indexes".to_string(),
1631 ));
1632 }
1633
1634 let seed = hash_str_to_seed(index_name);
1645 let mut idx = HnswIndex::new(metric, seed);
1646
1647 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
1651 std::collections::HashMap::with_capacity(existing.len());
1652 for (rowid, v) in existing {
1653 match v {
1654 Value::Vector(vec) => {
1655 if vec.len() != dim {
1656 return Err(SQLRiteError::Internal(format!(
1657 "row {rowid} stores a {}-dim vector in column '{column_name}' \
1658 declared as VECTOR({dim}) — schema invariant violated",
1659 vec.len()
1660 )));
1661 }
1662 vec_map.insert(*rowid, vec.clone());
1663 }
1664 _ => continue,
1668 }
1669 }
1670
1671 for (rowid, _) in existing {
1672 if let Some(v) = vec_map.get(rowid) {
1673 let v_clone = v.clone();
1674 idx.insert(*rowid, &v_clone, |id| {
1675 vec_map.get(&id).cloned().unwrap_or_default()
1676 })?;
1677 }
1678 }
1679
1680 let table_mut = db.get_table_mut(table_name.to_string())?;
1681 table_mut.hnsw_indexes.push(HnswIndexEntry {
1682 name: index_name.to_string(),
1683 column_name: column_name.to_string(),
1684 metric,
1685 index: idx,
1686 needs_rebuild: false,
1688 });
1689 Ok(index_name.to_string())
1690}
1691
1692fn parse_hnsw_with_options(
1703 with: &[Expr],
1704 index_name: &str,
1705 method: IndexMethod,
1706) -> Result<Option<DistanceMetric>> {
1707 if with.is_empty() {
1708 return Ok(None);
1709 }
1710 if !matches!(method, IndexMethod::Hnsw) {
1711 return Err(SQLRiteError::General(format!(
1712 "CREATE INDEX '{index_name}' has a WITH (...) clause but its index method \
1713 doesn't support any options — only `USING hnsw` recognises `WITH (metric = ...)`"
1714 )));
1715 }
1716
1717 let mut metric: Option<DistanceMetric> = None;
1718 for opt in with {
1719 let Expr::BinaryOp { left, op, right } = opt else {
1720 return Err(SQLRiteError::General(format!(
1721 "CREATE INDEX '{index_name}': unsupported WITH option {opt:?} \
1722 (expected `key = 'value'`)"
1723 )));
1724 };
1725 if !matches!(op, BinaryOperator::Eq) {
1726 return Err(SQLRiteError::General(format!(
1727 "CREATE INDEX '{index_name}': WITH options must use `=` (got {op:?})"
1728 )));
1729 }
1730 let key = match left.as_ref() {
1731 Expr::Identifier(ident) => ident.value.clone(),
1732 other => {
1733 return Err(SQLRiteError::General(format!(
1734 "CREATE INDEX '{index_name}': WITH option key must be a bare identifier, \
1735 got {other:?}"
1736 )));
1737 }
1738 };
1739 let value = match right.as_ref() {
1740 Expr::Value(v) => match &v.value {
1741 AstValue::SingleQuotedString(s) => s.clone(),
1742 AstValue::DoubleQuotedString(s) => s.clone(),
1743 other => {
1744 return Err(SQLRiteError::General(format!(
1745 "CREATE INDEX '{index_name}': WITH option '{key}' value must be \
1746 a quoted string, got {other:?}"
1747 )));
1748 }
1749 },
1750 Expr::Identifier(ident) => ident.value.clone(),
1751 other => {
1752 return Err(SQLRiteError::General(format!(
1753 "CREATE INDEX '{index_name}': WITH option '{key}' value must be a \
1754 quoted string, got {other:?}"
1755 )));
1756 }
1757 };
1758
1759 if key.eq_ignore_ascii_case("metric") {
1760 let parsed = DistanceMetric::from_sql_name(&value).ok_or_else(|| {
1761 SQLRiteError::General(format!(
1762 "CREATE INDEX '{index_name}': unknown HNSW metric '{value}' \
1763 (try 'l2', 'cosine', or 'dot')"
1764 ))
1765 })?;
1766 if metric.is_some() {
1767 return Err(SQLRiteError::General(format!(
1768 "CREATE INDEX '{index_name}': metric specified more than once in WITH (...)"
1769 )));
1770 }
1771 metric = Some(parsed);
1772 } else {
1773 return Err(SQLRiteError::General(format!(
1774 "CREATE INDEX '{index_name}': unknown WITH option '{key}' \
1775 (only 'metric' is recognised on HNSW indexes)"
1776 )));
1777 }
1778 }
1779
1780 Ok(metric)
1781}
1782
1783fn create_fts_index(
1788 db: &mut Database,
1789 table_name: &str,
1790 index_name: &str,
1791 column_name: &str,
1792 datatype: &DataType,
1793 unique: bool,
1794 existing: &[(i64, Value)],
1795) -> Result<String> {
1796 match datatype {
1801 DataType::Text => {}
1802 other => {
1803 return Err(SQLRiteError::General(format!(
1804 "USING fts requires a TEXT column; '{column_name}' is {other}"
1805 )));
1806 }
1807 }
1808
1809 if unique {
1810 return Err(SQLRiteError::General(
1811 "UNIQUE has no meaning for FTS indexes".to_string(),
1812 ));
1813 }
1814
1815 let mut idx = PostingList::new();
1816 for (rowid, v) in existing {
1817 if let Value::Text(text) = v {
1818 idx.insert(*rowid, text);
1819 }
1820 }
1823
1824 let table_mut = db.get_table_mut(table_name.to_string())?;
1825 table_mut.fts_indexes.push(FtsIndexEntry {
1826 name: index_name.to_string(),
1827 column_name: column_name.to_string(),
1828 index: idx,
1829 needs_rebuild: false,
1830 });
1831 Ok(index_name.to_string())
1832}
1833
1834fn hash_str_to_seed(s: &str) -> u64 {
1838 let mut h: u64 = 0xCBF29CE484222325;
1839 for b in s.as_bytes() {
1840 h ^= *b as u64;
1841 h = h.wrapping_mul(0x100000001B3);
1842 }
1843 h
1844}
1845
1846fn clone_datatype(dt: &DataType) -> DataType {
1849 match dt {
1850 DataType::Integer => DataType::Integer,
1851 DataType::Text => DataType::Text,
1852 DataType::Real => DataType::Real,
1853 DataType::Bool => DataType::Bool,
1854 DataType::Vector(dim) => DataType::Vector(*dim),
1855 DataType::Json => DataType::Json,
1856 DataType::None => DataType::None,
1857 DataType::Invalid => DataType::Invalid,
1858 }
1859}
1860
1861fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1862 if tables.len() != 1 {
1863 return Err(SQLRiteError::NotImplemented(
1864 "multi-table DELETE is not supported yet".to_string(),
1865 ));
1866 }
1867 extract_table_name(&tables[0])
1868}
1869
1870fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1871 if !twj.joins.is_empty() {
1872 return Err(SQLRiteError::NotImplemented(
1873 "JOIN is not supported yet".to_string(),
1874 ));
1875 }
1876 match &twj.relation {
1877 TableFactor::Table { name, .. } => Ok(name.to_string()),
1878 _ => Err(SQLRiteError::NotImplemented(
1879 "only plain table references are supported".to_string(),
1880 )),
1881 }
1882}
1883
1884enum RowidSource {
1886 IndexProbe(Vec<i64>),
1890 FullScan,
1893}
1894
1895fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1900 let Some(expr) = selection else {
1901 return Ok(RowidSource::FullScan);
1902 };
1903 let Some((col, literal)) = try_extract_equality(expr) else {
1904 return Ok(RowidSource::FullScan);
1905 };
1906 let Some(idx) = table.index_for_column(&col) else {
1907 return Ok(RowidSource::FullScan);
1908 };
1909
1910 let literal_value = match convert_literal(&literal) {
1914 Ok(v) => v,
1915 Err(_) => return Ok(RowidSource::FullScan),
1916 };
1917
1918 let mut rowids = idx.lookup(&literal_value);
1922 rowids.sort_unstable();
1923 Ok(RowidSource::IndexProbe(rowids))
1924}
1925
1926fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1930 let peeled = match expr {
1932 Expr::Nested(inner) => inner.as_ref(),
1933 other => other,
1934 };
1935 let Expr::BinaryOp { left, op, right } = peeled else {
1936 return None;
1937 };
1938 if !matches!(op, BinaryOperator::Eq) {
1939 return None;
1940 }
1941 let col_from = |e: &Expr| -> Option<String> {
1942 match e {
1943 Expr::Identifier(ident) => Some(ident.value.clone()),
1944 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1945 _ => None,
1946 }
1947 };
1948 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1949 if let Expr::Value(v) = e {
1950 Some(v.value.clone())
1951 } else {
1952 None
1953 }
1954 };
1955 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1956 return Some((c, l));
1957 }
1958 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1959 return Some((c, l));
1960 }
1961 None
1962}
1963
1964fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1989 if k == 0 {
1990 return None;
1991 }
1992
1993 let func = match order_expr {
1996 Expr::Function(f) => f,
1997 _ => return None,
1998 };
1999 let fname = match func.name.0.as_slice() {
2000 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2001 _ => return None,
2002 };
2003 let query_metric = match fname.as_str() {
2004 "vec_distance_l2" => DistanceMetric::L2,
2005 "vec_distance_cosine" => DistanceMetric::Cosine,
2006 "vec_distance_dot" => DistanceMetric::Dot,
2007 _ => return None,
2008 };
2009
2010 let arg_list = match &func.args {
2012 FunctionArguments::List(l) => &l.args,
2013 _ => return None,
2014 };
2015 if arg_list.len() != 2 {
2016 return None;
2017 }
2018 let exprs: Vec<&Expr> = arg_list
2019 .iter()
2020 .filter_map(|a| match a {
2021 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
2022 _ => None,
2023 })
2024 .collect();
2025 if exprs.len() != 2 {
2026 return None;
2027 }
2028
2029 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
2034 Some(v) => v,
2035 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
2036 Some(v) => v,
2037 None => return None,
2038 },
2039 };
2040
2041 let entry = table
2046 .hnsw_indexes
2047 .iter()
2048 .find(|e| e.column_name == col_name && e.metric == query_metric)?;
2049
2050 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
2056 Some(c) => match &c.datatype {
2057 DataType::Vector(d) => *d,
2058 _ => return None,
2059 },
2060 None => return None,
2061 };
2062 if query_vec.len() != declared_dim {
2063 return None;
2064 }
2065
2066 let column_for_closure = col_name.clone();
2070 let table_ref = table;
2071 let result = entry
2072 .index
2073 .search(&query_vec, k, |id| {
2074 match table_ref.get_value(&column_for_closure, id) {
2075 Some(Value::Vector(v)) => v,
2076 _ => Vec::new(),
2077 }
2078 })
2079 .ok()?;
2080 Some(result)
2081}
2082
2083fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
2099 if k == 0 || ascending {
2100 return None;
2104 }
2105
2106 let func = match order_expr {
2107 Expr::Function(f) => f,
2108 _ => return None,
2109 };
2110 let fname = match func.name.0.as_slice() {
2111 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2112 _ => return None,
2113 };
2114 if fname != "bm25_score" {
2115 return None;
2116 }
2117
2118 let arg_list = match &func.args {
2119 FunctionArguments::List(l) => &l.args,
2120 _ => return None,
2121 };
2122 if arg_list.len() != 2 {
2123 return None;
2124 }
2125 let exprs: Vec<&Expr> = arg_list
2126 .iter()
2127 .filter_map(|a| match a {
2128 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
2129 _ => None,
2130 })
2131 .collect();
2132 if exprs.len() != 2 {
2133 return None;
2134 }
2135
2136 let col_name = match exprs[0] {
2138 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2139 _ => return None,
2140 };
2141
2142 let query = match exprs[1] {
2146 Expr::Value(v) => match &v.value {
2147 AstValue::SingleQuotedString(s) => s.clone(),
2148 _ => return None,
2149 },
2150 _ => return None,
2151 };
2152
2153 let entry = table
2154 .fts_indexes
2155 .iter()
2156 .find(|e| e.column_name == col_name)?;
2157
2158 let scored = entry.index.query(&query, &Bm25Params::default());
2159 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
2160 if out.len() > k {
2161 out.truncate(k);
2162 }
2163 Some(out)
2164}
2165
2166fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
2171 let col_name = match a {
2172 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
2173 _ => return None,
2174 };
2175 let lit_str = match b {
2176 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
2177 format!("[{}]", ident.value)
2178 }
2179 _ => return None,
2180 };
2181 let v = parse_vector_literal(&lit_str).ok()?;
2182 Some((col_name, v))
2183}
2184
2185struct HeapEntry {
2198 key: Value,
2199 rowid: i64,
2200 asc: bool,
2201}
2202
2203impl PartialEq for HeapEntry {
2204 fn eq(&self, other: &Self) -> bool {
2205 self.cmp(other) == Ordering::Equal
2206 }
2207}
2208
2209impl Eq for HeapEntry {}
2210
2211impl PartialOrd for HeapEntry {
2212 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
2213 Some(self.cmp(other))
2214 }
2215}
2216
2217impl Ord for HeapEntry {
2218 fn cmp(&self, other: &Self) -> Ordering {
2219 let raw = compare_values(Some(&self.key), Some(&other.key));
2220 if self.asc { raw } else { raw.reverse() }
2221 }
2222}
2223
2224fn select_topk(
2233 matching: &[i64],
2234 table: &Table,
2235 order: &OrderByClause,
2236 k: usize,
2237) -> Result<Vec<i64>> {
2238 use std::collections::BinaryHeap;
2239
2240 if k == 0 || matching.is_empty() {
2241 return Ok(Vec::new());
2242 }
2243
2244 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
2245
2246 for &rowid in matching {
2247 let key = eval_expr(&order.expr, table, rowid)?;
2248 let entry = HeapEntry {
2249 key,
2250 rowid,
2251 asc: order.ascending,
2252 };
2253
2254 if heap.len() < k {
2255 heap.push(entry);
2256 } else {
2257 if entry < *heap.peek().unwrap() {
2261 heap.pop();
2262 heap.push(entry);
2263 }
2264 }
2265 }
2266
2267 Ok(heap
2272 .into_sorted_vec()
2273 .into_iter()
2274 .map(|e| e.rowid)
2275 .collect())
2276}
2277
2278fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
2279 let mut keys: Vec<(i64, Result<Value>)> = rowids
2287 .iter()
2288 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
2289 .collect();
2290
2291 for (_, k) in &keys {
2295 if let Err(e) = k {
2296 return Err(SQLRiteError::General(format!(
2297 "ORDER BY expression failed: {e}"
2298 )));
2299 }
2300 }
2301
2302 keys.sort_by(|(_, ka), (_, kb)| {
2303 let va = ka.as_ref().unwrap();
2306 let vb = kb.as_ref().unwrap();
2307 let ord = compare_values(Some(va), Some(vb));
2308 if order.ascending { ord } else { ord.reverse() }
2309 });
2310
2311 for (i, (rowid, _)) in keys.into_iter().enumerate() {
2313 rowids[i] = rowid;
2314 }
2315 Ok(())
2316}
2317
2318fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
2319 match (a, b) {
2320 (None, None) => Ordering::Equal,
2321 (None, _) => Ordering::Less,
2322 (_, None) => Ordering::Greater,
2323 (Some(a), Some(b)) => match (a, b) {
2324 (Value::Null, Value::Null) => Ordering::Equal,
2325 (Value::Null, _) => Ordering::Less,
2326 (_, Value::Null) => Ordering::Greater,
2327 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
2328 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
2329 (Value::Integer(x), Value::Real(y)) => {
2330 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
2331 }
2332 (Value::Real(x), Value::Integer(y)) => {
2333 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
2334 }
2335 (Value::Text(x), Value::Text(y)) => x.cmp(y),
2336 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
2337 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
2339 },
2340 }
2341}
2342
2343pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
2345 eval_predicate_scope(expr, &SingleTableScope::new(table, rowid))
2346}
2347
2348pub(crate) fn eval_predicate_scope(expr: &Expr, scope: &dyn RowScope) -> Result<bool> {
2352 let v = eval_expr_scope(expr, scope)?;
2353 match v {
2354 Value::Bool(b) => Ok(b),
2355 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
2357 other => Err(SQLRiteError::Internal(format!(
2358 "WHERE clause must evaluate to boolean, got {}",
2359 other.to_display_string()
2360 ))),
2361 }
2362}
2363
2364fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
2366 eval_expr_scope(expr, &SingleTableScope::new(table, rowid))
2367}
2368
2369fn eval_expr_scope(expr: &Expr, scope: &dyn RowScope) -> Result<Value> {
2370 match expr {
2371 Expr::Nested(inner) => eval_expr_scope(inner, scope),
2372
2373 Expr::Identifier(ident) => {
2374 if ident.quote_style == Some('[') {
2384 let raw = format!("[{}]", ident.value);
2385 let v = parse_vector_literal(&raw)?;
2386 return Ok(Value::Vector(v));
2387 }
2388 scope.lookup(None, &ident.value)
2389 }
2390
2391 Expr::CompoundIdentifier(parts) => {
2392 match parts.as_slice() {
2398 [only] => scope.lookup(None, &only.value),
2399 [q, c] => scope.lookup(Some(&q.value), &c.value),
2400 _ => Err(SQLRiteError::NotImplemented(format!(
2401 "compound identifier with {} parts is not supported",
2402 parts.len()
2403 ))),
2404 }
2405 }
2406
2407 Expr::Value(v) => convert_literal(&v.value),
2408
2409 Expr::UnaryOp { op, expr } => {
2410 let inner = eval_expr_scope(expr, scope)?;
2411 match op {
2412 UnaryOperator::Not => match inner {
2413 Value::Bool(b) => Ok(Value::Bool(!b)),
2414 Value::Null => Ok(Value::Null),
2415 other => Err(SQLRiteError::Internal(format!(
2416 "NOT applied to non-boolean value: {}",
2417 other.to_display_string()
2418 ))),
2419 },
2420 UnaryOperator::Minus => match inner {
2421 Value::Integer(i) => Ok(Value::Integer(-i)),
2422 Value::Real(f) => Ok(Value::Real(-f)),
2423 Value::Null => Ok(Value::Null),
2424 other => Err(SQLRiteError::Internal(format!(
2425 "unary minus on non-numeric value: {}",
2426 other.to_display_string()
2427 ))),
2428 },
2429 UnaryOperator::Plus => Ok(inner),
2430 other => Err(SQLRiteError::NotImplemented(format!(
2431 "unary operator {other:?} is not supported"
2432 ))),
2433 }
2434 }
2435
2436 Expr::BinaryOp { left, op, right } => match op {
2437 BinaryOperator::And => {
2438 let l = eval_expr_scope(left, scope)?;
2439 let r = eval_expr_scope(right, scope)?;
2440 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
2441 }
2442 BinaryOperator::Or => {
2443 let l = eval_expr_scope(left, scope)?;
2444 let r = eval_expr_scope(right, scope)?;
2445 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
2446 }
2447 cmp @ (BinaryOperator::Eq
2448 | BinaryOperator::NotEq
2449 | BinaryOperator::Lt
2450 | BinaryOperator::LtEq
2451 | BinaryOperator::Gt
2452 | BinaryOperator::GtEq) => {
2453 let l = eval_expr_scope(left, scope)?;
2454 let r = eval_expr_scope(right, scope)?;
2455 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2457 return Ok(Value::Bool(false));
2458 }
2459 let ord = compare_values(Some(&l), Some(&r));
2460 let result = match cmp {
2461 BinaryOperator::Eq => ord == Ordering::Equal,
2462 BinaryOperator::NotEq => ord != Ordering::Equal,
2463 BinaryOperator::Lt => ord == Ordering::Less,
2464 BinaryOperator::LtEq => ord != Ordering::Greater,
2465 BinaryOperator::Gt => ord == Ordering::Greater,
2466 BinaryOperator::GtEq => ord != Ordering::Less,
2467 _ => unreachable!(),
2468 };
2469 Ok(Value::Bool(result))
2470 }
2471 arith @ (BinaryOperator::Plus
2472 | BinaryOperator::Minus
2473 | BinaryOperator::Multiply
2474 | BinaryOperator::Divide
2475 | BinaryOperator::Modulo) => {
2476 let l = eval_expr_scope(left, scope)?;
2477 let r = eval_expr_scope(right, scope)?;
2478 eval_arith(arith, &l, &r)
2479 }
2480 BinaryOperator::StringConcat => {
2481 let l = eval_expr_scope(left, scope)?;
2482 let r = eval_expr_scope(right, scope)?;
2483 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2484 return Ok(Value::Null);
2485 }
2486 Ok(Value::Text(format!(
2487 "{}{}",
2488 l.to_display_string(),
2489 r.to_display_string()
2490 )))
2491 }
2492 other => Err(SQLRiteError::NotImplemented(format!(
2493 "binary operator {other:?} is not supported yet"
2494 ))),
2495 },
2496
2497 Expr::IsNull(inner) => {
2505 let v = eval_expr_scope(inner, scope)?;
2506 Ok(Value::Bool(matches!(v, Value::Null)))
2507 }
2508 Expr::IsNotNull(inner) => {
2509 let v = eval_expr_scope(inner, scope)?;
2510 Ok(Value::Bool(!matches!(v, Value::Null)))
2511 }
2512
2513 Expr::Like {
2520 negated,
2521 any,
2522 expr: lhs,
2523 pattern,
2524 escape_char,
2525 } => eval_like(
2526 scope,
2527 *negated,
2528 *any,
2529 lhs,
2530 pattern,
2531 escape_char.as_ref(),
2532 true,
2533 ),
2534 Expr::ILike {
2535 negated,
2536 any,
2537 expr: lhs,
2538 pattern,
2539 escape_char,
2540 } => eval_like(
2541 scope,
2542 *negated,
2543 *any,
2544 lhs,
2545 pattern,
2546 escape_char.as_ref(),
2547 true,
2548 ),
2549
2550 Expr::InList {
2556 expr: lhs,
2557 list,
2558 negated,
2559 } => eval_in_list(scope, lhs, list, *negated),
2560 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
2561 "IN (subquery) is not supported (only literal lists are)".to_string(),
2562 )),
2563
2564 Expr::Function(func) => eval_function(func, scope),
2575
2576 other => Err(SQLRiteError::NotImplemented(format!(
2577 "unsupported expression in WHERE/projection: {other:?}"
2578 ))),
2579 }
2580}
2581
2582fn eval_function(func: &sqlparser::ast::Function, scope: &dyn RowScope) -> Result<Value> {
2587 let name = match func.name.0.as_slice() {
2590 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
2591 _ => {
2592 return Err(SQLRiteError::NotImplemented(format!(
2593 "qualified function names not supported: {:?}",
2594 func.name
2595 )));
2596 }
2597 };
2598
2599 match name.as_str() {
2600 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
2601 let (a, b) = extract_two_vector_args(&name, &func.args, scope)?;
2602 let dist = match name.as_str() {
2603 "vec_distance_l2" => vec_distance_l2(&a, &b),
2604 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
2605 "vec_distance_dot" => vec_distance_dot(&a, &b),
2606 _ => unreachable!(),
2607 };
2608 Ok(Value::Real(dist as f64))
2614 }
2615 "json_extract" => json_fn_extract(&name, &func.args, scope),
2620 "json_type" => json_fn_type(&name, &func.args, scope),
2621 "json_array_length" => json_fn_array_length(&name, &func.args, scope),
2622 "json_object_keys" => json_fn_object_keys(&name, &func.args, scope),
2623 "fts_match" | "bm25_score" => {
2634 let Some((table, rowid)) = scope.single_table_view() else {
2635 return Err(SQLRiteError::NotImplemented(format!(
2636 "{name}() is not yet supported inside a JOIN query — \
2637 use it on a single-table SELECT or move the FTS lookup into a subquery"
2638 )));
2639 };
2640 let (entry, query) = resolve_fts_args(&name, &func.args, table, scope)?;
2641 Ok(match name.as_str() {
2642 "fts_match" => Value::Bool(entry.index.matches(rowid, &query)),
2643 "bm25_score" => {
2644 Value::Real(entry.index.score(rowid, &query, &Bm25Params::default()))
2645 }
2646 _ => unreachable!(),
2647 })
2648 }
2649 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
2653 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
2654 use it as a top-level projection item (HAVING is not yet supported)"
2655 ))),
2656 other => Err(SQLRiteError::NotImplemented(format!(
2657 "unknown function: {other}(...)"
2658 ))),
2659 }
2660}
2661
2662fn resolve_fts_args<'t>(
2667 fn_name: &str,
2668 args: &FunctionArguments,
2669 table: &'t Table,
2670 scope: &dyn RowScope,
2671) -> Result<(&'t FtsIndexEntry, String)> {
2672 let arg_list = match args {
2673 FunctionArguments::List(l) => &l.args,
2674 _ => {
2675 return Err(SQLRiteError::General(format!(
2676 "{fn_name}() expects exactly two arguments: (column, query_text)"
2677 )));
2678 }
2679 };
2680 if arg_list.len() != 2 {
2681 return Err(SQLRiteError::General(format!(
2682 "{fn_name}() expects exactly 2 arguments, got {}",
2683 arg_list.len()
2684 )));
2685 }
2686
2687 let col_expr = match &arg_list[0] {
2691 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2692 other => {
2693 return Err(SQLRiteError::NotImplemented(format!(
2694 "{fn_name}() argument 0 must be a column name, got {other:?}"
2695 )));
2696 }
2697 };
2698 let col_name = match col_expr {
2699 Expr::Identifier(ident) => ident.value.clone(),
2700 Expr::CompoundIdentifier(parts) => parts
2701 .last()
2702 .map(|p| p.value.clone())
2703 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
2704 other => {
2705 return Err(SQLRiteError::General(format!(
2706 "{fn_name}() argument 0 must be a column reference, got {other:?}"
2707 )));
2708 }
2709 };
2710
2711 let q_expr = match &arg_list[1] {
2715 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2716 other => {
2717 return Err(SQLRiteError::NotImplemented(format!(
2718 "{fn_name}() argument 1 must be a text expression, got {other:?}"
2719 )));
2720 }
2721 };
2722 let query = match eval_expr_scope(q_expr, scope)? {
2723 Value::Text(s) => s,
2724 other => {
2725 return Err(SQLRiteError::General(format!(
2726 "{fn_name}() argument 1 must be TEXT, got {}",
2727 other.to_display_string()
2728 )));
2729 }
2730 };
2731
2732 let entry = table
2733 .fts_indexes
2734 .iter()
2735 .find(|e| e.column_name == col_name)
2736 .ok_or_else(|| {
2737 SQLRiteError::General(format!(
2738 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
2739 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
2740 ))
2741 })?;
2742 Ok((entry, query))
2743}
2744
2745fn extract_json_and_path(
2759 fn_name: &str,
2760 args: &FunctionArguments,
2761 scope: &dyn RowScope,
2762) -> Result<(String, String)> {
2763 let arg_list = match args {
2764 FunctionArguments::List(l) => &l.args,
2765 _ => {
2766 return Err(SQLRiteError::General(format!(
2767 "{fn_name}() expects 1 or 2 arguments"
2768 )));
2769 }
2770 };
2771 if !(arg_list.len() == 1 || arg_list.len() == 2) {
2772 return Err(SQLRiteError::General(format!(
2773 "{fn_name}() expects 1 or 2 arguments, got {}",
2774 arg_list.len()
2775 )));
2776 }
2777 let first_expr = match &arg_list[0] {
2779 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2780 other => {
2781 return Err(SQLRiteError::NotImplemented(format!(
2782 "{fn_name}() argument 0 has unsupported shape: {other:?}"
2783 )));
2784 }
2785 };
2786 let json_text = match eval_expr_scope(first_expr, scope)? {
2787 Value::Text(s) => s,
2788 Value::Null => {
2789 return Err(SQLRiteError::General(format!(
2790 "{fn_name}() called on NULL — JSON column has no value for this row"
2791 )));
2792 }
2793 other => {
2794 return Err(SQLRiteError::General(format!(
2795 "{fn_name}() argument 0 is not JSON-typed: got {}",
2796 other.to_display_string()
2797 )));
2798 }
2799 };
2800
2801 let path = if arg_list.len() == 2 {
2803 let path_expr = match &arg_list[1] {
2804 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2805 other => {
2806 return Err(SQLRiteError::NotImplemented(format!(
2807 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2808 )));
2809 }
2810 };
2811 match eval_expr_scope(path_expr, scope)? {
2812 Value::Text(s) => s,
2813 other => {
2814 return Err(SQLRiteError::General(format!(
2815 "{fn_name}() path argument must be a string literal, got {}",
2816 other.to_display_string()
2817 )));
2818 }
2819 }
2820 } else {
2821 "$".to_string()
2822 };
2823
2824 Ok((json_text, path))
2825}
2826
2827fn walk_json_path<'a>(
2837 value: &'a serde_json::Value,
2838 path: &str,
2839) -> Result<Option<&'a serde_json::Value>> {
2840 let mut chars = path.chars().peekable();
2841 if chars.next() != Some('$') {
2842 return Err(SQLRiteError::General(format!(
2843 "JSON path must start with '$', got `{path}`"
2844 )));
2845 }
2846 let mut current = value;
2847 while let Some(&c) = chars.peek() {
2848 match c {
2849 '.' => {
2850 chars.next();
2851 let mut key = String::new();
2852 while let Some(&c) = chars.peek() {
2853 if c == '.' || c == '[' {
2854 break;
2855 }
2856 key.push(c);
2857 chars.next();
2858 }
2859 if key.is_empty() {
2860 return Err(SQLRiteError::General(format!(
2861 "JSON path has empty key after '.' in `{path}`"
2862 )));
2863 }
2864 match current.get(&key) {
2865 Some(v) => current = v,
2866 None => return Ok(None),
2867 }
2868 }
2869 '[' => {
2870 chars.next();
2871 let mut idx_str = String::new();
2872 while let Some(&c) = chars.peek() {
2873 if c == ']' {
2874 break;
2875 }
2876 idx_str.push(c);
2877 chars.next();
2878 }
2879 if chars.next() != Some(']') {
2880 return Err(SQLRiteError::General(format!(
2881 "JSON path has unclosed `[` in `{path}`"
2882 )));
2883 }
2884 let idx: usize = idx_str.trim().parse().map_err(|_| {
2885 SQLRiteError::General(format!(
2886 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2887 ))
2888 })?;
2889 match current.get(idx) {
2890 Some(v) => current = v,
2891 None => return Ok(None),
2892 }
2893 }
2894 other => {
2895 return Err(SQLRiteError::General(format!(
2896 "JSON path has unexpected character `{other}` in `{path}` \
2897 (expected `.`, `[`, or end-of-path)"
2898 )));
2899 }
2900 }
2901 }
2902 Ok(Some(current))
2903}
2904
2905fn json_value_to_sql(v: &serde_json::Value) -> Value {
2909 match v {
2910 serde_json::Value::Null => Value::Null,
2911 serde_json::Value::Bool(b) => Value::Bool(*b),
2912 serde_json::Value::Number(n) => {
2913 if let Some(i) = n.as_i64() {
2915 Value::Integer(i)
2916 } else if let Some(f) = n.as_f64() {
2917 Value::Real(f)
2918 } else {
2919 Value::Null
2920 }
2921 }
2922 serde_json::Value::String(s) => Value::Text(s.clone()),
2923 composite => Value::Text(composite.to_string()),
2927 }
2928}
2929
2930fn json_fn_extract(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2931 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2932 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2933 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2934 })?;
2935 match walk_json_path(&parsed, &path)? {
2936 Some(v) => Ok(json_value_to_sql(v)),
2937 None => Ok(Value::Null),
2938 }
2939}
2940
2941fn json_fn_type(name: &str, args: &FunctionArguments, scope: &dyn RowScope) -> Result<Value> {
2942 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2943 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2944 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2945 })?;
2946 let resolved = match walk_json_path(&parsed, &path)? {
2947 Some(v) => v,
2948 None => return Ok(Value::Null),
2949 };
2950 let ty = match resolved {
2951 serde_json::Value::Null => "null",
2952 serde_json::Value::Bool(true) => "true",
2953 serde_json::Value::Bool(false) => "false",
2954 serde_json::Value::Number(n) => {
2955 if n.is_i64() || n.is_u64() {
2956 "integer"
2957 } else {
2958 "real"
2959 }
2960 }
2961 serde_json::Value::String(_) => "text",
2962 serde_json::Value::Array(_) => "array",
2963 serde_json::Value::Object(_) => "object",
2964 };
2965 Ok(Value::Text(ty.to_string()))
2966}
2967
2968fn json_fn_array_length(
2969 name: &str,
2970 args: &FunctionArguments,
2971 scope: &dyn RowScope,
2972) -> Result<Value> {
2973 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2974 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2975 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2976 })?;
2977 let resolved = match walk_json_path(&parsed, &path)? {
2978 Some(v) => v,
2979 None => return Ok(Value::Null),
2980 };
2981 match resolved.as_array() {
2982 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2983 None => Err(SQLRiteError::General(format!(
2984 "{name}() resolved to a non-array value at path `{path}`"
2985 ))),
2986 }
2987}
2988
2989fn json_fn_object_keys(
2990 name: &str,
2991 args: &FunctionArguments,
2992 scope: &dyn RowScope,
2993) -> Result<Value> {
2994 let (json_text, path) = extract_json_and_path(name, args, scope)?;
2995 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2996 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2997 })?;
2998 let resolved = match walk_json_path(&parsed, &path)? {
2999 Some(v) => v,
3000 None => return Ok(Value::Null),
3001 };
3002 let obj = resolved.as_object().ok_or_else(|| {
3003 SQLRiteError::General(format!(
3004 "{name}() resolved to a non-object value at path `{path}`"
3005 ))
3006 })?;
3007 let keys: Vec<serde_json::Value> = obj
3014 .keys()
3015 .map(|k| serde_json::Value::String(k.clone()))
3016 .collect();
3017 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
3018}
3019
3020fn extract_two_vector_args(
3024 fn_name: &str,
3025 args: &FunctionArguments,
3026 scope: &dyn RowScope,
3027) -> Result<(Vec<f32>, Vec<f32>)> {
3028 let arg_list = match args {
3029 FunctionArguments::List(l) => &l.args,
3030 _ => {
3031 return Err(SQLRiteError::General(format!(
3032 "{fn_name}() expects exactly two vector arguments"
3033 )));
3034 }
3035 };
3036 if arg_list.len() != 2 {
3037 return Err(SQLRiteError::General(format!(
3038 "{fn_name}() expects exactly 2 arguments, got {}",
3039 arg_list.len()
3040 )));
3041 }
3042 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
3043 for (i, arg) in arg_list.iter().enumerate() {
3044 let expr = match arg {
3045 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
3046 other => {
3047 return Err(SQLRiteError::NotImplemented(format!(
3048 "{fn_name}() argument {i} has unsupported shape: {other:?}"
3049 )));
3050 }
3051 };
3052 let val = eval_expr_scope(expr, scope)?;
3053 match val {
3054 Value::Vector(v) => out.push(v),
3055 other => {
3056 return Err(SQLRiteError::General(format!(
3057 "{fn_name}() argument {i} is not a vector: got {}",
3058 other.to_display_string()
3059 )));
3060 }
3061 }
3062 }
3063 let b = out.pop().unwrap();
3064 let a = out.pop().unwrap();
3065 if a.len() != b.len() {
3066 return Err(SQLRiteError::General(format!(
3067 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
3068 a.len(),
3069 b.len()
3070 )));
3071 }
3072 Ok((a, b))
3073}
3074
3075pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
3078 debug_assert_eq!(a.len(), b.len());
3079 let mut sum = 0.0f32;
3080 for i in 0..a.len() {
3081 let d = a[i] - b[i];
3082 sum += d * d;
3083 }
3084 sum.sqrt()
3085}
3086
3087pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
3097 debug_assert_eq!(a.len(), b.len());
3098 let mut dot = 0.0f32;
3099 let mut norm_a_sq = 0.0f32;
3100 let mut norm_b_sq = 0.0f32;
3101 for i in 0..a.len() {
3102 dot += a[i] * b[i];
3103 norm_a_sq += a[i] * a[i];
3104 norm_b_sq += b[i] * b[i];
3105 }
3106 let denom = (norm_a_sq * norm_b_sq).sqrt();
3107 if denom == 0.0 {
3108 return Err(SQLRiteError::General(
3109 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
3110 ));
3111 }
3112 Ok(1.0 - dot / denom)
3113}
3114
3115pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
3119 debug_assert_eq!(a.len(), b.len());
3120 let mut dot = 0.0f32;
3121 for i in 0..a.len() {
3122 dot += a[i] * b[i];
3123 }
3124 -dot
3125}
3126
3127fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
3130 if matches!(l, Value::Null) || matches!(r, Value::Null) {
3131 return Ok(Value::Null);
3132 }
3133 match (l, r) {
3134 (Value::Integer(a), Value::Integer(b)) => match op {
3135 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
3136 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
3137 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
3138 BinaryOperator::Divide => {
3139 if *b == 0 {
3140 Err(SQLRiteError::General("division by zero".to_string()))
3141 } else {
3142 Ok(Value::Integer(a / b))
3143 }
3144 }
3145 BinaryOperator::Modulo => {
3146 if *b == 0 {
3147 Err(SQLRiteError::General("modulo by zero".to_string()))
3148 } else {
3149 Ok(Value::Integer(a % b))
3150 }
3151 }
3152 _ => unreachable!(),
3153 },
3154 (a, b) => {
3156 let af = as_number(a)?;
3157 let bf = as_number(b)?;
3158 match op {
3159 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
3160 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
3161 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
3162 BinaryOperator::Divide => {
3163 if bf == 0.0 {
3164 Err(SQLRiteError::General("division by zero".to_string()))
3165 } else {
3166 Ok(Value::Real(af / bf))
3167 }
3168 }
3169 BinaryOperator::Modulo => {
3170 if bf == 0.0 {
3171 Err(SQLRiteError::General("modulo by zero".to_string()))
3172 } else {
3173 Ok(Value::Real(af % bf))
3174 }
3175 }
3176 _ => unreachable!(),
3177 }
3178 }
3179 }
3180}
3181
3182fn as_number(v: &Value) -> Result<f64> {
3183 match v {
3184 Value::Integer(i) => Ok(*i as f64),
3185 Value::Real(f) => Ok(*f),
3186 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
3187 other => Err(SQLRiteError::General(format!(
3188 "arithmetic on non-numeric value '{}'",
3189 other.to_display_string()
3190 ))),
3191 }
3192}
3193
3194fn as_bool(v: &Value) -> Result<bool> {
3195 match v {
3196 Value::Bool(b) => Ok(*b),
3197 Value::Null => Ok(false),
3198 Value::Integer(i) => Ok(*i != 0),
3199 other => Err(SQLRiteError::Internal(format!(
3200 "expected boolean, got {}",
3201 other.to_display_string()
3202 ))),
3203 }
3204}
3205
3206#[allow(clippy::too_many_arguments)]
3211fn eval_like(
3212 scope: &dyn RowScope,
3213 negated: bool,
3214 any: bool,
3215 lhs: &Expr,
3216 pattern: &Expr,
3217 escape_char: Option<&AstValue>,
3218 case_insensitive: bool,
3219) -> Result<Value> {
3220 if any {
3221 return Err(SQLRiteError::NotImplemented(
3222 "LIKE ANY (...) is not supported".to_string(),
3223 ));
3224 }
3225 if escape_char.is_some() {
3226 return Err(SQLRiteError::NotImplemented(
3227 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
3228 ));
3229 }
3230
3231 let l = eval_expr_scope(lhs, scope)?;
3232 let p = eval_expr_scope(pattern, scope)?;
3233 if matches!(l, Value::Null) || matches!(p, Value::Null) {
3234 return Ok(Value::Null);
3235 }
3236 let text = match l {
3237 Value::Text(s) => s,
3238 other => other.to_display_string(),
3239 };
3240 let pat = match p {
3241 Value::Text(s) => s,
3242 other => other.to_display_string(),
3243 };
3244 let m = like_match(&text, &pat, case_insensitive);
3245 Ok(Value::Bool(if negated { !m } else { m }))
3246}
3247
3248fn eval_in_list(scope: &dyn RowScope, lhs: &Expr, list: &[Expr], negated: bool) -> Result<Value> {
3249 let l = eval_expr_scope(lhs, scope)?;
3250 if matches!(l, Value::Null) {
3251 return Ok(Value::Null);
3252 }
3253 let mut saw_null = false;
3254 for item in list {
3255 let r = eval_expr_scope(item, scope)?;
3256 if matches!(r, Value::Null) {
3257 saw_null = true;
3258 continue;
3259 }
3260 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
3261 return Ok(Value::Bool(!negated));
3262 }
3263 }
3264 if saw_null {
3265 Ok(Value::Null)
3268 } else {
3269 Ok(Value::Bool(negated))
3270 }
3271}
3272
3273fn aggregate_rows(
3284 table: &Table,
3285 matching: &[i64],
3286 group_by: &[String],
3287 proj_items: &[ProjectionItem],
3288) -> Result<Vec<Vec<Value>>> {
3289 let template: Vec<Option<AggState>> = proj_items
3293 .iter()
3294 .map(|i| match &i.kind {
3295 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
3296 ProjectionKind::Column { .. } => None,
3297 })
3298 .collect();
3299
3300 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
3306 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
3307 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
3308
3309 for &rowid in matching {
3310 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
3311 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
3312 for col in group_by {
3313 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
3314 key.push(DistinctKey::from_value(&v));
3315 key_values.push(v);
3316 }
3317 let idx = match keys.iter().position(|k| k == &key) {
3318 Some(i) => i,
3319 None => {
3320 keys.push(key);
3321 group_states.push(template.clone());
3322 group_key_values.push(key_values);
3323 keys.len() - 1
3324 }
3325 };
3326
3327 for (slot, item) in proj_items.iter().enumerate() {
3328 if let ProjectionKind::Aggregate(call) = &item.kind {
3329 let v = match &call.arg {
3330 AggregateArg::Star => Value::Null,
3331 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
3332 };
3333 if let Some(state) = group_states[idx][slot].as_mut() {
3334 state.update(&v)?;
3335 }
3336 }
3337 }
3338 }
3339
3340 if keys.is_empty() && group_by.is_empty() {
3346 keys.push(Vec::new());
3349 group_states.push(template.clone());
3350 group_key_values.push(Vec::new());
3351 }
3352
3353 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
3355 for (group_idx, _) in keys.iter().enumerate() {
3356 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
3357 for (slot, item) in proj_items.iter().enumerate() {
3358 match &item.kind {
3359 ProjectionKind::Column { name: c, .. } => {
3360 let pos = group_by
3363 .iter()
3364 .position(|g| g == c)
3365 .expect("validated to be in GROUP BY");
3366 row.push(group_key_values[group_idx][pos].clone());
3367 }
3368 ProjectionKind::Aggregate(_) => {
3369 let state = group_states[group_idx][slot]
3370 .as_ref()
3371 .expect("aggregate slot has state");
3372 row.push(state.finalize());
3373 }
3374 }
3375 }
3376 rows.push(row);
3377 }
3378 Ok(rows)
3379}
3380
3381struct GroupRowScope<'a> {
3391 columns: &'a [String],
3392 values: &'a [Value],
3393}
3394
3395impl RowScope for GroupRowScope<'_> {
3396 fn lookup(&self, qualifier: Option<&str>, col: &str) -> Result<Value> {
3397 let _ = qualifier;
3400 self.columns
3401 .iter()
3402 .position(|c| c.eq_ignore_ascii_case(col))
3403 .map(|i| self.values[i].clone())
3404 .ok_or_else(|| {
3405 SQLRiteError::Internal(format!(
3406 "HAVING references '{col}', which is neither a GROUP BY column nor an \
3407 aggregate in scope"
3408 ))
3409 })
3410 }
3411
3412 fn single_table_view(&self) -> Option<(&Table, i64)> {
3413 None
3414 }
3415}
3416
3417fn lower_having_expr(expr: &Expr, items: &mut Vec<ProjectionItem>) -> Result<Expr> {
3425 Ok(match expr {
3426 Expr::Function(func) => {
3427 let is_aggregate = matches!(
3428 func.name.0.as_slice(),
3429 [ObjectNamePart::Identifier(ident)] if AggregateFn::from_name(&ident.value).is_some()
3430 );
3431 if !is_aggregate {
3432 return Ok(expr.clone());
3433 }
3434 let call = parse_aggregate_call(func)?;
3435 let display = call.display_name();
3436 let already_known = items
3442 .iter()
3443 .any(|i| i.output_name().eq_ignore_ascii_case(&display));
3444 if !already_known {
3445 items.push(ProjectionItem {
3446 kind: ProjectionKind::Aggregate(call),
3447 alias: None,
3448 });
3449 }
3450 Expr::Identifier(Ident::new(display))
3451 }
3452 Expr::Nested(inner) => Expr::Nested(Box::new(lower_having_expr(inner, items)?)),
3453 Expr::UnaryOp { op, expr: inner } => Expr::UnaryOp {
3454 op: *op,
3455 expr: Box::new(lower_having_expr(inner, items)?),
3456 },
3457 Expr::BinaryOp { left, op, right } => Expr::BinaryOp {
3458 left: Box::new(lower_having_expr(left, items)?),
3459 op: op.clone(),
3460 right: Box::new(lower_having_expr(right, items)?),
3461 },
3462 Expr::IsNull(inner) => Expr::IsNull(Box::new(lower_having_expr(inner, items)?)),
3463 Expr::IsNotNull(inner) => Expr::IsNotNull(Box::new(lower_having_expr(inner, items)?)),
3464 Expr::InList {
3465 expr: lhs,
3466 list,
3467 negated,
3468 } => Expr::InList {
3469 expr: Box::new(lower_having_expr(lhs, items)?),
3470 list: list
3471 .iter()
3472 .map(|e| lower_having_expr(e, items))
3473 .collect::<Result<Vec<_>>>()?,
3474 negated: *negated,
3475 },
3476 Expr::Like {
3477 negated,
3478 any,
3479 expr: lhs,
3480 pattern,
3481 escape_char,
3482 } => Expr::Like {
3483 negated: *negated,
3484 any: *any,
3485 expr: Box::new(lower_having_expr(lhs, items)?),
3486 pattern: Box::new(lower_having_expr(pattern, items)?),
3487 escape_char: escape_char.clone(),
3488 },
3489 Expr::ILike {
3490 negated,
3491 any,
3492 expr: lhs,
3493 pattern,
3494 escape_char,
3495 } => Expr::ILike {
3496 negated: *negated,
3497 any: *any,
3498 expr: Box::new(lower_having_expr(lhs, items)?),
3499 pattern: Box::new(lower_having_expr(pattern, items)?),
3500 escape_char: escape_char.clone(),
3501 },
3502 other => other.clone(),
3505 })
3506}
3507
3508fn filter_groups_by_having(
3512 rows: Vec<Vec<Value>>,
3513 having: &Expr,
3514 columns: &[String],
3515) -> Result<Vec<Vec<Value>>> {
3516 let mut out = Vec::with_capacity(rows.len());
3517 for row in rows {
3518 let scope = GroupRowScope {
3519 columns,
3520 values: &row,
3521 };
3522 let keep = match eval_expr_scope(having, &scope)? {
3523 Value::Bool(b) => b,
3524 Value::Null => false,
3525 Value::Integer(i) => i != 0,
3526 other => {
3527 return Err(SQLRiteError::Internal(format!(
3528 "HAVING clause must evaluate to boolean, got {}",
3529 other.to_display_string()
3530 )));
3531 }
3532 };
3533 if keep {
3534 out.push(row);
3535 }
3536 }
3537 Ok(out)
3538}
3539
3540fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
3544 use std::collections::HashSet;
3545 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
3546 let mut out = Vec::with_capacity(rows.len());
3547 for row in rows {
3548 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
3549 if seen.insert(key) {
3550 out.push(row);
3551 }
3552 }
3553 out
3554}
3555
3556fn sort_output_rows(
3560 rows: &mut [Vec<Value>],
3561 columns: &[String],
3562 proj_items: &[ProjectionItem],
3563 order: &OrderByClause,
3564) -> Result<()> {
3565 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
3566 rows.sort_by(|a, b| {
3567 let va = &a[target_idx];
3568 let vb = &b[target_idx];
3569 let ord = compare_values(Some(va), Some(vb));
3570 if order.ascending { ord } else { ord.reverse() }
3571 });
3572 Ok(())
3573}
3574
3575fn resolve_order_by_index(
3578 expr: &Expr,
3579 columns: &[String],
3580 proj_items: &[ProjectionItem],
3581) -> Result<usize> {
3582 let target_name: Option<String> = match expr {
3584 Expr::Identifier(ident) => Some(ident.value.clone()),
3585 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
3586 Expr::Function(_) => None,
3587 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
3588 other => {
3589 return Err(SQLRiteError::NotImplemented(format!(
3590 "ORDER BY expression not supported on aggregating queries: {other:?}"
3591 )));
3592 }
3593 };
3594 if let Some(name) = target_name {
3595 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
3596 return Ok(i);
3597 }
3598 return Err(SQLRiteError::Internal(format!(
3599 "ORDER BY references unknown column '{name}' in the SELECT output"
3600 )));
3601 }
3602 if let Expr::Function(func) = expr {
3606 let user_disp = format_function_display(func);
3607 for (i, item) in proj_items.iter().enumerate() {
3608 if let ProjectionKind::Aggregate(call) = &item.kind
3609 && call.display_name().eq_ignore_ascii_case(&user_disp)
3610 {
3611 return Ok(i);
3612 }
3613 }
3614 return Err(SQLRiteError::Internal(format!(
3615 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
3616 )));
3617 }
3618 Err(SQLRiteError::Internal(
3619 "ORDER BY expression could not be resolved against the output columns".to_string(),
3620 ))
3621}
3622
3623fn format_function_display(func: &sqlparser::ast::Function) -> String {
3627 let name = match func.name.0.as_slice() {
3628 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
3629 _ => format!("{:?}", func.name).to_uppercase(),
3630 };
3631 let inner = match &func.args {
3632 FunctionArguments::List(l) => {
3633 let distinct = matches!(
3634 l.duplicate_treatment,
3635 Some(sqlparser::ast::DuplicateTreatment::Distinct)
3636 );
3637 let arg = l.args.first().map(|a| match a {
3638 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
3639 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
3640 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
3641 parts.last().map(|p| p.value.clone()).unwrap_or_default()
3642 }
3643 _ => String::new(),
3644 });
3645 match (distinct, arg) {
3646 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
3647 (_, Some(a)) => a,
3648 _ => String::new(),
3649 }
3650 }
3651 _ => String::new(),
3652 };
3653 format!("{name}({inner})")
3654}
3655
3656fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
3657 use sqlparser::ast::Value as AstValue;
3658 match v {
3659 AstValue::Number(n, _) => {
3660 if let Ok(i) = n.parse::<i64>() {
3661 Ok(Value::Integer(i))
3662 } else if let Ok(f) = n.parse::<f64>() {
3663 Ok(Value::Real(f))
3664 } else {
3665 Err(SQLRiteError::Internal(format!(
3666 "could not parse numeric literal '{n}'"
3667 )))
3668 }
3669 }
3670 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
3671 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
3672 AstValue::Null => Ok(Value::Null),
3673 other => Err(SQLRiteError::NotImplemented(format!(
3674 "unsupported literal value: {other:?}"
3675 ))),
3676 }
3677}
3678
3679#[cfg(test)]
3680mod tests {
3681 use super::*;
3682
3683 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
3690 (a - b).abs() < eps
3691 }
3692
3693 #[test]
3694 fn vec_distance_l2_identical_is_zero() {
3695 let v = vec![0.1, 0.2, 0.3];
3696 assert_eq!(vec_distance_l2(&v, &v), 0.0);
3697 }
3698
3699 #[test]
3700 fn vec_distance_l2_unit_basis_is_sqrt2() {
3701 let a = vec![1.0, 0.0];
3703 let b = vec![0.0, 1.0];
3704 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
3705 }
3706
3707 #[test]
3708 fn vec_distance_l2_known_value() {
3709 let a = vec![0.0, 0.0, 0.0];
3711 let b = vec![3.0, 4.0, 0.0];
3712 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
3713 }
3714
3715 #[test]
3716 fn vec_distance_cosine_identical_is_zero() {
3717 let v = vec![0.1, 0.2, 0.3];
3718 let d = vec_distance_cosine(&v, &v).unwrap();
3719 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
3720 }
3721
3722 #[test]
3723 fn vec_distance_cosine_orthogonal_is_one() {
3724 let a = vec![1.0, 0.0];
3727 let b = vec![0.0, 1.0];
3728 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
3729 }
3730
3731 #[test]
3732 fn vec_distance_cosine_opposite_is_two() {
3733 let a = vec![1.0, 0.0, 0.0];
3735 let b = vec![-1.0, 0.0, 0.0];
3736 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
3737 }
3738
3739 #[test]
3740 fn vec_distance_cosine_zero_magnitude_errors() {
3741 let a = vec![0.0, 0.0];
3743 let b = vec![1.0, 0.0];
3744 let err = vec_distance_cosine(&a, &b).unwrap_err();
3745 assert!(format!("{err}").contains("zero-magnitude"));
3746 }
3747
3748 #[test]
3749 fn vec_distance_dot_negates() {
3750 let a = vec![1.0, 2.0, 3.0];
3752 let b = vec![4.0, 5.0, 6.0];
3753 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
3754 }
3755
3756 #[test]
3757 fn vec_distance_dot_orthogonal_is_zero() {
3758 let a = vec![1.0, 0.0];
3760 let b = vec![0.0, 1.0];
3761 assert_eq!(vec_distance_dot(&a, &b), 0.0);
3762 }
3763
3764 #[test]
3765 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
3766 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
3772 let cos = vec_distance_cosine(&a, &b).unwrap();
3773 assert!(approx_eq(dot, cos - 1.0, 1e-5));
3774 }
3775
3776 use crate::sql::db::database::Database;
3781 use crate::sql::dialect::SqlriteDialect;
3782 use crate::sql::parser::select::SelectQuery;
3783 use sqlparser::parser::Parser;
3784
3785 fn seed_score_table(n: usize) -> Database {
3798 let mut db = Database::new("tempdb".to_string());
3799 crate::sql::process_command(
3800 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
3801 &mut db,
3802 )
3803 .expect("create");
3804 for i in 0..n {
3805 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
3809 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
3810 crate::sql::process_command(&sql, &mut db).expect("insert");
3811 }
3812 db
3813 }
3814
3815 fn parse_select(sql: &str) -> SelectQuery {
3819 let dialect = SqlriteDialect::new();
3820 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
3821 let stmt = ast.pop().expect("one statement");
3822 SelectQuery::new(&stmt).expect("select-query")
3823 }
3824
3825 #[test]
3826 fn topk_matches_full_sort_asc() {
3827 let db = seed_score_table(200);
3830 let table = db.get_table("docs".to_string()).unwrap();
3831 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3832 let order = q.order_by.as_ref().unwrap();
3833 let all_rowids = table.rowids();
3834
3835 let mut full = all_rowids.clone();
3837 sort_rowids(&mut full, table, order).unwrap();
3838 full.truncate(10);
3839
3840 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3842
3843 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
3844 }
3845
3846 #[test]
3847 fn topk_matches_full_sort_desc() {
3848 let db = seed_score_table(200);
3850 let table = db.get_table("docs".to_string()).unwrap();
3851 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
3852 let order = q.order_by.as_ref().unwrap();
3853 let all_rowids = table.rowids();
3854
3855 let mut full = all_rowids.clone();
3856 sort_rowids(&mut full, table, order).unwrap();
3857 full.truncate(10);
3858
3859 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
3860
3861 assert_eq!(
3862 topk, full,
3863 "top-k DESC via heap should match full-sort+truncate"
3864 );
3865 }
3866
3867 #[test]
3868 fn topk_k_larger_than_n_returns_everything_sorted() {
3869 let db = seed_score_table(50);
3874 let table = db.get_table("docs".to_string()).unwrap();
3875 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
3876 let order = q.order_by.as_ref().unwrap();
3877 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
3878 assert_eq!(topk.len(), 50);
3879 let scores: Vec<f64> = topk
3881 .iter()
3882 .filter_map(|r| match table.get_value("score", *r) {
3883 Some(Value::Real(f)) => Some(f),
3884 _ => None,
3885 })
3886 .collect();
3887 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
3888 }
3889
3890 #[test]
3891 fn topk_k_zero_returns_empty() {
3892 let db = seed_score_table(10);
3893 let table = db.get_table("docs".to_string()).unwrap();
3894 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
3895 let order = q.order_by.as_ref().unwrap();
3896 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
3897 assert!(topk.is_empty());
3898 }
3899
3900 #[test]
3901 fn topk_empty_input_returns_empty() {
3902 let db = seed_score_table(0);
3903 let table = db.get_table("docs".to_string()).unwrap();
3904 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
3905 let order = q.order_by.as_ref().unwrap();
3906 let topk = select_topk(&[], table, order, 5).unwrap();
3907 assert!(topk.is_empty());
3908 }
3909
3910 #[test]
3911 fn topk_works_through_select_executor_with_distance_function() {
3912 let mut db = Database::new("tempdb".to_string());
3916 crate::sql::process_command(
3917 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
3918 &mut db,
3919 )
3920 .unwrap();
3921 for v in &[
3928 "[1.0, 0.0]",
3929 "[2.0, 0.0]",
3930 "[0.0, 3.0]",
3931 "[1.0, 4.0]",
3932 "[10.0, 10.0]",
3933 ] {
3934 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
3935 .unwrap();
3936 }
3937 let resp = crate::sql::process_command(
3938 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
3939 &mut db,
3940 )
3941 .unwrap();
3942 assert!(resp.contains("3 rows returned"), "got: {resp}");
3945 }
3946
3947 #[test]
3970 #[ignore]
3971 fn topk_benchmark() {
3972 use std::time::Instant;
3973 const N: usize = 10_000;
3974 const K: usize = 10;
3975
3976 let db = seed_score_table(N);
3977 let table = db.get_table("docs".to_string()).unwrap();
3978 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3979 let order = q.order_by.as_ref().unwrap();
3980 let all_rowids = table.rowids();
3981
3982 let t0 = Instant::now();
3984 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3985 let heap_dur = t0.elapsed();
3986
3987 let t1 = Instant::now();
3989 let mut full = all_rowids.clone();
3990 sort_rowids(&mut full, table, order).unwrap();
3991 full.truncate(K);
3992 let sort_dur = t1.elapsed();
3993
3994 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3995 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3996 println!(" bounded heap: {heap_dur:?}");
3997 println!(" full sort+trunc: {sort_dur:?}");
3998 println!(" speedup ratio: {ratio:.2}×");
3999
4000 assert!(
4007 ratio > 1.4,
4008 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
4009 );
4010 }
4011
4012 fn run_select(db: &mut Database, sql: &str) -> String {
4020 crate::sql::process_command(sql, db).expect("select")
4021 }
4022
4023 #[test]
4024 fn where_is_null_returns_null_rows() {
4025 let mut db = Database::new("t".to_string());
4026 crate::sql::process_command(
4027 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
4028 &mut db,
4029 )
4030 .unwrap();
4031 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
4032 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
4033 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
4034 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
4035
4036 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
4037 assert!(
4038 response.contains("2 rows returned"),
4039 "IS NULL should return 2 rows, got: {response}"
4040 );
4041 }
4042
4043 #[test]
4044 fn where_is_not_null_returns_non_null_rows() {
4045 let mut db = Database::new("t".to_string());
4046 crate::sql::process_command(
4047 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
4048 &mut db,
4049 )
4050 .unwrap();
4051 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
4052 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
4053 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
4054
4055 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
4056 assert!(
4057 response.contains("2 rows returned"),
4058 "IS NOT NULL should return 2 rows, got: {response}"
4059 );
4060 }
4061
4062 #[test]
4063 fn where_is_null_on_indexed_column() {
4064 let mut db = Database::new("t".to_string());
4069 crate::sql::process_command(
4070 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
4071 &mut db,
4072 )
4073 .unwrap();
4074 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
4075 .unwrap();
4076 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
4077 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
4078 .unwrap();
4079
4080 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
4081 assert!(
4082 null_rows.contains("1 row returned"),
4083 "indexed IS NULL should return 1 row, got: {null_rows}"
4084 );
4085 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
4086 assert!(
4087 not_null_rows.contains("2 rows returned"),
4088 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
4089 );
4090 }
4091
4092 #[test]
4093 fn where_is_null_works_on_omitted_column() {
4094 let mut db = Database::new("t".to_string());
4098 crate::sql::process_command(
4099 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
4100 &mut db,
4101 )
4102 .unwrap();
4103 crate::sql::process_command(
4104 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
4105 &mut db,
4106 )
4107 .unwrap();
4108 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
4110
4111 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
4112 assert!(
4113 response.contains("1 row returned"),
4114 "IS NULL should match the omitted-column row, got: {response}"
4115 );
4116 }
4117
4118 #[test]
4119 fn where_is_null_combines_with_and_or() {
4120 let mut db = Database::new("t".to_string());
4124 crate::sql::process_command(
4125 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
4126 &mut db,
4127 )
4128 .unwrap();
4129 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
4130 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
4131 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
4132
4133 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
4134 assert!(
4135 response.contains("1 row returned"),
4136 "IS NULL combined with AND should match exactly row 2, got: {response}"
4137 );
4138 }
4139
4140 fn seed_employees() -> Database {
4146 let mut db = Database::new("t".to_string());
4147 crate::sql::process_command(
4148 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
4149 &mut db,
4150 )
4151 .unwrap();
4152 let rows = [
4153 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
4154 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
4155 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
4156 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
4157 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
4158 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
4159 ];
4160 for sql in rows {
4161 crate::sql::process_command(sql, &mut db).unwrap();
4162 }
4163 db
4164 }
4165
4166 fn run_rows(db: &Database, sql: &str) -> SelectResult {
4168 let q = parse_select(sql);
4169 execute_select_rows(q, db).expect("select")
4170 }
4171
4172 #[test]
4175 fn like_percent_prefix_case_insensitive() {
4176 let db = seed_employees();
4177 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
4178 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4180 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
4181 assert!(names.contains(&"Alice".to_string()));
4182 assert!(names.contains(&"alex".to_string()));
4183 }
4184
4185 #[test]
4186 fn like_underscore_singlechar() {
4187 let db = seed_employees();
4188 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
4189 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4191 assert_eq!(names, vec!["Eve".to_string()]);
4192 }
4193
4194 #[test]
4195 fn not_like_excludes_match() {
4196 let db = seed_employees();
4197 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
4198 assert_eq!(r.rows.len(), 4);
4200 }
4201
4202 #[test]
4203 fn like_with_null_excludes_row() {
4204 let db = seed_employees();
4205 let r = run_rows(
4207 &db,
4208 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
4209 );
4210 assert_eq!(r.rows.len(), 1);
4211 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
4212 }
4213
4214 #[test]
4217 fn in_list_positive() {
4218 let db = seed_employees();
4219 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
4220 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
4221 assert_eq!(names.len(), 3);
4222 assert!(names.contains(&"Alice".to_string()));
4223 assert!(names.contains(&"Bob".to_string()));
4224 assert!(names.contains(&"Dave".to_string()));
4225 }
4226
4227 #[test]
4228 fn not_in_excludes_listed() {
4229 let db = seed_employees();
4230 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
4231 assert_eq!(r.rows.len(), 4);
4233 }
4234
4235 #[test]
4236 fn in_list_with_null_three_valued() {
4237 let db = seed_employees();
4238 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
4241 assert_eq!(r.rows.len(), 1);
4242 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
4243 }
4244
4245 #[test]
4248 fn distinct_single_column() {
4249 let db = seed_employees();
4250 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
4251 assert_eq!(r.rows.len(), 3);
4253 }
4254
4255 #[test]
4256 fn distinct_multi_column_with_null() {
4257 let db = seed_employees();
4258 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
4260 assert_eq!(r.rows.len(), 5);
4262 }
4263
4264 #[test]
4267 fn count_star_no_groupby() {
4268 let db = seed_employees();
4269 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
4270 assert_eq!(r.rows.len(), 1);
4271 assert_eq!(r.rows[0][0], Value::Integer(6));
4272 }
4273
4274 #[test]
4275 fn count_col_skips_nulls() {
4276 let db = seed_employees();
4277 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
4278 assert_eq!(r.rows[0][0], Value::Integer(5));
4280 }
4281
4282 #[test]
4283 fn count_distinct_dedupes_and_skips_nulls() {
4284 let db = seed_employees();
4285 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
4286 assert_eq!(r.rows[0][0], Value::Integer(4));
4288 }
4289
4290 #[test]
4291 fn sum_int_stays_integer() {
4292 let db = seed_employees();
4293 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
4294 assert_eq!(r.rows[0][0], Value::Integer(490));
4296 }
4297
4298 #[test]
4299 fn avg_returns_real() {
4300 let db = seed_employees();
4301 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
4302 match &r.rows[0][0] {
4304 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
4305 other => panic!("expected Real, got {other:?}"),
4306 }
4307 }
4308
4309 #[test]
4310 fn min_max_skip_nulls() {
4311 let db = seed_employees();
4312 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
4313 assert_eq!(r.rows[0][0], Value::Integer(80));
4314 assert_eq!(r.rows[0][1], Value::Integer(120));
4315 }
4316
4317 #[test]
4318 fn aggregates_on_empty_table_emit_one_row() {
4319 let mut db = Database::new("t".to_string());
4320 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4321 let r = run_rows(
4322 &db,
4323 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
4324 );
4325 assert_eq!(r.rows.len(), 1);
4326 assert_eq!(r.rows[0][0], Value::Integer(0));
4327 assert_eq!(r.rows[0][1], Value::Null);
4328 assert_eq!(r.rows[0][2], Value::Null);
4329 assert_eq!(r.rows[0][3], Value::Null);
4330 assert_eq!(r.rows[0][4], Value::Null);
4331 }
4332
4333 #[test]
4336 fn group_by_single_col_with_count() {
4337 let db = seed_employees();
4338 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
4339 assert_eq!(r.rows.len(), 3);
4340 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
4342 for row in &r.rows {
4343 let d = row[0].to_display_string();
4344 let c = match &row[1] {
4345 Value::Integer(i) => *i,
4346 v => panic!("expected Integer count, got {v:?}"),
4347 };
4348 by_dept.insert(d, c);
4349 }
4350 assert_eq!(by_dept["eng"], 3);
4351 assert_eq!(by_dept["sales"], 2);
4352 assert_eq!(by_dept["ops"], 1);
4353 }
4354
4355 #[test]
4356 fn group_by_with_where_filter() {
4357 let db = seed_employees();
4358 let r = run_rows(
4359 &db,
4360 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
4361 );
4362 let by: std::collections::HashMap<String, i64> = r
4365 .rows
4366 .iter()
4367 .map(|row| {
4368 (
4369 row[0].to_display_string(),
4370 match &row[1] {
4371 Value::Integer(i) => *i,
4372 v => panic!("expected Integer sum, got {v:?}"),
4373 },
4374 )
4375 })
4376 .collect();
4377 assert_eq!(by.len(), 2);
4378 assert_eq!(by["eng"], 320);
4379 assert_eq!(by["sales"], 90);
4380 }
4381
4382 #[test]
4383 fn group_by_without_aggregates_is_distinct() {
4384 let db = seed_employees();
4385 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
4386 assert_eq!(r.rows.len(), 3);
4387 }
4388
4389 #[test]
4390 fn order_by_count_desc() {
4391 let db = seed_employees();
4392 let r = run_rows(
4393 &db,
4394 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
4395 );
4396 assert_eq!(r.rows.len(), 2);
4397 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4399 assert_eq!(r.rows[0][1], Value::Integer(3));
4400 }
4401
4402 #[test]
4403 fn order_by_aggregate_call_form() {
4404 let db = seed_employees();
4405 let r = run_rows(
4407 &db,
4408 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
4409 );
4410 assert_eq!(r.rows.len(), 3);
4411 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4412 }
4413
4414 #[test]
4415 fn group_by_invalid_bare_column_errors() {
4416 let mut db = Database::new("t".to_string());
4418 crate::sql::process_command(
4419 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
4420 &mut db,
4421 )
4422 .unwrap();
4423 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
4424 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
4425 }
4426
4427 #[test]
4428 fn aggregate_in_where_errors_friendly() {
4429 let mut db = Database::new("t".to_string());
4430 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
4431 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
4432 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
4433 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
4434 }
4435
4436 #[test]
4445 fn having_count_filters_groups() {
4446 let db = seed_employees();
4447 let r = run_rows(
4448 &db,
4449 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 1;",
4450 );
4451 assert_eq!(r.columns, vec!["dept".to_string(), "COUNT(*)".to_string()]);
4454 let got: Vec<(String, i64)> = r
4455 .rows
4456 .iter()
4457 .map(|row| (row[0].to_display_string(), expect_int(&row[1])))
4458 .collect();
4459 assert_eq!(got, vec![("eng".to_string(), 3), ("sales".to_string(), 2)]);
4460 }
4461
4462 #[test]
4463 fn having_sum_threshold() {
4464 let db = seed_employees();
4465 let r = run_rows(
4466 &db,
4467 "SELECT dept, SUM(salary) FROM emp GROUP BY dept HAVING SUM(salary) > 100;",
4468 );
4469 assert_eq!(r.rows.len(), 1);
4470 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4471 assert_eq!(r.rows[0][1], Value::Integer(320));
4472 }
4473
4474 #[test]
4475 fn having_references_aggregate_alias() {
4476 let db = seed_employees();
4477 let r = run_rows(
4478 &db,
4479 "SELECT dept, SUM(salary) AS total FROM emp GROUP BY dept HAVING total > 100;",
4480 );
4481 assert_eq!(r.columns, vec!["dept".to_string(), "total".to_string()]);
4482 assert_eq!(r.rows.len(), 1);
4483 assert_eq!(r.rows[0][1], Value::Integer(320));
4484 }
4485
4486 #[test]
4487 fn having_aggregate_not_in_projection() {
4488 let db = seed_employees();
4489 let r = run_rows(
4492 &db,
4493 "SELECT dept FROM emp GROUP BY dept HAVING COUNT(*) > 1;",
4494 );
4495 assert_eq!(r.columns, vec!["dept".to_string()]);
4496 let depts: Vec<String> = r
4497 .rows
4498 .iter()
4499 .map(|row| row[0].to_display_string())
4500 .collect();
4501 assert_eq!(depts, vec!["eng".to_string(), "sales".to_string()]);
4502 }
4503
4504 #[test]
4505 fn having_group_key_not_in_projection() {
4506 let db = seed_employees();
4507 let r = run_rows(
4509 &db,
4510 "SELECT COUNT(*) FROM emp GROUP BY dept HAVING dept = 'eng';",
4511 );
4512 assert_eq!(r.columns, vec!["COUNT(*)".to_string()]);
4513 assert_eq!(r.rows.len(), 1);
4514 assert_eq!(r.rows[0][0], Value::Integer(3));
4515 }
4516
4517 #[test]
4518 fn having_compound_and_predicate() {
4519 let db = seed_employees();
4520 let r = run_rows(
4521 &db,
4522 "SELECT dept FROM emp GROUP BY dept \
4523 HAVING COUNT(*) > 1 AND SUM(salary) > 100;",
4524 );
4525 assert_eq!(r.rows.len(), 1);
4527 assert_eq!(r.rows[0][0].to_display_string(), "eng");
4528 }
4529
4530 #[test]
4531 fn having_composes_with_order_by_and_limit() {
4532 let db = seed_employees();
4533 let r = run_rows(
4534 &db,
4535 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept \
4536 HAVING n >= 1 ORDER BY n DESC LIMIT 2;",
4537 );
4538 let got: Vec<(String, i64)> = r
4539 .rows
4540 .iter()
4541 .map(|row| (row[0].to_display_string(), expect_int(&row[1])))
4542 .collect();
4543 assert_eq!(got, vec![("eng".to_string(), 3), ("sales".to_string(), 2)]);
4544 }
4545
4546 #[test]
4547 fn having_can_exclude_every_group() {
4548 let db = seed_employees();
4549 let r = run_rows(
4550 &db,
4551 "SELECT dept FROM emp GROUP BY dept HAVING COUNT(*) > 99;",
4552 );
4553 assert_eq!(r.rows.len(), 0);
4554 }
4555
4556 #[test]
4557 fn having_null_aggregate_collapses_to_false() {
4558 let mut db = seed_employees();
4559 crate::sql::process_command(
4562 "INSERT INTO emp (name, dept, salary) VALUES ('Zoe', 'mkt', NULL);",
4563 &mut db,
4564 )
4565 .unwrap();
4566 let r = run_rows(
4567 &db,
4568 "SELECT dept FROM emp GROUP BY dept HAVING SUM(salary) > 0;",
4569 );
4570 let depts: Vec<String> = r
4571 .rows
4572 .iter()
4573 .map(|row| row[0].to_display_string())
4574 .collect();
4575 assert_eq!(
4576 depts,
4577 vec!["eng".to_string(), "sales".to_string(), "ops".to_string()],
4578 "mkt (all-NULL salaries) must be filtered out"
4579 );
4580 }
4581
4582 #[test]
4583 fn having_lowercase_function_form_matches() {
4584 let db = seed_employees();
4585 let r = run_rows(
4586 &db,
4587 "SELECT dept FROM emp GROUP BY dept HAVING count(*) > 1;",
4588 );
4589 assert_eq!(r.rows.len(), 2);
4590 }
4591
4592 #[test]
4593 fn having_without_group_by_is_rejected() {
4594 let mut db = seed_employees();
4595 let err =
4596 crate::sql::process_command("SELECT COUNT(*) FROM emp HAVING COUNT(*) > 0;", &mut db);
4597 match err {
4598 Err(SQLRiteError::NotImplemented(msg)) => assert!(
4599 msg.contains("HAVING without GROUP BY"),
4600 "unexpected message: {msg}"
4601 ),
4602 other => panic!("expected NotImplemented, got {other:?}"),
4603 }
4604 }
4605
4606 #[test]
4607 fn having_unknown_column_is_rejected() {
4608 let mut db = seed_employees();
4609 let err = crate::sql::process_command(
4612 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING name = 'Alice';",
4613 &mut db,
4614 );
4615 match err {
4616 Err(e) => {
4617 let msg = e.to_string();
4618 assert!(
4619 msg.contains("HAVING references"),
4620 "unexpected message: {msg}"
4621 );
4622 }
4623 Ok(_) => panic!("HAVING on an out-of-scope column must error"),
4624 }
4625 }
4626
4627 #[test]
4628 fn having_over_join_rejected_for_all_flavors() {
4629 for flavor in ["INNER", "LEFT OUTER", "RIGHT OUTER", "FULL OUTER"] {
4633 let sql = format!(
4634 "SELECT customers.name, COUNT(*) FROM customers \
4635 {flavor} JOIN orders ON customers.id = orders.customer_id \
4636 GROUP BY customers.name HAVING COUNT(*) > 1;"
4637 );
4638 let err = crate::sql::process_command(&sql, &mut seed_join_fixture());
4639 match err {
4640 Err(SQLRiteError::NotImplemented(msg)) => {
4641 assert!(msg.contains("JOIN"), "{flavor}: unexpected message: {msg}")
4642 }
4643 other => panic!("{flavor}: expected NotImplemented, got {other:?}"),
4644 }
4645 }
4646 }
4647
4648 fn expect_int(v: &Value) -> i64 {
4650 match v {
4651 Value::Integer(i) => *i,
4652 other => panic!("expected integer value, got {other:?}"),
4653 }
4654 }
4655
4656 fn seed_join_fixture() -> Database {
4667 let mut db = Database::new("t".to_string());
4668 for sql in [
4669 "CREATE TABLE customers (id INTEGER PRIMARY KEY, name TEXT);",
4670 "CREATE TABLE orders (id INTEGER PRIMARY KEY, customer_id INTEGER, amount INTEGER);",
4671 "INSERT INTO customers (name) VALUES ('Alice');",
4672 "INSERT INTO customers (name) VALUES ('Bob');",
4673 "INSERT INTO customers (name) VALUES ('Carol');",
4674 "INSERT INTO orders (customer_id, amount) VALUES (1, 100);",
4675 "INSERT INTO orders (customer_id, amount) VALUES (1, 200);",
4676 "INSERT INTO orders (customer_id, amount) VALUES (2, 50);",
4677 "INSERT INTO orders (customer_id, amount) VALUES (4, 999);",
4678 ] {
4679 crate::sql::process_command(sql, &mut db).unwrap();
4680 }
4681 db
4682 }
4683
4684 #[test]
4685 fn inner_join_returns_only_matched_rows() {
4686 let db = seed_join_fixture();
4687 let r = run_rows(
4688 &db,
4689 "SELECT customers.name, orders.amount FROM customers \
4690 INNER JOIN orders ON customers.id = orders.customer_id;",
4691 );
4692 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4693 let pairs: Vec<(String, i64)> = r
4696 .rows
4697 .iter()
4698 .map(|row| {
4699 (
4700 row[0].to_display_string(),
4701 match row[1] {
4702 Value::Integer(i) => i,
4703 ref v => panic!("expected integer amount, got {v:?}"),
4704 },
4705 )
4706 })
4707 .collect();
4708 assert_eq!(pairs.len(), 3);
4709 assert!(pairs.contains(&("Alice".to_string(), 100)));
4710 assert!(pairs.contains(&("Alice".to_string(), 200)));
4711 assert!(pairs.contains(&("Bob".to_string(), 50)));
4712 }
4713
4714 #[test]
4715 fn bare_join_defaults_to_inner() {
4716 let db = seed_join_fixture();
4717 let r = run_rows(
4718 &db,
4719 "SELECT customers.name FROM customers \
4720 JOIN orders ON customers.id = orders.customer_id;",
4721 );
4722 assert_eq!(r.rows.len(), 3, "JOIN without prefix should be INNER");
4723 }
4724
4725 #[test]
4726 fn left_outer_join_preserves_unmatched_left() {
4727 let db = seed_join_fixture();
4728 let r = run_rows(
4729 &db,
4730 "SELECT customers.name, orders.amount FROM customers \
4731 LEFT OUTER JOIN orders ON customers.id = orders.customer_id;",
4732 );
4733 assert_eq!(r.rows.len(), 4);
4736 let carol = r
4737 .rows
4738 .iter()
4739 .find(|row| row[0].to_display_string() == "Carol")
4740 .expect("Carol should appear with a NULL-padded right side");
4741 assert_eq!(carol[1], Value::Null);
4742 }
4743
4744 #[test]
4745 fn right_outer_join_preserves_unmatched_right() {
4746 let db = seed_join_fixture();
4747 let r = run_rows(
4748 &db,
4749 "SELECT customers.name, orders.amount FROM customers \
4750 RIGHT OUTER JOIN orders ON customers.id = orders.customer_id;",
4751 );
4752 assert_eq!(r.rows.len(), 4);
4756 let dangling = r
4757 .rows
4758 .iter()
4759 .find(|row| matches!(row[1], Value::Integer(999)))
4760 .expect("dangling order 999 should appear with a NULL-padded customer name");
4761 assert_eq!(dangling[0], Value::Null);
4762 }
4763
4764 #[test]
4765 fn full_outer_join_preserves_both_sides() {
4766 let db = seed_join_fixture();
4767 let r = run_rows(
4768 &db,
4769 "SELECT customers.name, orders.amount FROM customers \
4770 FULL OUTER JOIN orders ON customers.id = orders.customer_id;",
4771 );
4772 assert_eq!(r.rows.len(), 5);
4775 assert!(
4777 r.rows
4778 .iter()
4779 .any(|row| row[0].to_display_string() == "Carol" && matches!(row[1], Value::Null))
4780 );
4781 assert!(
4783 r.rows
4784 .iter()
4785 .any(|row| matches!(row[1], Value::Integer(999)) && matches!(row[0], Value::Null))
4786 );
4787 }
4788
4789 #[test]
4790 fn join_with_table_aliases_resolves_qualifiers() {
4791 let db = seed_join_fixture();
4792 let r = run_rows(
4793 &db,
4794 "SELECT c.name, o.amount FROM customers AS c \
4795 INNER JOIN orders AS o ON c.id = o.customer_id;",
4796 );
4797 assert_eq!(r.rows.len(), 3);
4798 assert_eq!(r.columns, vec!["name".to_string(), "amount".to_string()]);
4799 }
4800
4801 #[test]
4802 fn join_with_where_filter_applies_after_join() {
4803 let db = seed_join_fixture();
4804 let r = run_rows(
4807 &db,
4808 "SELECT customers.name, orders.amount FROM customers \
4809 INNER JOIN orders ON customers.id = orders.customer_id \
4810 WHERE orders.amount >= 100;",
4811 );
4812 assert_eq!(r.rows.len(), 2);
4813 assert!(
4814 r.rows
4815 .iter()
4816 .all(|row| row[0].to_display_string() == "Alice")
4817 );
4818 }
4819
4820 #[test]
4821 fn left_join_with_where_on_right_side_is_not_inner() {
4822 let db = seed_join_fixture();
4826 let r = run_rows(
4827 &db,
4828 "SELECT customers.name, orders.amount FROM customers \
4829 LEFT OUTER JOIN orders ON customers.id = orders.customer_id \
4830 WHERE orders.amount IS NULL;",
4831 );
4832 assert_eq!(r.rows.len(), 1);
4834 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
4835 assert_eq!(r.rows[0][1], Value::Null);
4836 }
4837
4838 #[test]
4839 fn select_star_over_join_emits_all_columns_from_both_tables() {
4840 let db = seed_join_fixture();
4841 let r = run_rows(
4842 &db,
4843 "SELECT * FROM customers \
4844 INNER JOIN orders ON customers.id = orders.customer_id;",
4845 );
4846 assert_eq!(
4850 r.columns,
4851 vec![
4852 "id".to_string(),
4853 "name".to_string(),
4854 "id".to_string(),
4855 "customer_id".to_string(),
4856 "amount".to_string(),
4857 ]
4858 );
4859 assert_eq!(r.rows.len(), 3);
4860 }
4861
4862 #[test]
4863 fn join_order_by_sorts_full_joined_rows() {
4864 let db = seed_join_fixture();
4865 let r = run_rows(
4866 &db,
4867 "SELECT c.name, o.amount FROM customers AS c \
4868 INNER JOIN orders AS o ON c.id = o.customer_id \
4869 ORDER BY o.amount;",
4870 );
4871 let amounts: Vec<i64> = r
4872 .rows
4873 .iter()
4874 .map(|row| match row[1] {
4875 Value::Integer(i) => i,
4876 ref v => panic!("expected integer, got {v:?}"),
4877 })
4878 .collect();
4879 assert_eq!(amounts, vec![50, 100, 200]);
4880 }
4881
4882 #[test]
4883 fn join_limit_truncates_after_join_and_sort() {
4884 let db = seed_join_fixture();
4885 let r = run_rows(
4886 &db,
4887 "SELECT c.name, o.amount FROM customers AS c \
4888 INNER JOIN orders AS o ON c.id = o.customer_id \
4889 ORDER BY o.amount DESC LIMIT 2;",
4890 );
4891 assert_eq!(r.rows.len(), 2);
4892 let amounts: Vec<i64> = r
4894 .rows
4895 .iter()
4896 .map(|row| match row[1] {
4897 Value::Integer(i) => i,
4898 ref v => panic!("expected integer, got {v:?}"),
4899 })
4900 .collect();
4901 assert_eq!(amounts, vec![200, 100]);
4902 }
4903
4904 #[test]
4905 fn three_table_join_chains_correctly() {
4906 let mut db = Database::new("t".to_string());
4907 for sql in [
4908 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
4909 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
4910 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
4911 "INSERT INTO a (label) VALUES ('a-one');",
4912 "INSERT INTO a (label) VALUES ('a-two');",
4913 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
4914 "INSERT INTO b (a_id, tag) VALUES (2, 'b2');",
4915 "INSERT INTO c (b_id, note) VALUES (1, 'c1');",
4916 ] {
4917 crate::sql::process_command(sql, &mut db).unwrap();
4918 }
4919 let r = run_rows(
4920 &db,
4921 "SELECT a.label, b.tag, c.note FROM a \
4922 INNER JOIN b ON a.id = b.a_id \
4923 INNER JOIN c ON b.id = c.b_id;",
4924 );
4925 assert_eq!(r.rows.len(), 1);
4927 assert_eq!(r.rows[0][0].to_display_string(), "a-one");
4928 assert_eq!(r.rows[0][1].to_display_string(), "b1");
4929 assert_eq!(r.rows[0][2].to_display_string(), "c1");
4930 }
4931
4932 #[test]
4933 fn ambiguous_unqualified_column_in_join_errors() {
4934 let db = seed_join_fixture();
4938 let q = parse_select(
4939 "SELECT id FROM customers INNER JOIN orders ON customers.id = orders.customer_id;",
4940 );
4941 let res = execute_select_rows(q, &db);
4942 assert!(res.is_err(), "unqualified ambiguous 'id' should error");
4943 }
4944
4945 #[test]
4946 fn join_self_without_alias_is_rejected() {
4947 let mut db = Database::new("t".to_string());
4948 crate::sql::process_command(
4949 "CREATE TABLE n (id INTEGER PRIMARY KEY, parent INTEGER);",
4950 &mut db,
4951 )
4952 .unwrap();
4953 let q = parse_select("SELECT n.id FROM n INNER JOIN n ON n.id = n.parent;");
4954 let res = execute_select_rows(q, &db);
4955 assert!(
4956 res.is_err(),
4957 "self-join without an alias should error on duplicate qualifier"
4958 );
4959 }
4960
4961 #[test]
4967 fn join_using_matches_same_rows_as_on() {
4968 let db = seed_join_fixture();
4969 let using = run_rows(
4970 &db,
4971 "SELECT customers.name, orders.amount FROM customers \
4972 INNER JOIN orders USING (id) ORDER BY orders.amount;",
4973 );
4974 let on = run_rows(
4975 &db,
4976 "SELECT customers.name, orders.amount FROM customers \
4977 INNER JOIN orders ON customers.id = orders.id ORDER BY orders.amount;",
4978 );
4979 let pairs: Vec<(String, Value)> = using
4981 .rows
4982 .iter()
4983 .map(|r| (r[0].to_display_string(), r[1].clone()))
4984 .collect();
4985 assert_eq!(pairs.len(), 3);
4986 assert_eq!(
4987 using.rows, on.rows,
4988 "USING must mirror the explicit ON rows"
4989 );
4990 }
4991
4992 #[test]
4995 fn select_star_using_dedups_joined_column() {
4996 let db = seed_join_fixture();
4997 let r = run_rows(&db, "SELECT * FROM customers INNER JOIN orders USING (id);");
4998 assert_eq!(
5002 r.columns,
5003 vec![
5004 "id".to_string(),
5005 "name".to_string(),
5006 "customer_id".to_string(),
5007 "amount".to_string(),
5008 ]
5009 );
5010 assert_eq!(r.rows.len(), 3);
5011 for row in &r.rows {
5014 assert!(matches!(row[0], Value::Integer(_)));
5015 }
5016 }
5017
5018 fn seed_natural_fixture() -> Database {
5019 let mut db = Database::new("t".to_string());
5020 for sql in [
5021 "CREATE TABLE l (lid INTEGER PRIMARY KEY, k1 INTEGER, k2 INTEGER, v1 TEXT);",
5024 "CREATE TABLE r (rid INTEGER PRIMARY KEY, k1 INTEGER, k2 INTEGER, v2 TEXT);",
5025 "INSERT INTO l (k1, k2, v1) VALUES (1, 1, 'l-a');",
5026 "INSERT INTO l (k1, k2, v1) VALUES (1, 2, 'l-b');",
5027 "INSERT INTO l (k1, k2, v1) VALUES (2, 1, 'l-c');",
5028 "INSERT INTO r (k1, k2, v2) VALUES (1, 1, 'r-a');",
5029 "INSERT INTO r (k1, k2, v2) VALUES (1, 2, 'r-b');",
5030 "INSERT INTO r (k1, k2, v2) VALUES (9, 9, 'r-z');",
5031 ] {
5032 crate::sql::process_command(sql, &mut db).unwrap();
5033 }
5034 db
5035 }
5036
5037 #[test]
5040 fn natural_join_matches_on_all_shared_columns() {
5041 let db = seed_natural_fixture();
5042 let natural = run_rows(&db, "SELECT v1, v2 FROM l NATURAL JOIN r ORDER BY v1;");
5043 let pairs: Vec<(String, String)> = natural
5045 .rows
5046 .iter()
5047 .map(|r| (r[0].to_display_string(), r[1].to_display_string()))
5048 .collect();
5049 assert_eq!(
5050 pairs,
5051 vec![
5052 ("l-a".to_string(), "r-a".to_string()),
5053 ("l-b".to_string(), "r-b".to_string()),
5054 ]
5055 );
5056 let explicit = run_rows(
5058 &db,
5059 "SELECT v1, v2 FROM l INNER JOIN r ON l.k1 = r.k1 AND l.k2 = r.k2 ORDER BY v1;",
5060 );
5061 assert_eq!(natural.rows, explicit.rows);
5062 }
5063
5064 #[test]
5066 fn select_star_natural_dedups_shared_columns() {
5067 let db = seed_natural_fixture();
5068 let r = run_rows(&db, "SELECT * FROM l NATURAL JOIN r;");
5069 assert_eq!(
5072 r.columns,
5073 vec![
5074 "lid".to_string(),
5075 "k1".to_string(),
5076 "k2".to_string(),
5077 "v1".to_string(),
5078 "rid".to_string(),
5079 "v2".to_string(),
5080 ]
5081 );
5082 assert_eq!(r.rows.len(), 2);
5083 }
5084
5085 #[test]
5088 fn natural_join_without_common_columns_is_cross_product() {
5089 let mut db = Database::new("t".to_string());
5090 for sql in [
5091 "CREATE TABLE p (pid INTEGER PRIMARY KEY, pa TEXT);",
5092 "CREATE TABLE q (qid INTEGER PRIMARY KEY, qb TEXT);",
5093 "INSERT INTO p (pa) VALUES ('p1');",
5094 "INSERT INTO p (pa) VALUES ('p2');",
5095 "INSERT INTO q (qb) VALUES ('q1');",
5096 "INSERT INTO q (qb) VALUES ('q2');",
5097 "INSERT INTO q (qb) VALUES ('q3');",
5098 ] {
5099 crate::sql::process_command(sql, &mut db).unwrap();
5100 }
5101 let r = run_rows(&db, "SELECT p.pa, q.qb FROM p NATURAL JOIN q;");
5102 assert_eq!(r.rows.len(), 2 * 3, "no shared columns ⇒ cross product");
5103 }
5104
5105 #[test]
5108 fn cross_join_produces_cartesian_product() {
5109 let db = seed_join_fixture();
5110 let cross = run_rows(
5111 &db,
5112 "SELECT customers.name, orders.amount FROM customers CROSS JOIN orders;",
5113 );
5114 assert_eq!(cross.rows.len(), 12);
5116 let on_true = run_rows(
5117 &db,
5118 "SELECT customers.name, orders.amount FROM customers INNER JOIN orders ON 1;",
5119 );
5120 assert_eq!(cross.rows.len(), on_true.rows.len());
5121 let star = run_rows(&db, "SELECT * FROM customers CROSS JOIN orders;");
5123 assert_eq!(star.columns.len(), 5);
5124 assert_eq!(star.rows.len(), 12);
5125 }
5126
5127 #[test]
5131 fn left_outer_join_using_preserves_unmatched_left() {
5132 let db = seed_join_fixture();
5133 let r = run_rows(
5134 &db,
5135 "SELECT * FROM customers LEFT OUTER JOIN orders USING (id);",
5136 );
5137 assert_eq!(r.columns.len(), 4, "id is shown once");
5141 assert_eq!(r.rows.len(), 3);
5142 }
5143
5144 #[test]
5147 fn using_unknown_column_errors() {
5148 let db = seed_join_fixture();
5149 let q = parse_select("SELECT * FROM customers INNER JOIN orders USING (nope);");
5150 let res = execute_select_rows(q, &db);
5151 assert!(res.is_err(), "USING (nope) must error — column absent");
5152 }
5153
5154 #[test]
5155 fn aggregates_over_join_are_rejected() {
5156 let db = seed_join_fixture();
5157 let err = crate::sql::process_command(
5158 "SELECT COUNT(*) FROM customers \
5159 INNER JOIN orders ON customers.id = orders.customer_id;",
5160 &mut seed_join_fixture(),
5161 );
5162 assert!(err.is_err(), "aggregates over JOIN are not yet supported");
5163 let _ = db; }
5165
5166 #[test]
5167 fn left_join_with_no_matches_pads_every_row() {
5168 let mut db = Database::new("t".to_string());
5169 for sql in [
5170 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
5171 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
5172 "INSERT INTO a (x) VALUES (1);",
5173 "INSERT INTO a (x) VALUES (2);",
5174 "INSERT INTO b (y) VALUES (10);",
5175 ] {
5176 crate::sql::process_command(sql, &mut db).unwrap();
5177 }
5178 let r = run_rows(
5180 &db,
5181 "SELECT a.x, b.y FROM a LEFT OUTER JOIN b ON a.x = b.y;",
5182 );
5183 assert_eq!(r.rows.len(), 2);
5184 for row in &r.rows {
5185 assert_eq!(row[1], Value::Null);
5186 }
5187 }
5188
5189 #[test]
5190 fn left_outer_join_order_by_places_nulls_first() {
5191 let db = seed_join_fixture();
5196 let r = run_rows(
5197 &db,
5198 "SELECT c.name, o.amount FROM customers AS c \
5199 LEFT OUTER JOIN orders AS o ON c.id = o.customer_id \
5200 ORDER BY o.amount ASC;",
5201 );
5202 assert_eq!(r.rows.len(), 4);
5203 assert_eq!(r.rows[0][0].to_display_string(), "Carol");
5205 assert_eq!(r.rows[0][1], Value::Null);
5206 }
5207
5208 #[test]
5209 fn chained_left_outer_join_preserves_left_through_two_levels() {
5210 let mut db = Database::new("t".to_string());
5213 for sql in [
5214 "CREATE TABLE a (id INTEGER PRIMARY KEY, label TEXT);",
5215 "CREATE TABLE b (id INTEGER PRIMARY KEY, a_id INTEGER, tag TEXT);",
5216 "CREATE TABLE c (id INTEGER PRIMARY KEY, b_id INTEGER, note TEXT);",
5217 "INSERT INTO a (label) VALUES ('a-one');",
5218 "INSERT INTO a (label) VALUES ('a-two');",
5219 "INSERT INTO b (a_id, tag) VALUES (1, 'b1');",
5221 ] {
5223 crate::sql::process_command(sql, &mut db).unwrap();
5224 }
5225 let r = run_rows(
5226 &db,
5227 "SELECT a.label, b.tag, c.note FROM a \
5228 LEFT OUTER JOIN b ON a.id = b.a_id \
5229 LEFT OUTER JOIN c ON b.id = c.b_id;",
5230 );
5231 assert_eq!(r.rows.len(), 2);
5233 let by_label: std::collections::HashMap<String, &Vec<Value>> = r
5234 .rows
5235 .iter()
5236 .map(|row| (row[0].to_display_string(), row))
5237 .collect();
5238 assert_eq!(by_label["a-one"][1].to_display_string(), "b1");
5239 assert_eq!(by_label["a-one"][2], Value::Null);
5240 assert_eq!(by_label["a-two"][1], Value::Null);
5241 assert_eq!(by_label["a-two"][2], Value::Null);
5242 }
5243
5244 #[test]
5245 fn on_clause_referencing_not_yet_joined_table_errors_clearly() {
5246 let mut db = Database::new("t".to_string());
5250 for sql in [
5251 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
5252 "CREATE TABLE b (id INTEGER PRIMARY KEY, x INTEGER);",
5253 "CREATE TABLE c (id INTEGER PRIMARY KEY, x INTEGER);",
5254 "INSERT INTO a (x) VALUES (1);",
5255 "INSERT INTO b (x) VALUES (1);",
5256 "INSERT INTO c (x) VALUES (1);",
5257 ] {
5258 crate::sql::process_command(sql, &mut db).unwrap();
5259 }
5260 let q =
5261 parse_select("SELECT a.x FROM a INNER JOIN b ON a.x = c.x INNER JOIN c ON b.x = c.x;");
5262 let res = execute_select_rows(q, &db);
5263 assert!(
5264 res.is_err(),
5265 "ON referencing not-yet-joined table 'c' should error"
5266 );
5267 }
5268
5269 #[test]
5270 fn join_on_truthy_integer_is_accepted() {
5271 let mut db = Database::new("t".to_string());
5275 for sql in [
5276 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
5277 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
5278 "INSERT INTO a (x) VALUES (1);",
5279 "INSERT INTO a (x) VALUES (2);",
5280 "INSERT INTO b (y) VALUES (10);",
5281 "INSERT INTO b (y) VALUES (20);",
5282 ] {
5283 crate::sql::process_command(sql, &mut db).unwrap();
5284 }
5285 let r = run_rows(&db, "SELECT a.x, b.y FROM a INNER JOIN b ON 1;");
5286 assert_eq!(r.rows.len(), 4);
5288 }
5289
5290 #[test]
5291 fn full_join_on_empty_tables_returns_empty() {
5292 let mut db = Database::new("t".to_string());
5293 for sql in [
5294 "CREATE TABLE a (id INTEGER PRIMARY KEY, x INTEGER);",
5295 "CREATE TABLE b (id INTEGER PRIMARY KEY, y INTEGER);",
5296 ] {
5297 crate::sql::process_command(sql, &mut db).unwrap();
5298 }
5299 let r = run_rows(
5300 &db,
5301 "SELECT a.x, b.y FROM a FULL OUTER JOIN b ON a.x = b.y;",
5302 );
5303 assert!(r.rows.is_empty());
5304 }
5305}