1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::agg::{AggState, DistinctKey, like_match};
16use crate::sql::db::database::Database;
17use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
18use crate::sql::db::table::{
19 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
20};
21use crate::sql::fts::{Bm25Params, PostingList};
22use crate::sql::hnsw::{DistanceMetric, HnswIndex};
23use crate::sql::parser::select::{
24 AggregateArg, OrderByClause, Projection, ProjectionItem, ProjectionKind, SelectQuery,
25};
26
27pub struct SelectResult {
36 pub columns: Vec<String>,
37 pub rows: Vec<Vec<Value>>,
38}
39
40pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
44 let table = db
45 .get_table(query.table_name.clone())
46 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
47
48 let proj_items: Vec<ProjectionItem> = match &query.projection {
53 Projection::All => table
54 .column_names()
55 .into_iter()
56 .map(|c| ProjectionItem {
57 kind: ProjectionKind::Column(c),
58 alias: None,
59 })
60 .collect(),
61 Projection::Items(items) => items.clone(),
62 };
63 let has_aggregates = proj_items
64 .iter()
65 .any(|i| matches!(i.kind, ProjectionKind::Aggregate(_)));
66 for item in &proj_items {
68 if let ProjectionKind::Column(c) = &item.kind
69 && !table.contains_column(c.clone())
70 {
71 return Err(SQLRiteError::Internal(format!(
72 "Column '{c}' does not exist on table '{}'",
73 query.table_name
74 )));
75 }
76 }
77 for c in &query.group_by {
78 if !table.contains_column(c.clone()) {
79 return Err(SQLRiteError::Internal(format!(
80 "GROUP BY references unknown column '{c}' on table '{}'",
81 query.table_name
82 )));
83 }
84 }
85 let matching = match select_rowids(table, query.selection.as_ref())? {
89 RowidSource::IndexProbe(rowids) => rowids,
90 RowidSource::FullScan => {
91 let mut out = Vec::new();
92 for rowid in table.rowids() {
93 if let Some(expr) = &query.selection
94 && !eval_predicate(expr, table, rowid)?
95 {
96 continue;
97 }
98 out.push(rowid);
99 }
100 out
101 }
102 };
103 let mut matching = matching;
104
105 let aggregating = has_aggregates || !query.group_by.is_empty();
106
107 if aggregating {
113 for item in &proj_items {
115 if let ProjectionKind::Aggregate(call) = &item.kind
116 && let AggregateArg::Column(c) = &call.arg
117 && !table.contains_column(c.clone())
118 {
119 return Err(SQLRiteError::Internal(format!(
120 "{}({}) references unknown column '{c}' on table '{}'",
121 call.func.as_str(),
122 c,
123 query.table_name
124 )));
125 }
126 }
127
128 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
129 let mut rows = aggregate_rows(table, &matching, &query.group_by, &proj_items)?;
130
131 if query.distinct {
132 rows = dedupe_rows(rows);
133 }
134
135 if let Some(order) = &query.order_by {
136 sort_output_rows(&mut rows, &columns, &proj_items, order)?;
137 }
138 if let Some(k) = query.limit {
139 rows.truncate(k);
140 }
141
142 return Ok(SelectResult { columns, rows });
143 }
144
145 let defer_limit_for_distinct = query.distinct;
183 match (&query.order_by, query.limit) {
184 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
185 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
186 }
187 (Some(order), Some(k))
188 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
189 {
190 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
191 }
192 (Some(order), Some(k)) if !defer_limit_for_distinct && k < matching.len() => {
193 matching = select_topk(&matching, table, order, k)?;
194 }
195 (Some(order), _) => {
196 sort_rowids(&mut matching, table, order)?;
197 if let Some(k) = query.limit
198 && !defer_limit_for_distinct
199 {
200 matching.truncate(k);
201 }
202 }
203 (None, Some(k)) if !defer_limit_for_distinct => {
204 matching.truncate(k);
205 }
206 _ => {}
207 }
208
209 let columns: Vec<String> = proj_items.iter().map(|i| i.output_name()).collect();
210 let projected_cols: Vec<String> = proj_items
211 .iter()
212 .map(|i| match &i.kind {
213 ProjectionKind::Column(c) => c.clone(),
214 ProjectionKind::Aggregate(_) => unreachable!("aggregation handled above"),
215 })
216 .collect();
217
218 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
222 for rowid in &matching {
223 let row: Vec<Value> = projected_cols
224 .iter()
225 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
226 .collect();
227 rows.push(row);
228 }
229
230 if query.distinct {
231 rows = dedupe_rows(rows);
232 if let Some(k) = query.limit {
233 rows.truncate(k);
234 }
235 }
236
237 Ok(SelectResult { columns, rows })
238}
239
240pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
245 let result = execute_select_rows(query, db)?;
246 let row_count = result.rows.len();
247
248 let mut print_table = PrintTable::new();
249 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
250 print_table.add_row(PrintRow::new(header_cells));
251
252 for row in &result.rows {
253 let cells: Vec<PrintCell> = row
254 .iter()
255 .map(|v| PrintCell::new(&v.to_display_string()))
256 .collect();
257 print_table.add_row(PrintRow::new(cells));
258 }
259
260 Ok((print_table.to_string(), row_count))
261}
262
263pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
265 let Statement::Delete(Delete {
266 from, selection, ..
267 }) = stmt
268 else {
269 return Err(SQLRiteError::Internal(
270 "execute_delete called on a non-DELETE statement".to_string(),
271 ));
272 };
273
274 let tables = match from {
275 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
276 };
277 let table_name = extract_single_table_name(tables)?;
278
279 let matching: Vec<i64> = {
281 let table = db
282 .get_table(table_name.clone())
283 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
284 match select_rowids(table, selection.as_ref())? {
285 RowidSource::IndexProbe(rowids) => rowids,
286 RowidSource::FullScan => {
287 let mut out = Vec::new();
288 for rowid in table.rowids() {
289 if let Some(expr) = selection {
290 if !eval_predicate(expr, table, rowid)? {
291 continue;
292 }
293 }
294 out.push(rowid);
295 }
296 out
297 }
298 }
299 };
300
301 let table = db.get_table_mut(table_name)?;
302 for rowid in &matching {
303 table.delete_row(*rowid);
304 }
305 if !matching.is_empty() {
314 for entry in &mut table.hnsw_indexes {
315 entry.needs_rebuild = true;
316 }
317 for entry in &mut table.fts_indexes {
318 entry.needs_rebuild = true;
319 }
320 }
321 Ok(matching.len())
322}
323
324pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
326 let Statement::Update(Update {
327 table,
328 assignments,
329 from,
330 selection,
331 ..
332 }) = stmt
333 else {
334 return Err(SQLRiteError::Internal(
335 "execute_update called on a non-UPDATE statement".to_string(),
336 ));
337 };
338
339 if from.is_some() {
340 return Err(SQLRiteError::NotImplemented(
341 "UPDATE ... FROM is not supported yet".to_string(),
342 ));
343 }
344
345 let table_name = extract_table_name(table)?;
346
347 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
349 {
350 let tbl = db
351 .get_table(table_name.clone())
352 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
353 for a in assignments {
354 let col = match &a.target {
355 AssignmentTarget::ColumnName(name) => name
356 .0
357 .last()
358 .map(|p| p.to_string())
359 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
360 AssignmentTarget::Tuple(_) => {
361 return Err(SQLRiteError::NotImplemented(
362 "tuple assignment targets are not supported".to_string(),
363 ));
364 }
365 };
366 if !tbl.contains_column(col.clone()) {
367 return Err(SQLRiteError::Internal(format!(
368 "UPDATE references unknown column '{col}'"
369 )));
370 }
371 parsed_assignments.push((col, a.value.clone()));
372 }
373 }
374
375 let work: Vec<(i64, Vec<(String, Value)>)> = {
379 let tbl = db.get_table(table_name.clone())?;
380 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
381 RowidSource::IndexProbe(rowids) => rowids,
382 RowidSource::FullScan => {
383 let mut out = Vec::new();
384 for rowid in tbl.rowids() {
385 if let Some(expr) = selection {
386 if !eval_predicate(expr, tbl, rowid)? {
387 continue;
388 }
389 }
390 out.push(rowid);
391 }
392 out
393 }
394 };
395 let mut rows_to_update = Vec::new();
396 for rowid in matched_rowids {
397 let mut values = Vec::with_capacity(parsed_assignments.len());
398 for (col, expr) in &parsed_assignments {
399 let v = eval_expr(expr, tbl, rowid)?;
402 values.push((col.clone(), v));
403 }
404 rows_to_update.push((rowid, values));
405 }
406 rows_to_update
407 };
408
409 let tbl = db.get_table_mut(table_name)?;
410 for (rowid, values) in &work {
411 for (col, v) in values {
412 tbl.set_value(col, *rowid, v.clone())?;
413 }
414 }
415
416 if !work.is_empty() {
425 let updated_columns: std::collections::HashSet<&str> = work
426 .iter()
427 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
428 .collect();
429 for entry in &mut tbl.hnsw_indexes {
430 if updated_columns.contains(entry.column_name.as_str()) {
431 entry.needs_rebuild = true;
432 }
433 }
434 for entry in &mut tbl.fts_indexes {
435 if updated_columns.contains(entry.column_name.as_str()) {
436 entry.needs_rebuild = true;
437 }
438 }
439 }
440 Ok(work.len())
441}
442
443pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
455 let Statement::CreateIndex(CreateIndex {
456 name,
457 table_name,
458 columns,
459 using,
460 unique,
461 if_not_exists,
462 predicate,
463 ..
464 }) = stmt
465 else {
466 return Err(SQLRiteError::Internal(
467 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
468 ));
469 };
470
471 if predicate.is_some() {
472 return Err(SQLRiteError::NotImplemented(
473 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
474 ));
475 }
476
477 if columns.len() != 1 {
478 return Err(SQLRiteError::NotImplemented(format!(
479 "multi-column indexes are not supported yet ({} columns given)",
480 columns.len()
481 )));
482 }
483
484 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
485 SQLRiteError::NotImplemented(
486 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
487 )
488 })?;
489
490 let method = match using {
496 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
497 IndexMethod::Hnsw
498 }
499 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
500 IndexMethod::Fts
501 }
502 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
503 IndexMethod::Btree
504 }
505 Some(other) => {
506 return Err(SQLRiteError::NotImplemented(format!(
507 "CREATE INDEX … USING {other:?} is not supported \
508 (try `hnsw`, `fts`, or no USING clause)"
509 )));
510 }
511 None => IndexMethod::Btree,
512 };
513
514 let table_name_str = table_name.to_string();
515 let column_name = match &columns[0].column.expr {
516 Expr::Identifier(ident) => ident.value.clone(),
517 Expr::CompoundIdentifier(parts) => parts
518 .last()
519 .map(|p| p.value.clone())
520 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
521 other => {
522 return Err(SQLRiteError::NotImplemented(format!(
523 "CREATE INDEX only supports simple column references, got {other:?}"
524 )));
525 }
526 };
527
528 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
533 let table = db.get_table(table_name_str.clone()).map_err(|_| {
534 SQLRiteError::General(format!(
535 "CREATE INDEX references unknown table '{table_name_str}'"
536 ))
537 })?;
538 if !table.contains_column(column_name.clone()) {
539 return Err(SQLRiteError::General(format!(
540 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
541 )));
542 }
543 let col = table
544 .columns
545 .iter()
546 .find(|c| c.column_name == column_name)
547 .expect("we just verified the column exists");
548
549 if table.index_by_name(&index_name).is_some()
552 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
553 || table.fts_indexes.iter().any(|i| i.name == index_name)
554 {
555 if *if_not_exists {
556 return Ok(index_name);
557 }
558 return Err(SQLRiteError::General(format!(
559 "index '{index_name}' already exists"
560 )));
561 }
562 let datatype = clone_datatype(&col.datatype);
563
564 let mut pairs = Vec::new();
565 for rowid in table.rowids() {
566 if let Some(v) = table.get_value(&column_name, rowid) {
567 pairs.push((rowid, v));
568 }
569 }
570 (datatype, pairs)
571 };
572
573 match method {
574 IndexMethod::Btree => create_btree_index(
575 db,
576 &table_name_str,
577 &index_name,
578 &column_name,
579 &datatype,
580 *unique,
581 &existing_rowids_and_values,
582 ),
583 IndexMethod::Hnsw => create_hnsw_index(
584 db,
585 &table_name_str,
586 &index_name,
587 &column_name,
588 &datatype,
589 *unique,
590 &existing_rowids_and_values,
591 ),
592 IndexMethod::Fts => create_fts_index(
593 db,
594 &table_name_str,
595 &index_name,
596 &column_name,
597 &datatype,
598 *unique,
599 &existing_rowids_and_values,
600 ),
601 }
602}
603
604pub fn execute_drop_table(
615 names: &[ObjectName],
616 if_exists: bool,
617 db: &mut Database,
618) -> Result<usize> {
619 if names.len() != 1 {
620 return Err(SQLRiteError::NotImplemented(
621 "DROP TABLE supports a single table per statement".to_string(),
622 ));
623 }
624 let name = names[0].to_string();
625
626 if name == crate::sql::pager::MASTER_TABLE_NAME {
627 return Err(SQLRiteError::General(format!(
628 "'{}' is a reserved name used by the internal schema catalog",
629 crate::sql::pager::MASTER_TABLE_NAME
630 )));
631 }
632
633 if !db.contains_table(name.clone()) {
634 return if if_exists {
635 Ok(0)
636 } else {
637 Err(SQLRiteError::General(format!(
638 "Table '{name}' does not exist"
639 )))
640 };
641 }
642
643 db.tables.remove(&name);
644 Ok(1)
645}
646
647pub fn execute_drop_index(
656 names: &[ObjectName],
657 if_exists: bool,
658 db: &mut Database,
659) -> Result<usize> {
660 if names.len() != 1 {
661 return Err(SQLRiteError::NotImplemented(
662 "DROP INDEX supports a single index per statement".to_string(),
663 ));
664 }
665 let name = names[0].to_string();
666
667 for table in db.tables.values_mut() {
668 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
669 if secondary.origin == IndexOrigin::Auto {
670 return Err(SQLRiteError::General(format!(
671 "cannot drop auto-created index '{name}' (drop the column or table instead)"
672 )));
673 }
674 table.secondary_indexes.retain(|i| i.name != name);
675 return Ok(1);
676 }
677 if table.hnsw_indexes.iter().any(|i| i.name == name) {
678 table.hnsw_indexes.retain(|i| i.name != name);
679 return Ok(1);
680 }
681 if table.fts_indexes.iter().any(|i| i.name == name) {
682 table.fts_indexes.retain(|i| i.name != name);
683 return Ok(1);
684 }
685 }
686
687 if if_exists {
688 Ok(0)
689 } else {
690 Err(SQLRiteError::General(format!(
691 "Index '{name}' does not exist"
692 )))
693 }
694}
695
696pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
708 let table_name = alter.name.to_string();
709
710 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
711 return Err(SQLRiteError::General(format!(
712 "'{}' is a reserved name used by the internal schema catalog",
713 crate::sql::pager::MASTER_TABLE_NAME
714 )));
715 }
716
717 if !db.contains_table(table_name.clone()) {
718 return if alter.if_exists {
719 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
720 } else {
721 Err(SQLRiteError::General(format!(
722 "Table '{table_name}' does not exist"
723 )))
724 };
725 }
726
727 if alter.operations.len() != 1 {
728 return Err(SQLRiteError::NotImplemented(
729 "ALTER TABLE supports one operation per statement".to_string(),
730 ));
731 }
732
733 match &alter.operations[0] {
734 AlterTableOperation::RenameTable { table_name: kind } => {
735 let new_name = match kind {
736 RenameTableNameKind::To(name) => name.to_string(),
737 RenameTableNameKind::As(_) => {
738 return Err(SQLRiteError::NotImplemented(
739 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
740 .to_string(),
741 ));
742 }
743 };
744 alter_rename_table(db, &table_name, &new_name)?;
745 Ok(format!(
746 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
747 ))
748 }
749 AlterTableOperation::RenameColumn {
750 old_column_name,
751 new_column_name,
752 } => {
753 let old = old_column_name.value.clone();
754 let new = new_column_name.value.clone();
755 db.get_table_mut(table_name.clone())?
756 .rename_column(&old, &new)?;
757 Ok(format!(
758 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
759 ))
760 }
761 AlterTableOperation::AddColumn {
762 column_def,
763 if_not_exists,
764 ..
765 } => {
766 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
767 let table = db.get_table_mut(table_name.clone())?;
768 if *if_not_exists && table.contains_column(parsed.name.clone()) {
769 return Ok(format!(
770 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
771 parsed.name
772 ));
773 }
774 let col_name = parsed.name.clone();
775 table.add_column(parsed)?;
776 Ok(format!(
777 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
778 ))
779 }
780 AlterTableOperation::DropColumn {
781 column_names,
782 if_exists,
783 ..
784 } => {
785 if column_names.len() != 1 {
786 return Err(SQLRiteError::NotImplemented(
787 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
788 ));
789 }
790 let col_name = column_names[0].value.clone();
791 let table = db.get_table_mut(table_name.clone())?;
792 if *if_exists && !table.contains_column(col_name.clone()) {
793 return Ok(format!(
794 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
795 ));
796 }
797 table.drop_column(&col_name)?;
798 Ok(format!(
799 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
800 ))
801 }
802 other => Err(SQLRiteError::NotImplemented(format!(
803 "ALTER TABLE operation {other:?} is not supported"
804 ))),
805 }
806}
807
808pub fn execute_vacuum(db: &mut Database) -> Result<String> {
818 if db.in_transaction() {
819 return Err(SQLRiteError::General(
820 "VACUUM cannot run inside a transaction".to_string(),
821 ));
822 }
823 let path = match db.source_path.clone() {
824 Some(p) => p,
825 None => {
826 return Ok("VACUUM is a no-op for in-memory databases".to_string());
827 }
828 };
829 if let Some(pager) = db.pager.as_mut() {
835 let _ = pager.checkpoint();
836 }
837 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
838 let pages_before = db
839 .pager
840 .as_ref()
841 .map(|p| p.header().page_count)
842 .unwrap_or(0);
843 crate::sql::pager::vacuum_database(db, &path)?;
844 if let Some(pager) = db.pager.as_mut() {
847 let _ = pager.checkpoint();
848 }
849 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
850 let pages_after = db
851 .pager
852 .as_ref()
853 .map(|p| p.header().page_count)
854 .unwrap_or(0);
855 let pages_reclaimed = pages_before.saturating_sub(pages_after);
856 let bytes_reclaimed = size_before.saturating_sub(size_after);
857 Ok(format!(
858 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
859 ))
860}
861
862fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
868 if new == crate::sql::pager::MASTER_TABLE_NAME {
869 return Err(SQLRiteError::General(format!(
870 "'{}' is a reserved name used by the internal schema catalog",
871 crate::sql::pager::MASTER_TABLE_NAME
872 )));
873 }
874 if old == new {
875 return Ok(());
876 }
877 if db.contains_table(new.to_string()) {
878 return Err(SQLRiteError::General(format!(
879 "target table '{new}' already exists"
880 )));
881 }
882
883 let mut table = db
884 .tables
885 .remove(old)
886 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
887 table.tb_name = new.to_string();
888 for idx in table.secondary_indexes.iter_mut() {
889 idx.table_name = new.to_string();
890 if idx.origin == IndexOrigin::Auto
891 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
892 {
893 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
894 }
895 }
896 db.tables.insert(new.to_string(), table);
897 Ok(())
898}
899
900#[derive(Debug, Clone, Copy)]
904enum IndexMethod {
905 Btree,
906 Hnsw,
907 Fts,
909}
910
911fn create_btree_index(
913 db: &mut Database,
914 table_name: &str,
915 index_name: &str,
916 column_name: &str,
917 datatype: &DataType,
918 unique: bool,
919 existing: &[(i64, Value)],
920) -> Result<String> {
921 let mut idx = SecondaryIndex::new(
922 index_name.to_string(),
923 table_name.to_string(),
924 column_name.to_string(),
925 datatype,
926 unique,
927 IndexOrigin::Explicit,
928 )?;
929
930 for (rowid, v) in existing {
934 if unique && idx.would_violate_unique(v) {
935 return Err(SQLRiteError::General(format!(
936 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
937 already contains the duplicate value {}",
938 v.to_display_string()
939 )));
940 }
941 idx.insert(v, *rowid)?;
942 }
943
944 let table_mut = db.get_table_mut(table_name.to_string())?;
945 table_mut.secondary_indexes.push(idx);
946 Ok(index_name.to_string())
947}
948
949fn create_hnsw_index(
951 db: &mut Database,
952 table_name: &str,
953 index_name: &str,
954 column_name: &str,
955 datatype: &DataType,
956 unique: bool,
957 existing: &[(i64, Value)],
958) -> Result<String> {
959 let dim = match datatype {
962 DataType::Vector(d) => *d,
963 other => {
964 return Err(SQLRiteError::General(format!(
965 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
966 )));
967 }
968 };
969
970 if unique {
971 return Err(SQLRiteError::General(
972 "UNIQUE has no meaning for HNSW indexes".to_string(),
973 ));
974 }
975
976 let seed = hash_str_to_seed(index_name);
984 let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
985
986 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
990 std::collections::HashMap::with_capacity(existing.len());
991 for (rowid, v) in existing {
992 match v {
993 Value::Vector(vec) => {
994 if vec.len() != dim {
995 return Err(SQLRiteError::Internal(format!(
996 "row {rowid} stores a {}-dim vector in column '{column_name}' \
997 declared as VECTOR({dim}) — schema invariant violated",
998 vec.len()
999 )));
1000 }
1001 vec_map.insert(*rowid, vec.clone());
1002 }
1003 _ => continue,
1007 }
1008 }
1009
1010 for (rowid, _) in existing {
1011 if let Some(v) = vec_map.get(rowid) {
1012 let v_clone = v.clone();
1013 idx.insert(*rowid, &v_clone, |id| {
1014 vec_map.get(&id).cloned().unwrap_or_default()
1015 });
1016 }
1017 }
1018
1019 let table_mut = db.get_table_mut(table_name.to_string())?;
1020 table_mut.hnsw_indexes.push(HnswIndexEntry {
1021 name: index_name.to_string(),
1022 column_name: column_name.to_string(),
1023 index: idx,
1024 needs_rebuild: false,
1026 });
1027 Ok(index_name.to_string())
1028}
1029
1030fn create_fts_index(
1035 db: &mut Database,
1036 table_name: &str,
1037 index_name: &str,
1038 column_name: &str,
1039 datatype: &DataType,
1040 unique: bool,
1041 existing: &[(i64, Value)],
1042) -> Result<String> {
1043 match datatype {
1048 DataType::Text => {}
1049 other => {
1050 return Err(SQLRiteError::General(format!(
1051 "USING fts requires a TEXT column; '{column_name}' is {other}"
1052 )));
1053 }
1054 }
1055
1056 if unique {
1057 return Err(SQLRiteError::General(
1058 "UNIQUE has no meaning for FTS indexes".to_string(),
1059 ));
1060 }
1061
1062 let mut idx = PostingList::new();
1063 for (rowid, v) in existing {
1064 if let Value::Text(text) = v {
1065 idx.insert(*rowid, text);
1066 }
1067 }
1070
1071 let table_mut = db.get_table_mut(table_name.to_string())?;
1072 table_mut.fts_indexes.push(FtsIndexEntry {
1073 name: index_name.to_string(),
1074 column_name: column_name.to_string(),
1075 index: idx,
1076 needs_rebuild: false,
1077 });
1078 Ok(index_name.to_string())
1079}
1080
1081fn hash_str_to_seed(s: &str) -> u64 {
1085 let mut h: u64 = 0xCBF29CE484222325;
1086 for b in s.as_bytes() {
1087 h ^= *b as u64;
1088 h = h.wrapping_mul(0x100000001B3);
1089 }
1090 h
1091}
1092
1093fn clone_datatype(dt: &DataType) -> DataType {
1096 match dt {
1097 DataType::Integer => DataType::Integer,
1098 DataType::Text => DataType::Text,
1099 DataType::Real => DataType::Real,
1100 DataType::Bool => DataType::Bool,
1101 DataType::Vector(dim) => DataType::Vector(*dim),
1102 DataType::Json => DataType::Json,
1103 DataType::None => DataType::None,
1104 DataType::Invalid => DataType::Invalid,
1105 }
1106}
1107
1108fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1109 if tables.len() != 1 {
1110 return Err(SQLRiteError::NotImplemented(
1111 "multi-table DELETE is not supported yet".to_string(),
1112 ));
1113 }
1114 extract_table_name(&tables[0])
1115}
1116
1117fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1118 if !twj.joins.is_empty() {
1119 return Err(SQLRiteError::NotImplemented(
1120 "JOIN is not supported yet".to_string(),
1121 ));
1122 }
1123 match &twj.relation {
1124 TableFactor::Table { name, .. } => Ok(name.to_string()),
1125 _ => Err(SQLRiteError::NotImplemented(
1126 "only plain table references are supported".to_string(),
1127 )),
1128 }
1129}
1130
1131enum RowidSource {
1133 IndexProbe(Vec<i64>),
1137 FullScan,
1140}
1141
1142fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1147 let Some(expr) = selection else {
1148 return Ok(RowidSource::FullScan);
1149 };
1150 let Some((col, literal)) = try_extract_equality(expr) else {
1151 return Ok(RowidSource::FullScan);
1152 };
1153 let Some(idx) = table.index_for_column(&col) else {
1154 return Ok(RowidSource::FullScan);
1155 };
1156
1157 let literal_value = match convert_literal(&literal) {
1161 Ok(v) => v,
1162 Err(_) => return Ok(RowidSource::FullScan),
1163 };
1164
1165 let mut rowids = idx.lookup(&literal_value);
1169 rowids.sort_unstable();
1170 Ok(RowidSource::IndexProbe(rowids))
1171}
1172
1173fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1177 let peeled = match expr {
1179 Expr::Nested(inner) => inner.as_ref(),
1180 other => other,
1181 };
1182 let Expr::BinaryOp { left, op, right } = peeled else {
1183 return None;
1184 };
1185 if !matches!(op, BinaryOperator::Eq) {
1186 return None;
1187 }
1188 let col_from = |e: &Expr| -> Option<String> {
1189 match e {
1190 Expr::Identifier(ident) => Some(ident.value.clone()),
1191 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1192 _ => None,
1193 }
1194 };
1195 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1196 if let Expr::Value(v) = e {
1197 Some(v.value.clone())
1198 } else {
1199 None
1200 }
1201 };
1202 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1203 return Some((c, l));
1204 }
1205 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1206 return Some((c, l));
1207 }
1208 None
1209}
1210
1211fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1233 if k == 0 {
1234 return None;
1235 }
1236
1237 let func = match order_expr {
1239 Expr::Function(f) => f,
1240 _ => return None,
1241 };
1242 let fname = match func.name.0.as_slice() {
1243 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1244 _ => return None,
1245 };
1246 if fname != "vec_distance_l2" {
1247 return None;
1248 }
1249
1250 let arg_list = match &func.args {
1252 FunctionArguments::List(l) => &l.args,
1253 _ => return None,
1254 };
1255 if arg_list.len() != 2 {
1256 return None;
1257 }
1258 let exprs: Vec<&Expr> = arg_list
1259 .iter()
1260 .filter_map(|a| match a {
1261 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1262 _ => None,
1263 })
1264 .collect();
1265 if exprs.len() != 2 {
1266 return None;
1267 }
1268
1269 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1274 Some(v) => v,
1275 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1276 Some(v) => v,
1277 None => return None,
1278 },
1279 };
1280
1281 let entry = table
1283 .hnsw_indexes
1284 .iter()
1285 .find(|e| e.column_name == col_name)?;
1286
1287 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
1293 Some(c) => match &c.datatype {
1294 DataType::Vector(d) => *d,
1295 _ => return None,
1296 },
1297 None => return None,
1298 };
1299 if query_vec.len() != declared_dim {
1300 return None;
1301 }
1302
1303 let column_for_closure = col_name.clone();
1307 let table_ref = table;
1308 let result = entry.index.search(&query_vec, k, |id| {
1309 match table_ref.get_value(&column_for_closure, id) {
1310 Some(Value::Vector(v)) => v,
1311 _ => Vec::new(),
1312 }
1313 });
1314 Some(result)
1315}
1316
1317fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
1333 if k == 0 || ascending {
1334 return None;
1338 }
1339
1340 let func = match order_expr {
1341 Expr::Function(f) => f,
1342 _ => return None,
1343 };
1344 let fname = match func.name.0.as_slice() {
1345 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1346 _ => return None,
1347 };
1348 if fname != "bm25_score" {
1349 return None;
1350 }
1351
1352 let arg_list = match &func.args {
1353 FunctionArguments::List(l) => &l.args,
1354 _ => return None,
1355 };
1356 if arg_list.len() != 2 {
1357 return None;
1358 }
1359 let exprs: Vec<&Expr> = arg_list
1360 .iter()
1361 .filter_map(|a| match a {
1362 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1363 _ => None,
1364 })
1365 .collect();
1366 if exprs.len() != 2 {
1367 return None;
1368 }
1369
1370 let col_name = match exprs[0] {
1372 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1373 _ => return None,
1374 };
1375
1376 let query = match exprs[1] {
1380 Expr::Value(v) => match &v.value {
1381 AstValue::SingleQuotedString(s) => s.clone(),
1382 _ => return None,
1383 },
1384 _ => return None,
1385 };
1386
1387 let entry = table
1388 .fts_indexes
1389 .iter()
1390 .find(|e| e.column_name == col_name)?;
1391
1392 let scored = entry.index.query(&query, &Bm25Params::default());
1393 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1394 if out.len() > k {
1395 out.truncate(k);
1396 }
1397 Some(out)
1398}
1399
1400fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1405 let col_name = match a {
1406 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1407 _ => return None,
1408 };
1409 let lit_str = match b {
1410 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1411 format!("[{}]", ident.value)
1412 }
1413 _ => return None,
1414 };
1415 let v = parse_vector_literal(&lit_str).ok()?;
1416 Some((col_name, v))
1417}
1418
1419struct HeapEntry {
1432 key: Value,
1433 rowid: i64,
1434 asc: bool,
1435}
1436
1437impl PartialEq for HeapEntry {
1438 fn eq(&self, other: &Self) -> bool {
1439 self.cmp(other) == Ordering::Equal
1440 }
1441}
1442
1443impl Eq for HeapEntry {}
1444
1445impl PartialOrd for HeapEntry {
1446 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1447 Some(self.cmp(other))
1448 }
1449}
1450
1451impl Ord for HeapEntry {
1452 fn cmp(&self, other: &Self) -> Ordering {
1453 let raw = compare_values(Some(&self.key), Some(&other.key));
1454 if self.asc { raw } else { raw.reverse() }
1455 }
1456}
1457
1458fn select_topk(
1467 matching: &[i64],
1468 table: &Table,
1469 order: &OrderByClause,
1470 k: usize,
1471) -> Result<Vec<i64>> {
1472 use std::collections::BinaryHeap;
1473
1474 if k == 0 || matching.is_empty() {
1475 return Ok(Vec::new());
1476 }
1477
1478 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
1479
1480 for &rowid in matching {
1481 let key = eval_expr(&order.expr, table, rowid)?;
1482 let entry = HeapEntry {
1483 key,
1484 rowid,
1485 asc: order.ascending,
1486 };
1487
1488 if heap.len() < k {
1489 heap.push(entry);
1490 } else {
1491 if entry < *heap.peek().unwrap() {
1495 heap.pop();
1496 heap.push(entry);
1497 }
1498 }
1499 }
1500
1501 Ok(heap
1506 .into_sorted_vec()
1507 .into_iter()
1508 .map(|e| e.rowid)
1509 .collect())
1510}
1511
1512fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
1513 let mut keys: Vec<(i64, Result<Value>)> = rowids
1521 .iter()
1522 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
1523 .collect();
1524
1525 for (_, k) in &keys {
1529 if let Err(e) = k {
1530 return Err(SQLRiteError::General(format!(
1531 "ORDER BY expression failed: {e}"
1532 )));
1533 }
1534 }
1535
1536 keys.sort_by(|(_, ka), (_, kb)| {
1537 let va = ka.as_ref().unwrap();
1540 let vb = kb.as_ref().unwrap();
1541 let ord = compare_values(Some(va), Some(vb));
1542 if order.ascending { ord } else { ord.reverse() }
1543 });
1544
1545 for (i, (rowid, _)) in keys.into_iter().enumerate() {
1547 rowids[i] = rowid;
1548 }
1549 Ok(())
1550}
1551
1552fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1553 match (a, b) {
1554 (None, None) => Ordering::Equal,
1555 (None, _) => Ordering::Less,
1556 (_, None) => Ordering::Greater,
1557 (Some(a), Some(b)) => match (a, b) {
1558 (Value::Null, Value::Null) => Ordering::Equal,
1559 (Value::Null, _) => Ordering::Less,
1560 (_, Value::Null) => Ordering::Greater,
1561 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1562 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1563 (Value::Integer(x), Value::Real(y)) => {
1564 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1565 }
1566 (Value::Real(x), Value::Integer(y)) => {
1567 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1568 }
1569 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1570 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1571 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1573 },
1574 }
1575}
1576
1577pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1579 let v = eval_expr(expr, table, rowid)?;
1580 match v {
1581 Value::Bool(b) => Ok(b),
1582 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
1584 other => Err(SQLRiteError::Internal(format!(
1585 "WHERE clause must evaluate to boolean, got {}",
1586 other.to_display_string()
1587 ))),
1588 }
1589}
1590
1591fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1592 match expr {
1593 Expr::Nested(inner) => eval_expr(inner, table, rowid),
1594
1595 Expr::Identifier(ident) => {
1596 if ident.quote_style == Some('[') {
1606 let raw = format!("[{}]", ident.value);
1607 let v = parse_vector_literal(&raw)?;
1608 return Ok(Value::Vector(v));
1609 }
1610 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1611 }
1612
1613 Expr::CompoundIdentifier(parts) => {
1614 let col = parts
1616 .last()
1617 .map(|i| i.value.as_str())
1618 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1619 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1620 }
1621
1622 Expr::Value(v) => convert_literal(&v.value),
1623
1624 Expr::UnaryOp { op, expr } => {
1625 let inner = eval_expr(expr, table, rowid)?;
1626 match op {
1627 UnaryOperator::Not => match inner {
1628 Value::Bool(b) => Ok(Value::Bool(!b)),
1629 Value::Null => Ok(Value::Null),
1630 other => Err(SQLRiteError::Internal(format!(
1631 "NOT applied to non-boolean value: {}",
1632 other.to_display_string()
1633 ))),
1634 },
1635 UnaryOperator::Minus => match inner {
1636 Value::Integer(i) => Ok(Value::Integer(-i)),
1637 Value::Real(f) => Ok(Value::Real(-f)),
1638 Value::Null => Ok(Value::Null),
1639 other => Err(SQLRiteError::Internal(format!(
1640 "unary minus on non-numeric value: {}",
1641 other.to_display_string()
1642 ))),
1643 },
1644 UnaryOperator::Plus => Ok(inner),
1645 other => Err(SQLRiteError::NotImplemented(format!(
1646 "unary operator {other:?} is not supported"
1647 ))),
1648 }
1649 }
1650
1651 Expr::BinaryOp { left, op, right } => match op {
1652 BinaryOperator::And => {
1653 let l = eval_expr(left, table, rowid)?;
1654 let r = eval_expr(right, table, rowid)?;
1655 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1656 }
1657 BinaryOperator::Or => {
1658 let l = eval_expr(left, table, rowid)?;
1659 let r = eval_expr(right, table, rowid)?;
1660 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1661 }
1662 cmp @ (BinaryOperator::Eq
1663 | BinaryOperator::NotEq
1664 | BinaryOperator::Lt
1665 | BinaryOperator::LtEq
1666 | BinaryOperator::Gt
1667 | BinaryOperator::GtEq) => {
1668 let l = eval_expr(left, table, rowid)?;
1669 let r = eval_expr(right, table, rowid)?;
1670 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1672 return Ok(Value::Bool(false));
1673 }
1674 let ord = compare_values(Some(&l), Some(&r));
1675 let result = match cmp {
1676 BinaryOperator::Eq => ord == Ordering::Equal,
1677 BinaryOperator::NotEq => ord != Ordering::Equal,
1678 BinaryOperator::Lt => ord == Ordering::Less,
1679 BinaryOperator::LtEq => ord != Ordering::Greater,
1680 BinaryOperator::Gt => ord == Ordering::Greater,
1681 BinaryOperator::GtEq => ord != Ordering::Less,
1682 _ => unreachable!(),
1683 };
1684 Ok(Value::Bool(result))
1685 }
1686 arith @ (BinaryOperator::Plus
1687 | BinaryOperator::Minus
1688 | BinaryOperator::Multiply
1689 | BinaryOperator::Divide
1690 | BinaryOperator::Modulo) => {
1691 let l = eval_expr(left, table, rowid)?;
1692 let r = eval_expr(right, table, rowid)?;
1693 eval_arith(arith, &l, &r)
1694 }
1695 BinaryOperator::StringConcat => {
1696 let l = eval_expr(left, table, rowid)?;
1697 let r = eval_expr(right, table, rowid)?;
1698 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1699 return Ok(Value::Null);
1700 }
1701 Ok(Value::Text(format!(
1702 "{}{}",
1703 l.to_display_string(),
1704 r.to_display_string()
1705 )))
1706 }
1707 other => Err(SQLRiteError::NotImplemented(format!(
1708 "binary operator {other:?} is not supported yet"
1709 ))),
1710 },
1711
1712 Expr::IsNull(inner) => {
1720 let v = eval_expr(inner, table, rowid)?;
1721 Ok(Value::Bool(matches!(v, Value::Null)))
1722 }
1723 Expr::IsNotNull(inner) => {
1724 let v = eval_expr(inner, table, rowid)?;
1725 Ok(Value::Bool(!matches!(v, Value::Null)))
1726 }
1727
1728 Expr::Like {
1735 negated,
1736 any,
1737 expr: lhs,
1738 pattern,
1739 escape_char,
1740 } => eval_like(
1741 table,
1742 rowid,
1743 *negated,
1744 *any,
1745 lhs,
1746 pattern,
1747 escape_char.as_ref(),
1748 true,
1749 ),
1750 Expr::ILike {
1751 negated,
1752 any,
1753 expr: lhs,
1754 pattern,
1755 escape_char,
1756 } => eval_like(
1757 table,
1758 rowid,
1759 *negated,
1760 *any,
1761 lhs,
1762 pattern,
1763 escape_char.as_ref(),
1764 true,
1765 ),
1766
1767 Expr::InList {
1773 expr: lhs,
1774 list,
1775 negated,
1776 } => eval_in_list(table, rowid, lhs, list, *negated),
1777 Expr::InSubquery { .. } => Err(SQLRiteError::NotImplemented(
1778 "IN (subquery) is not supported (only literal lists are)".to_string(),
1779 )),
1780
1781 Expr::Function(func) => eval_function(func, table, rowid),
1792
1793 other => Err(SQLRiteError::NotImplemented(format!(
1794 "unsupported expression in WHERE/projection: {other:?}"
1795 ))),
1796 }
1797}
1798
1799fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1804 let name = match func.name.0.as_slice() {
1807 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1808 _ => {
1809 return Err(SQLRiteError::NotImplemented(format!(
1810 "qualified function names not supported: {:?}",
1811 func.name
1812 )));
1813 }
1814 };
1815
1816 match name.as_str() {
1817 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1818 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1819 let dist = match name.as_str() {
1820 "vec_distance_l2" => vec_distance_l2(&a, &b),
1821 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1822 "vec_distance_dot" => vec_distance_dot(&a, &b),
1823 _ => unreachable!(),
1824 };
1825 Ok(Value::Real(dist as f64))
1831 }
1832 "json_extract" => json_fn_extract(&name, &func.args, table, rowid),
1837 "json_type" => json_fn_type(&name, &func.args, table, rowid),
1838 "json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
1839 "json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
1840 "fts_match" => {
1844 let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
1845 Ok(Value::Bool(entry.index.matches(rowid, &query)))
1846 }
1847 "bm25_score" => {
1848 let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
1849 let s = entry.index.score(rowid, &query, &Bm25Params::default());
1850 Ok(Value::Real(s))
1851 }
1852 "count" | "sum" | "avg" | "min" | "max" => Err(SQLRiteError::NotImplemented(format!(
1856 "aggregate function '{name}' is not allowed in WHERE / projection-scalar position; \
1857 use it as a top-level projection item (HAVING is not yet supported)"
1858 ))),
1859 other => Err(SQLRiteError::NotImplemented(format!(
1860 "unknown function: {other}(...)"
1861 ))),
1862 }
1863}
1864
1865fn resolve_fts_args<'t>(
1870 fn_name: &str,
1871 args: &FunctionArguments,
1872 table: &'t Table,
1873 rowid: i64,
1874) -> Result<(&'t FtsIndexEntry, String)> {
1875 let arg_list = match args {
1876 FunctionArguments::List(l) => &l.args,
1877 _ => {
1878 return Err(SQLRiteError::General(format!(
1879 "{fn_name}() expects exactly two arguments: (column, query_text)"
1880 )));
1881 }
1882 };
1883 if arg_list.len() != 2 {
1884 return Err(SQLRiteError::General(format!(
1885 "{fn_name}() expects exactly 2 arguments, got {}",
1886 arg_list.len()
1887 )));
1888 }
1889
1890 let col_expr = match &arg_list[0] {
1894 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1895 other => {
1896 return Err(SQLRiteError::NotImplemented(format!(
1897 "{fn_name}() argument 0 must be a column name, got {other:?}"
1898 )));
1899 }
1900 };
1901 let col_name = match col_expr {
1902 Expr::Identifier(ident) => ident.value.clone(),
1903 Expr::CompoundIdentifier(parts) => parts
1904 .last()
1905 .map(|p| p.value.clone())
1906 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
1907 other => {
1908 return Err(SQLRiteError::General(format!(
1909 "{fn_name}() argument 0 must be a column reference, got {other:?}"
1910 )));
1911 }
1912 };
1913
1914 let q_expr = match &arg_list[1] {
1918 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1919 other => {
1920 return Err(SQLRiteError::NotImplemented(format!(
1921 "{fn_name}() argument 1 must be a text expression, got {other:?}"
1922 )));
1923 }
1924 };
1925 let query = match eval_expr(q_expr, table, rowid)? {
1926 Value::Text(s) => s,
1927 other => {
1928 return Err(SQLRiteError::General(format!(
1929 "{fn_name}() argument 1 must be TEXT, got {}",
1930 other.to_display_string()
1931 )));
1932 }
1933 };
1934
1935 let entry = table
1936 .fts_indexes
1937 .iter()
1938 .find(|e| e.column_name == col_name)
1939 .ok_or_else(|| {
1940 SQLRiteError::General(format!(
1941 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
1942 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
1943 ))
1944 })?;
1945 Ok((entry, query))
1946}
1947
1948fn extract_json_and_path(
1962 fn_name: &str,
1963 args: &FunctionArguments,
1964 table: &Table,
1965 rowid: i64,
1966) -> Result<(String, String)> {
1967 let arg_list = match args {
1968 FunctionArguments::List(l) => &l.args,
1969 _ => {
1970 return Err(SQLRiteError::General(format!(
1971 "{fn_name}() expects 1 or 2 arguments"
1972 )));
1973 }
1974 };
1975 if !(arg_list.len() == 1 || arg_list.len() == 2) {
1976 return Err(SQLRiteError::General(format!(
1977 "{fn_name}() expects 1 or 2 arguments, got {}",
1978 arg_list.len()
1979 )));
1980 }
1981 let first_expr = match &arg_list[0] {
1983 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1984 other => {
1985 return Err(SQLRiteError::NotImplemented(format!(
1986 "{fn_name}() argument 0 has unsupported shape: {other:?}"
1987 )));
1988 }
1989 };
1990 let json_text = match eval_expr(first_expr, table, rowid)? {
1991 Value::Text(s) => s,
1992 Value::Null => {
1993 return Err(SQLRiteError::General(format!(
1994 "{fn_name}() called on NULL — JSON column has no value for this row"
1995 )));
1996 }
1997 other => {
1998 return Err(SQLRiteError::General(format!(
1999 "{fn_name}() argument 0 is not JSON-typed: got {}",
2000 other.to_display_string()
2001 )));
2002 }
2003 };
2004
2005 let path = if arg_list.len() == 2 {
2007 let path_expr = match &arg_list[1] {
2008 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2009 other => {
2010 return Err(SQLRiteError::NotImplemented(format!(
2011 "{fn_name}() argument 1 has unsupported shape: {other:?}"
2012 )));
2013 }
2014 };
2015 match eval_expr(path_expr, table, rowid)? {
2016 Value::Text(s) => s,
2017 other => {
2018 return Err(SQLRiteError::General(format!(
2019 "{fn_name}() path argument must be a string literal, got {}",
2020 other.to_display_string()
2021 )));
2022 }
2023 }
2024 } else {
2025 "$".to_string()
2026 };
2027
2028 Ok((json_text, path))
2029}
2030
2031fn walk_json_path<'a>(
2041 value: &'a serde_json::Value,
2042 path: &str,
2043) -> Result<Option<&'a serde_json::Value>> {
2044 let mut chars = path.chars().peekable();
2045 if chars.next() != Some('$') {
2046 return Err(SQLRiteError::General(format!(
2047 "JSON path must start with '$', got `{path}`"
2048 )));
2049 }
2050 let mut current = value;
2051 while let Some(&c) = chars.peek() {
2052 match c {
2053 '.' => {
2054 chars.next();
2055 let mut key = String::new();
2056 while let Some(&c) = chars.peek() {
2057 if c == '.' || c == '[' {
2058 break;
2059 }
2060 key.push(c);
2061 chars.next();
2062 }
2063 if key.is_empty() {
2064 return Err(SQLRiteError::General(format!(
2065 "JSON path has empty key after '.' in `{path}`"
2066 )));
2067 }
2068 match current.get(&key) {
2069 Some(v) => current = v,
2070 None => return Ok(None),
2071 }
2072 }
2073 '[' => {
2074 chars.next();
2075 let mut idx_str = String::new();
2076 while let Some(&c) = chars.peek() {
2077 if c == ']' {
2078 break;
2079 }
2080 idx_str.push(c);
2081 chars.next();
2082 }
2083 if chars.next() != Some(']') {
2084 return Err(SQLRiteError::General(format!(
2085 "JSON path has unclosed `[` in `{path}`"
2086 )));
2087 }
2088 let idx: usize = idx_str.trim().parse().map_err(|_| {
2089 SQLRiteError::General(format!(
2090 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
2091 ))
2092 })?;
2093 match current.get(idx) {
2094 Some(v) => current = v,
2095 None => return Ok(None),
2096 }
2097 }
2098 other => {
2099 return Err(SQLRiteError::General(format!(
2100 "JSON path has unexpected character `{other}` in `{path}` \
2101 (expected `.`, `[`, or end-of-path)"
2102 )));
2103 }
2104 }
2105 }
2106 Ok(Some(current))
2107}
2108
2109fn json_value_to_sql(v: &serde_json::Value) -> Value {
2113 match v {
2114 serde_json::Value::Null => Value::Null,
2115 serde_json::Value::Bool(b) => Value::Bool(*b),
2116 serde_json::Value::Number(n) => {
2117 if let Some(i) = n.as_i64() {
2119 Value::Integer(i)
2120 } else if let Some(f) = n.as_f64() {
2121 Value::Real(f)
2122 } else {
2123 Value::Null
2124 }
2125 }
2126 serde_json::Value::String(s) => Value::Text(s.clone()),
2127 composite => Value::Text(composite.to_string()),
2131 }
2132}
2133
2134fn json_fn_extract(
2135 name: &str,
2136 args: &FunctionArguments,
2137 table: &Table,
2138 rowid: i64,
2139) -> Result<Value> {
2140 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2141 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2142 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2143 })?;
2144 match walk_json_path(&parsed, &path)? {
2145 Some(v) => Ok(json_value_to_sql(v)),
2146 None => Ok(Value::Null),
2147 }
2148}
2149
2150fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
2151 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2152 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2153 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2154 })?;
2155 let resolved = match walk_json_path(&parsed, &path)? {
2156 Some(v) => v,
2157 None => return Ok(Value::Null),
2158 };
2159 let ty = match resolved {
2160 serde_json::Value::Null => "null",
2161 serde_json::Value::Bool(true) => "true",
2162 serde_json::Value::Bool(false) => "false",
2163 serde_json::Value::Number(n) => {
2164 if n.is_i64() || n.is_u64() {
2165 "integer"
2166 } else {
2167 "real"
2168 }
2169 }
2170 serde_json::Value::String(_) => "text",
2171 serde_json::Value::Array(_) => "array",
2172 serde_json::Value::Object(_) => "object",
2173 };
2174 Ok(Value::Text(ty.to_string()))
2175}
2176
2177fn json_fn_array_length(
2178 name: &str,
2179 args: &FunctionArguments,
2180 table: &Table,
2181 rowid: i64,
2182) -> Result<Value> {
2183 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2184 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2185 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2186 })?;
2187 let resolved = match walk_json_path(&parsed, &path)? {
2188 Some(v) => v,
2189 None => return Ok(Value::Null),
2190 };
2191 match resolved.as_array() {
2192 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2193 None => Err(SQLRiteError::General(format!(
2194 "{name}() resolved to a non-array value at path `{path}`"
2195 ))),
2196 }
2197}
2198
2199fn json_fn_object_keys(
2200 name: &str,
2201 args: &FunctionArguments,
2202 table: &Table,
2203 rowid: i64,
2204) -> Result<Value> {
2205 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2206 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2207 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2208 })?;
2209 let resolved = match walk_json_path(&parsed, &path)? {
2210 Some(v) => v,
2211 None => return Ok(Value::Null),
2212 };
2213 let obj = resolved.as_object().ok_or_else(|| {
2214 SQLRiteError::General(format!(
2215 "{name}() resolved to a non-object value at path `{path}`"
2216 ))
2217 })?;
2218 let keys: Vec<serde_json::Value> = obj
2225 .keys()
2226 .map(|k| serde_json::Value::String(k.clone()))
2227 .collect();
2228 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2229}
2230
2231fn extract_two_vector_args(
2235 fn_name: &str,
2236 args: &FunctionArguments,
2237 table: &Table,
2238 rowid: i64,
2239) -> Result<(Vec<f32>, Vec<f32>)> {
2240 let arg_list = match args {
2241 FunctionArguments::List(l) => &l.args,
2242 _ => {
2243 return Err(SQLRiteError::General(format!(
2244 "{fn_name}() expects exactly two vector arguments"
2245 )));
2246 }
2247 };
2248 if arg_list.len() != 2 {
2249 return Err(SQLRiteError::General(format!(
2250 "{fn_name}() expects exactly 2 arguments, got {}",
2251 arg_list.len()
2252 )));
2253 }
2254 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
2255 for (i, arg) in arg_list.iter().enumerate() {
2256 let expr = match arg {
2257 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2258 other => {
2259 return Err(SQLRiteError::NotImplemented(format!(
2260 "{fn_name}() argument {i} has unsupported shape: {other:?}"
2261 )));
2262 }
2263 };
2264 let val = eval_expr(expr, table, rowid)?;
2265 match val {
2266 Value::Vector(v) => out.push(v),
2267 other => {
2268 return Err(SQLRiteError::General(format!(
2269 "{fn_name}() argument {i} is not a vector: got {}",
2270 other.to_display_string()
2271 )));
2272 }
2273 }
2274 }
2275 let b = out.pop().unwrap();
2276 let a = out.pop().unwrap();
2277 if a.len() != b.len() {
2278 return Err(SQLRiteError::General(format!(
2279 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
2280 a.len(),
2281 b.len()
2282 )));
2283 }
2284 Ok((a, b))
2285}
2286
2287pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
2290 debug_assert_eq!(a.len(), b.len());
2291 let mut sum = 0.0f32;
2292 for i in 0..a.len() {
2293 let d = a[i] - b[i];
2294 sum += d * d;
2295 }
2296 sum.sqrt()
2297}
2298
2299pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
2309 debug_assert_eq!(a.len(), b.len());
2310 let mut dot = 0.0f32;
2311 let mut norm_a_sq = 0.0f32;
2312 let mut norm_b_sq = 0.0f32;
2313 for i in 0..a.len() {
2314 dot += a[i] * b[i];
2315 norm_a_sq += a[i] * a[i];
2316 norm_b_sq += b[i] * b[i];
2317 }
2318 let denom = (norm_a_sq * norm_b_sq).sqrt();
2319 if denom == 0.0 {
2320 return Err(SQLRiteError::General(
2321 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
2322 ));
2323 }
2324 Ok(1.0 - dot / denom)
2325}
2326
2327pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
2331 debug_assert_eq!(a.len(), b.len());
2332 let mut dot = 0.0f32;
2333 for i in 0..a.len() {
2334 dot += a[i] * b[i];
2335 }
2336 -dot
2337}
2338
2339fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
2342 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2343 return Ok(Value::Null);
2344 }
2345 match (l, r) {
2346 (Value::Integer(a), Value::Integer(b)) => match op {
2347 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
2348 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
2349 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
2350 BinaryOperator::Divide => {
2351 if *b == 0 {
2352 Err(SQLRiteError::General("division by zero".to_string()))
2353 } else {
2354 Ok(Value::Integer(a / b))
2355 }
2356 }
2357 BinaryOperator::Modulo => {
2358 if *b == 0 {
2359 Err(SQLRiteError::General("modulo by zero".to_string()))
2360 } else {
2361 Ok(Value::Integer(a % b))
2362 }
2363 }
2364 _ => unreachable!(),
2365 },
2366 (a, b) => {
2368 let af = as_number(a)?;
2369 let bf = as_number(b)?;
2370 match op {
2371 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
2372 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
2373 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
2374 BinaryOperator::Divide => {
2375 if bf == 0.0 {
2376 Err(SQLRiteError::General("division by zero".to_string()))
2377 } else {
2378 Ok(Value::Real(af / bf))
2379 }
2380 }
2381 BinaryOperator::Modulo => {
2382 if bf == 0.0 {
2383 Err(SQLRiteError::General("modulo by zero".to_string()))
2384 } else {
2385 Ok(Value::Real(af % bf))
2386 }
2387 }
2388 _ => unreachable!(),
2389 }
2390 }
2391 }
2392}
2393
2394fn as_number(v: &Value) -> Result<f64> {
2395 match v {
2396 Value::Integer(i) => Ok(*i as f64),
2397 Value::Real(f) => Ok(*f),
2398 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
2399 other => Err(SQLRiteError::General(format!(
2400 "arithmetic on non-numeric value '{}'",
2401 other.to_display_string()
2402 ))),
2403 }
2404}
2405
2406fn as_bool(v: &Value) -> Result<bool> {
2407 match v {
2408 Value::Bool(b) => Ok(*b),
2409 Value::Null => Ok(false),
2410 Value::Integer(i) => Ok(*i != 0),
2411 other => Err(SQLRiteError::Internal(format!(
2412 "expected boolean, got {}",
2413 other.to_display_string()
2414 ))),
2415 }
2416}
2417
2418#[allow(clippy::too_many_arguments)]
2423fn eval_like(
2424 table: &Table,
2425 rowid: i64,
2426 negated: bool,
2427 any: bool,
2428 lhs: &Expr,
2429 pattern: &Expr,
2430 escape_char: Option<&AstValue>,
2431 case_insensitive: bool,
2432) -> Result<Value> {
2433 if any {
2434 return Err(SQLRiteError::NotImplemented(
2435 "LIKE ANY (...) is not supported".to_string(),
2436 ));
2437 }
2438 if escape_char.is_some() {
2439 return Err(SQLRiteError::NotImplemented(
2440 "LIKE ... ESCAPE '<char>' is not supported (default `\\` escape only)".to_string(),
2441 ));
2442 }
2443
2444 let l = eval_expr(lhs, table, rowid)?;
2445 let p = eval_expr(pattern, table, rowid)?;
2446 if matches!(l, Value::Null) || matches!(p, Value::Null) {
2447 return Ok(Value::Null);
2448 }
2449 let text = match l {
2450 Value::Text(s) => s,
2451 other => other.to_display_string(),
2452 };
2453 let pat = match p {
2454 Value::Text(s) => s,
2455 other => other.to_display_string(),
2456 };
2457 let m = like_match(&text, &pat, case_insensitive);
2458 Ok(Value::Bool(if negated { !m } else { m }))
2459}
2460
2461fn eval_in_list(
2462 table: &Table,
2463 rowid: i64,
2464 lhs: &Expr,
2465 list: &[Expr],
2466 negated: bool,
2467) -> Result<Value> {
2468 let l = eval_expr(lhs, table, rowid)?;
2469 if matches!(l, Value::Null) {
2470 return Ok(Value::Null);
2471 }
2472 let mut saw_null = false;
2473 for item in list {
2474 let r = eval_expr(item, table, rowid)?;
2475 if matches!(r, Value::Null) {
2476 saw_null = true;
2477 continue;
2478 }
2479 if compare_values(Some(&l), Some(&r)) == Ordering::Equal {
2480 return Ok(Value::Bool(!negated));
2481 }
2482 }
2483 if saw_null {
2484 Ok(Value::Null)
2487 } else {
2488 Ok(Value::Bool(negated))
2489 }
2490}
2491
2492fn aggregate_rows(
2503 table: &Table,
2504 matching: &[i64],
2505 group_by: &[String],
2506 proj_items: &[ProjectionItem],
2507) -> Result<Vec<Vec<Value>>> {
2508 let template: Vec<Option<AggState>> = proj_items
2512 .iter()
2513 .map(|i| match &i.kind {
2514 ProjectionKind::Aggregate(call) => Some(AggState::new(call)),
2515 ProjectionKind::Column(_) => None,
2516 })
2517 .collect();
2518
2519 let mut keys: Vec<Vec<DistinctKey>> = Vec::new();
2525 let mut group_states: Vec<Vec<Option<AggState>>> = Vec::new();
2526 let mut group_key_values: Vec<Vec<Value>> = Vec::new();
2527
2528 for &rowid in matching {
2529 let mut key_values: Vec<Value> = Vec::with_capacity(group_by.len());
2530 let mut key: Vec<DistinctKey> = Vec::with_capacity(group_by.len());
2531 for col in group_by {
2532 let v = table.get_value(col, rowid).unwrap_or(Value::Null);
2533 key.push(DistinctKey::from_value(&v));
2534 key_values.push(v);
2535 }
2536 let idx = match keys.iter().position(|k| k == &key) {
2537 Some(i) => i,
2538 None => {
2539 keys.push(key);
2540 group_states.push(template.clone());
2541 group_key_values.push(key_values);
2542 keys.len() - 1
2543 }
2544 };
2545
2546 for (slot, item) in proj_items.iter().enumerate() {
2547 if let ProjectionKind::Aggregate(call) = &item.kind {
2548 let v = match &call.arg {
2549 AggregateArg::Star => Value::Null,
2550 AggregateArg::Column(c) => table.get_value(c, rowid).unwrap_or(Value::Null),
2551 };
2552 if let Some(state) = group_states[idx][slot].as_mut() {
2553 state.update(&v)?;
2554 }
2555 }
2556 }
2557 }
2558
2559 if keys.is_empty() && group_by.is_empty() {
2565 keys.push(Vec::new());
2568 group_states.push(template.clone());
2569 group_key_values.push(Vec::new());
2570 }
2571
2572 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(keys.len());
2574 for (group_idx, _) in keys.iter().enumerate() {
2575 let mut row: Vec<Value> = Vec::with_capacity(proj_items.len());
2576 for (slot, item) in proj_items.iter().enumerate() {
2577 match &item.kind {
2578 ProjectionKind::Column(c) => {
2579 let pos = group_by
2582 .iter()
2583 .position(|g| g == c)
2584 .expect("validated to be in GROUP BY");
2585 row.push(group_key_values[group_idx][pos].clone());
2586 }
2587 ProjectionKind::Aggregate(_) => {
2588 let state = group_states[group_idx][slot]
2589 .as_ref()
2590 .expect("aggregate slot has state");
2591 row.push(state.finalize());
2592 }
2593 }
2594 }
2595 rows.push(row);
2596 }
2597 Ok(rows)
2598}
2599
2600fn dedupe_rows(rows: Vec<Vec<Value>>) -> Vec<Vec<Value>> {
2604 use std::collections::HashSet;
2605 let mut seen: HashSet<Vec<DistinctKey>> = HashSet::new();
2606 let mut out = Vec::with_capacity(rows.len());
2607 for row in rows {
2608 let key: Vec<DistinctKey> = row.iter().map(DistinctKey::from_value).collect();
2609 if seen.insert(key) {
2610 out.push(row);
2611 }
2612 }
2613 out
2614}
2615
2616fn sort_output_rows(
2620 rows: &mut [Vec<Value>],
2621 columns: &[String],
2622 proj_items: &[ProjectionItem],
2623 order: &OrderByClause,
2624) -> Result<()> {
2625 let target_idx = resolve_order_by_index(&order.expr, columns, proj_items)?;
2626 rows.sort_by(|a, b| {
2627 let va = &a[target_idx];
2628 let vb = &b[target_idx];
2629 let ord = compare_values(Some(va), Some(vb));
2630 if order.ascending { ord } else { ord.reverse() }
2631 });
2632 Ok(())
2633}
2634
2635fn resolve_order_by_index(
2638 expr: &Expr,
2639 columns: &[String],
2640 proj_items: &[ProjectionItem],
2641) -> Result<usize> {
2642 let target_name: Option<String> = match expr {
2644 Expr::Identifier(ident) => Some(ident.value.clone()),
2645 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
2646 Expr::Function(_) => None,
2647 Expr::Nested(inner) => return resolve_order_by_index(inner, columns, proj_items),
2648 other => {
2649 return Err(SQLRiteError::NotImplemented(format!(
2650 "ORDER BY expression not supported on aggregating queries: {other:?}"
2651 )));
2652 }
2653 };
2654 if let Some(name) = target_name {
2655 if let Some(i) = columns.iter().position(|c| c.eq_ignore_ascii_case(&name)) {
2656 return Ok(i);
2657 }
2658 return Err(SQLRiteError::Internal(format!(
2659 "ORDER BY references unknown column '{name}' in the SELECT output"
2660 )));
2661 }
2662 if let Expr::Function(func) = expr {
2666 let user_disp = format_function_display(func);
2667 for (i, item) in proj_items.iter().enumerate() {
2668 if let ProjectionKind::Aggregate(call) = &item.kind
2669 && call.display_name().eq_ignore_ascii_case(&user_disp)
2670 {
2671 return Ok(i);
2672 }
2673 }
2674 return Err(SQLRiteError::Internal(format!(
2675 "ORDER BY references aggregate '{user_disp}' that isn't in the SELECT output"
2676 )));
2677 }
2678 Err(SQLRiteError::Internal(
2679 "ORDER BY expression could not be resolved against the output columns".to_string(),
2680 ))
2681}
2682
2683fn format_function_display(func: &sqlparser::ast::Function) -> String {
2687 let name = match func.name.0.as_slice() {
2688 [ObjectNamePart::Identifier(ident)] => ident.value.to_uppercase(),
2689 _ => format!("{:?}", func.name).to_uppercase(),
2690 };
2691 let inner = match &func.args {
2692 FunctionArguments::List(l) => {
2693 let distinct = matches!(
2694 l.duplicate_treatment,
2695 Some(sqlparser::ast::DuplicateTreatment::Distinct)
2696 );
2697 let arg = l.args.first().map(|a| match a {
2698 FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => "*".to_string(),
2699 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Identifier(i))) => i.value.clone(),
2700 FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::CompoundIdentifier(parts))) => {
2701 parts.last().map(|p| p.value.clone()).unwrap_or_default()
2702 }
2703 _ => String::new(),
2704 });
2705 match (distinct, arg) {
2706 (true, Some(a)) if a != "*" => format!("DISTINCT {a}"),
2707 (_, Some(a)) => a,
2708 _ => String::new(),
2709 }
2710 }
2711 _ => String::new(),
2712 };
2713 format!("{name}({inner})")
2714}
2715
2716fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
2717 use sqlparser::ast::Value as AstValue;
2718 match v {
2719 AstValue::Number(n, _) => {
2720 if let Ok(i) = n.parse::<i64>() {
2721 Ok(Value::Integer(i))
2722 } else if let Ok(f) = n.parse::<f64>() {
2723 Ok(Value::Real(f))
2724 } else {
2725 Err(SQLRiteError::Internal(format!(
2726 "could not parse numeric literal '{n}'"
2727 )))
2728 }
2729 }
2730 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
2731 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
2732 AstValue::Null => Ok(Value::Null),
2733 other => Err(SQLRiteError::NotImplemented(format!(
2734 "unsupported literal value: {other:?}"
2735 ))),
2736 }
2737}
2738
2739#[cfg(test)]
2740mod tests {
2741 use super::*;
2742
2743 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
2750 (a - b).abs() < eps
2751 }
2752
2753 #[test]
2754 fn vec_distance_l2_identical_is_zero() {
2755 let v = vec![0.1, 0.2, 0.3];
2756 assert_eq!(vec_distance_l2(&v, &v), 0.0);
2757 }
2758
2759 #[test]
2760 fn vec_distance_l2_unit_basis_is_sqrt2() {
2761 let a = vec![1.0, 0.0];
2763 let b = vec![0.0, 1.0];
2764 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
2765 }
2766
2767 #[test]
2768 fn vec_distance_l2_known_value() {
2769 let a = vec![0.0, 0.0, 0.0];
2771 let b = vec![3.0, 4.0, 0.0];
2772 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
2773 }
2774
2775 #[test]
2776 fn vec_distance_cosine_identical_is_zero() {
2777 let v = vec![0.1, 0.2, 0.3];
2778 let d = vec_distance_cosine(&v, &v).unwrap();
2779 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
2780 }
2781
2782 #[test]
2783 fn vec_distance_cosine_orthogonal_is_one() {
2784 let a = vec![1.0, 0.0];
2787 let b = vec![0.0, 1.0];
2788 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
2789 }
2790
2791 #[test]
2792 fn vec_distance_cosine_opposite_is_two() {
2793 let a = vec![1.0, 0.0, 0.0];
2795 let b = vec![-1.0, 0.0, 0.0];
2796 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
2797 }
2798
2799 #[test]
2800 fn vec_distance_cosine_zero_magnitude_errors() {
2801 let a = vec![0.0, 0.0];
2803 let b = vec![1.0, 0.0];
2804 let err = vec_distance_cosine(&a, &b).unwrap_err();
2805 assert!(format!("{err}").contains("zero-magnitude"));
2806 }
2807
2808 #[test]
2809 fn vec_distance_dot_negates() {
2810 let a = vec![1.0, 2.0, 3.0];
2812 let b = vec![4.0, 5.0, 6.0];
2813 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
2814 }
2815
2816 #[test]
2817 fn vec_distance_dot_orthogonal_is_zero() {
2818 let a = vec![1.0, 0.0];
2820 let b = vec![0.0, 1.0];
2821 assert_eq!(vec_distance_dot(&a, &b), 0.0);
2822 }
2823
2824 #[test]
2825 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
2826 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
2832 let cos = vec_distance_cosine(&a, &b).unwrap();
2833 assert!(approx_eq(dot, cos - 1.0, 1e-5));
2834 }
2835
2836 use crate::sql::db::database::Database;
2841 use crate::sql::parser::select::SelectQuery;
2842 use sqlparser::dialect::SQLiteDialect;
2843 use sqlparser::parser::Parser;
2844
2845 fn seed_score_table(n: usize) -> Database {
2858 let mut db = Database::new("tempdb".to_string());
2859 crate::sql::process_command(
2860 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
2861 &mut db,
2862 )
2863 .expect("create");
2864 for i in 0..n {
2865 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
2869 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
2870 crate::sql::process_command(&sql, &mut db).expect("insert");
2871 }
2872 db
2873 }
2874
2875 fn parse_select(sql: &str) -> SelectQuery {
2879 let dialect = SQLiteDialect {};
2880 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
2881 let stmt = ast.pop().expect("one statement");
2882 SelectQuery::new(&stmt).expect("select-query")
2883 }
2884
2885 #[test]
2886 fn topk_matches_full_sort_asc() {
2887 let db = seed_score_table(200);
2890 let table = db.get_table("docs".to_string()).unwrap();
2891 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2892 let order = q.order_by.as_ref().unwrap();
2893 let all_rowids = table.rowids();
2894
2895 let mut full = all_rowids.clone();
2897 sort_rowids(&mut full, table, order).unwrap();
2898 full.truncate(10);
2899
2900 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
2902
2903 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
2904 }
2905
2906 #[test]
2907 fn topk_matches_full_sort_desc() {
2908 let db = seed_score_table(200);
2910 let table = db.get_table("docs".to_string()).unwrap();
2911 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
2912 let order = q.order_by.as_ref().unwrap();
2913 let all_rowids = table.rowids();
2914
2915 let mut full = all_rowids.clone();
2916 sort_rowids(&mut full, table, order).unwrap();
2917 full.truncate(10);
2918
2919 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
2920
2921 assert_eq!(
2922 topk, full,
2923 "top-k DESC via heap should match full-sort+truncate"
2924 );
2925 }
2926
2927 #[test]
2928 fn topk_k_larger_than_n_returns_everything_sorted() {
2929 let db = seed_score_table(50);
2934 let table = db.get_table("docs".to_string()).unwrap();
2935 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
2936 let order = q.order_by.as_ref().unwrap();
2937 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
2938 assert_eq!(topk.len(), 50);
2939 let scores: Vec<f64> = topk
2941 .iter()
2942 .filter_map(|r| match table.get_value("score", *r) {
2943 Some(Value::Real(f)) => Some(f),
2944 _ => None,
2945 })
2946 .collect();
2947 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
2948 }
2949
2950 #[test]
2951 fn topk_k_zero_returns_empty() {
2952 let db = seed_score_table(10);
2953 let table = db.get_table("docs".to_string()).unwrap();
2954 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
2955 let order = q.order_by.as_ref().unwrap();
2956 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
2957 assert!(topk.is_empty());
2958 }
2959
2960 #[test]
2961 fn topk_empty_input_returns_empty() {
2962 let db = seed_score_table(0);
2963 let table = db.get_table("docs".to_string()).unwrap();
2964 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
2965 let order = q.order_by.as_ref().unwrap();
2966 let topk = select_topk(&[], table, order, 5).unwrap();
2967 assert!(topk.is_empty());
2968 }
2969
2970 #[test]
2971 fn topk_works_through_select_executor_with_distance_function() {
2972 let mut db = Database::new("tempdb".to_string());
2976 crate::sql::process_command(
2977 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
2978 &mut db,
2979 )
2980 .unwrap();
2981 for v in &[
2988 "[1.0, 0.0]",
2989 "[2.0, 0.0]",
2990 "[0.0, 3.0]",
2991 "[1.0, 4.0]",
2992 "[10.0, 10.0]",
2993 ] {
2994 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
2995 .unwrap();
2996 }
2997 let resp = crate::sql::process_command(
2998 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
2999 &mut db,
3000 )
3001 .unwrap();
3002 assert!(resp.contains("3 rows returned"), "got: {resp}");
3005 }
3006
3007 #[test]
3030 #[ignore]
3031 fn topk_benchmark() {
3032 use std::time::Instant;
3033 const N: usize = 10_000;
3034 const K: usize = 10;
3035
3036 let db = seed_score_table(N);
3037 let table = db.get_table("docs".to_string()).unwrap();
3038 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
3039 let order = q.order_by.as_ref().unwrap();
3040 let all_rowids = table.rowids();
3041
3042 let t0 = Instant::now();
3044 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
3045 let heap_dur = t0.elapsed();
3046
3047 let t1 = Instant::now();
3049 let mut full = all_rowids.clone();
3050 sort_rowids(&mut full, table, order).unwrap();
3051 full.truncate(K);
3052 let sort_dur = t1.elapsed();
3053
3054 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
3055 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
3056 println!(" bounded heap: {heap_dur:?}");
3057 println!(" full sort+trunc: {sort_dur:?}");
3058 println!(" speedup ratio: {ratio:.2}×");
3059
3060 assert!(
3067 ratio > 1.4,
3068 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
3069 );
3070 }
3071
3072 fn run_select(db: &mut Database, sql: &str) -> String {
3080 crate::sql::process_command(sql, db).expect("select")
3081 }
3082
3083 #[test]
3084 fn where_is_null_returns_null_rows() {
3085 let mut db = Database::new("t".to_string());
3086 crate::sql::process_command(
3087 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3088 &mut db,
3089 )
3090 .unwrap();
3091 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3092 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3093 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3094 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
3095
3096 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
3097 assert!(
3098 response.contains("2 rows returned"),
3099 "IS NULL should return 2 rows, got: {response}"
3100 );
3101 }
3102
3103 #[test]
3104 fn where_is_not_null_returns_non_null_rows() {
3105 let mut db = Database::new("t".to_string());
3106 crate::sql::process_command(
3107 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3108 &mut db,
3109 )
3110 .unwrap();
3111 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
3112 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3113 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3114
3115 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
3116 assert!(
3117 response.contains("2 rows returned"),
3118 "IS NOT NULL should return 2 rows, got: {response}"
3119 );
3120 }
3121
3122 #[test]
3123 fn where_is_null_on_indexed_column() {
3124 let mut db = Database::new("t".to_string());
3129 crate::sql::process_command(
3130 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
3131 &mut db,
3132 )
3133 .unwrap();
3134 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
3135 .unwrap();
3136 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
3137 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
3138 .unwrap();
3139
3140 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
3141 assert!(
3142 null_rows.contains("1 row returned"),
3143 "indexed IS NULL should return 1 row, got: {null_rows}"
3144 );
3145 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
3146 assert!(
3147 not_null_rows.contains("2 rows returned"),
3148 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
3149 );
3150 }
3151
3152 #[test]
3153 fn where_is_null_works_on_omitted_column() {
3154 let mut db = Database::new("t".to_string());
3158 crate::sql::process_command(
3159 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
3160 &mut db,
3161 )
3162 .unwrap();
3163 crate::sql::process_command(
3164 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
3165 &mut db,
3166 )
3167 .unwrap();
3168 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
3170
3171 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
3172 assert!(
3173 response.contains("1 row returned"),
3174 "IS NULL should match the omitted-column row, got: {response}"
3175 );
3176 }
3177
3178 #[test]
3179 fn where_is_null_combines_with_and_or() {
3180 let mut db = Database::new("t".to_string());
3184 crate::sql::process_command(
3185 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
3186 &mut db,
3187 )
3188 .unwrap();
3189 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
3190 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
3191 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
3192
3193 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
3194 assert!(
3195 response.contains("1 row returned"),
3196 "IS NULL combined with AND should match exactly row 2, got: {response}"
3197 );
3198 }
3199
3200 fn seed_employees() -> Database {
3206 let mut db = Database::new("t".to_string());
3207 crate::sql::process_command(
3208 "CREATE TABLE emp (id INTEGER PRIMARY KEY, name TEXT, dept TEXT, salary INTEGER);",
3209 &mut db,
3210 )
3211 .unwrap();
3212 let rows = [
3213 "INSERT INTO emp (name, dept, salary) VALUES ('Alice', 'eng', 100);",
3214 "INSERT INTO emp (name, dept, salary) VALUES ('alex', 'eng', 120);",
3215 "INSERT INTO emp (name, dept, salary) VALUES ('Bob', 'eng', 100);",
3216 "INSERT INTO emp (name, dept, salary) VALUES ('Carol', 'sales', 90);",
3217 "INSERT INTO emp (name, dept, salary) VALUES ('Dave', 'sales', NULL);",
3218 "INSERT INTO emp (name, dept, salary) VALUES ('Eve', 'ops', 80);",
3219 ];
3220 for sql in rows {
3221 crate::sql::process_command(sql, &mut db).unwrap();
3222 }
3223 db
3224 }
3225
3226 fn run_rows(db: &Database, sql: &str) -> SelectResult {
3228 let q = parse_select(sql);
3229 execute_select_rows(q, db).expect("select")
3230 }
3231
3232 #[test]
3235 fn like_percent_prefix_case_insensitive() {
3236 let db = seed_employees();
3237 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE 'a%';");
3238 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3240 assert_eq!(names.len(), 2, "expected 2 rows, got {names:?}");
3241 assert!(names.contains(&"Alice".to_string()));
3242 assert!(names.contains(&"alex".to_string()));
3243 }
3244
3245 #[test]
3246 fn like_underscore_singlechar() {
3247 let db = seed_employees();
3248 let r = run_rows(&db, "SELECT name FROM emp WHERE name LIKE '_ve';");
3249 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3251 assert_eq!(names, vec!["Eve".to_string()]);
3252 }
3253
3254 #[test]
3255 fn not_like_excludes_match() {
3256 let db = seed_employees();
3257 let r = run_rows(&db, "SELECT name FROM emp WHERE name NOT LIKE 'a%';");
3258 assert_eq!(r.rows.len(), 4);
3260 }
3261
3262 #[test]
3263 fn like_with_null_excludes_row() {
3264 let db = seed_employees();
3265 let r = run_rows(
3267 &db,
3268 "SELECT name FROM emp WHERE dept LIKE 'sales' AND salary IS NULL;",
3269 );
3270 assert_eq!(r.rows.len(), 1);
3271 assert_eq!(r.rows[0][0].to_display_string(), "Dave");
3272 }
3273
3274 #[test]
3277 fn in_list_positive() {
3278 let db = seed_employees();
3279 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, 3, 5);");
3280 let names: Vec<_> = r.rows.iter().map(|r| r[0].to_display_string()).collect();
3281 assert_eq!(names.len(), 3);
3282 assert!(names.contains(&"Alice".to_string()));
3283 assert!(names.contains(&"Bob".to_string()));
3284 assert!(names.contains(&"Dave".to_string()));
3285 }
3286
3287 #[test]
3288 fn not_in_excludes_listed() {
3289 let db = seed_employees();
3290 let r = run_rows(&db, "SELECT name FROM emp WHERE id NOT IN (1, 2);");
3291 assert_eq!(r.rows.len(), 4);
3293 }
3294
3295 #[test]
3296 fn in_list_with_null_three_valued() {
3297 let db = seed_employees();
3298 let r = run_rows(&db, "SELECT name FROM emp WHERE id IN (1, NULL);");
3301 assert_eq!(r.rows.len(), 1);
3302 assert_eq!(r.rows[0][0].to_display_string(), "Alice");
3303 }
3304
3305 #[test]
3308 fn distinct_single_column() {
3309 let db = seed_employees();
3310 let r = run_rows(&db, "SELECT DISTINCT dept FROM emp;");
3311 assert_eq!(r.rows.len(), 3);
3313 }
3314
3315 #[test]
3316 fn distinct_multi_column_with_null() {
3317 let db = seed_employees();
3318 let r = run_rows(&db, "SELECT DISTINCT dept, salary FROM emp;");
3320 assert_eq!(r.rows.len(), 5);
3322 }
3323
3324 #[test]
3327 fn count_star_no_groupby() {
3328 let db = seed_employees();
3329 let r = run_rows(&db, "SELECT COUNT(*) FROM emp;");
3330 assert_eq!(r.rows.len(), 1);
3331 assert_eq!(r.rows[0][0], Value::Integer(6));
3332 }
3333
3334 #[test]
3335 fn count_col_skips_nulls() {
3336 let db = seed_employees();
3337 let r = run_rows(&db, "SELECT COUNT(salary) FROM emp;");
3338 assert_eq!(r.rows[0][0], Value::Integer(5));
3340 }
3341
3342 #[test]
3343 fn count_distinct_dedupes_and_skips_nulls() {
3344 let db = seed_employees();
3345 let r = run_rows(&db, "SELECT COUNT(DISTINCT salary) FROM emp;");
3346 assert_eq!(r.rows[0][0], Value::Integer(4));
3348 }
3349
3350 #[test]
3351 fn sum_int_stays_integer() {
3352 let db = seed_employees();
3353 let r = run_rows(&db, "SELECT SUM(salary) FROM emp;");
3354 assert_eq!(r.rows[0][0], Value::Integer(490));
3356 }
3357
3358 #[test]
3359 fn avg_returns_real() {
3360 let db = seed_employees();
3361 let r = run_rows(&db, "SELECT AVG(salary) FROM emp;");
3362 match &r.rows[0][0] {
3364 Value::Real(v) => assert!((v - 98.0).abs() < 1e-9),
3365 other => panic!("expected Real, got {other:?}"),
3366 }
3367 }
3368
3369 #[test]
3370 fn min_max_skip_nulls() {
3371 let db = seed_employees();
3372 let r = run_rows(&db, "SELECT MIN(salary), MAX(salary) FROM emp;");
3373 assert_eq!(r.rows[0][0], Value::Integer(80));
3374 assert_eq!(r.rows[0][1], Value::Integer(120));
3375 }
3376
3377 #[test]
3378 fn aggregates_on_empty_table_emit_one_row() {
3379 let mut db = Database::new("t".to_string());
3380 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3381 let r = run_rows(
3382 &db,
3383 "SELECT COUNT(*), SUM(x), AVG(x), MIN(x), MAX(x) FROM t;",
3384 );
3385 assert_eq!(r.rows.len(), 1);
3386 assert_eq!(r.rows[0][0], Value::Integer(0));
3387 assert_eq!(r.rows[0][1], Value::Null);
3388 assert_eq!(r.rows[0][2], Value::Null);
3389 assert_eq!(r.rows[0][3], Value::Null);
3390 assert_eq!(r.rows[0][4], Value::Null);
3391 }
3392
3393 #[test]
3396 fn group_by_single_col_with_count() {
3397 let db = seed_employees();
3398 let r = run_rows(&db, "SELECT dept, COUNT(*) FROM emp GROUP BY dept;");
3399 assert_eq!(r.rows.len(), 3);
3400 let mut by_dept: std::collections::HashMap<String, i64> = Default::default();
3402 for row in &r.rows {
3403 let d = row[0].to_display_string();
3404 let c = match &row[1] {
3405 Value::Integer(i) => *i,
3406 v => panic!("expected Integer count, got {v:?}"),
3407 };
3408 by_dept.insert(d, c);
3409 }
3410 assert_eq!(by_dept["eng"], 3);
3411 assert_eq!(by_dept["sales"], 2);
3412 assert_eq!(by_dept["ops"], 1);
3413 }
3414
3415 #[test]
3416 fn group_by_with_where_filter() {
3417 let db = seed_employees();
3418 let r = run_rows(
3419 &db,
3420 "SELECT dept, SUM(salary) FROM emp WHERE salary > 80 GROUP BY dept;",
3421 );
3422 let by: std::collections::HashMap<String, i64> = r
3425 .rows
3426 .iter()
3427 .map(|row| {
3428 (
3429 row[0].to_display_string(),
3430 match &row[1] {
3431 Value::Integer(i) => *i,
3432 v => panic!("expected Integer sum, got {v:?}"),
3433 },
3434 )
3435 })
3436 .collect();
3437 assert_eq!(by.len(), 2);
3438 assert_eq!(by["eng"], 320);
3439 assert_eq!(by["sales"], 90);
3440 }
3441
3442 #[test]
3443 fn group_by_without_aggregates_is_distinct() {
3444 let db = seed_employees();
3445 let r = run_rows(&db, "SELECT dept FROM emp GROUP BY dept;");
3446 assert_eq!(r.rows.len(), 3);
3447 }
3448
3449 #[test]
3450 fn order_by_count_desc() {
3451 let db = seed_employees();
3452 let r = run_rows(
3453 &db,
3454 "SELECT dept, COUNT(*) AS n FROM emp GROUP BY dept ORDER BY n DESC LIMIT 2;",
3455 );
3456 assert_eq!(r.rows.len(), 2);
3457 assert_eq!(r.rows[0][0].to_display_string(), "eng");
3459 assert_eq!(r.rows[0][1], Value::Integer(3));
3460 }
3461
3462 #[test]
3463 fn order_by_aggregate_call_form() {
3464 let db = seed_employees();
3465 let r = run_rows(
3467 &db,
3468 "SELECT dept, COUNT(*) FROM emp GROUP BY dept ORDER BY COUNT(*) DESC;",
3469 );
3470 assert_eq!(r.rows.len(), 3);
3471 assert_eq!(r.rows[0][0].to_display_string(), "eng");
3472 }
3473
3474 #[test]
3475 fn group_by_invalid_bare_column_errors() {
3476 let mut db = Database::new("t".to_string());
3478 crate::sql::process_command(
3479 "CREATE TABLE t (id INTEGER PRIMARY KEY, dept TEXT, name TEXT);",
3480 &mut db,
3481 )
3482 .unwrap();
3483 let err = crate::sql::process_command("SELECT dept, name FROM t GROUP BY dept;", &mut db);
3484 assert!(err.is_err(), "should reject bare 'name' not in GROUP BY");
3485 }
3486
3487 #[test]
3488 fn aggregate_in_where_errors_friendly() {
3489 let mut db = Database::new("t".to_string());
3490 crate::sql::process_command("CREATE TABLE t (x INTEGER);", &mut db).unwrap();
3491 crate::sql::process_command("INSERT INTO t (x) VALUES (1);", &mut db).unwrap();
3492 let err = crate::sql::process_command("SELECT x FROM t WHERE COUNT(*) > 0;", &mut db);
3493 assert!(err.is_err(), "aggregates must not be allowed in WHERE");
3494 }
3495}