1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AlterTable, AlterTableOperation, AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr,
9 FromTable, FunctionArg, FunctionArgExpr, FunctionArguments, IndexType, ObjectName,
10 ObjectNamePart, RenameTableNameKind, Statement, TableFactor, TableWithJoins, UnaryOperator,
11 Update, Value as AstValue,
12};
13
14use crate::error::{Result, SQLRiteError};
15use crate::sql::db::database::Database;
16use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
17use crate::sql::db::table::{
18 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
19};
20use crate::sql::fts::{Bm25Params, PostingList};
21use crate::sql::hnsw::{DistanceMetric, HnswIndex};
22use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
23
24pub struct SelectResult {
33 pub columns: Vec<String>,
34 pub rows: Vec<Vec<Value>>,
35}
36
37pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
41 let table = db
42 .get_table(query.table_name.clone())
43 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
44
45 let projected_cols: Vec<String> = match &query.projection {
47 Projection::All => table.column_names(),
48 Projection::Columns(cols) => {
49 for c in cols {
50 if !table.contains_column(c.to_string()) {
51 return Err(SQLRiteError::Internal(format!(
52 "Column '{c}' does not exist on table '{}'",
53 query.table_name
54 )));
55 }
56 }
57 cols.clone()
58 }
59 };
60
61 let matching = match select_rowids(table, query.selection.as_ref())? {
65 RowidSource::IndexProbe(rowids) => rowids,
66 RowidSource::FullScan => {
67 let mut out = Vec::new();
68 for rowid in table.rowids() {
69 if let Some(expr) = &query.selection {
70 if !eval_predicate(expr, table, rowid)? {
71 continue;
72 }
73 }
74 out.push(rowid);
75 }
76 out
77 }
78 };
79 let mut matching = matching;
80
81 match (&query.order_by, query.limit) {
111 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
112 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
113 }
114 (Some(order), Some(k))
115 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
116 {
117 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
118 }
119 (Some(order), Some(k)) if k < matching.len() => {
120 matching = select_topk(&matching, table, order, k)?;
121 }
122 (Some(order), _) => {
123 sort_rowids(&mut matching, table, order)?;
124 if let Some(k) = query.limit {
125 matching.truncate(k);
126 }
127 }
128 (None, Some(k)) => {
129 matching.truncate(k);
130 }
131 (None, None) => {}
132 }
133
134 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
138 for rowid in &matching {
139 let row: Vec<Value> = projected_cols
140 .iter()
141 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
142 .collect();
143 rows.push(row);
144 }
145
146 Ok(SelectResult {
147 columns: projected_cols,
148 rows,
149 })
150}
151
152pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
157 let result = execute_select_rows(query, db)?;
158 let row_count = result.rows.len();
159
160 let mut print_table = PrintTable::new();
161 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
162 print_table.add_row(PrintRow::new(header_cells));
163
164 for row in &result.rows {
165 let cells: Vec<PrintCell> = row
166 .iter()
167 .map(|v| PrintCell::new(&v.to_display_string()))
168 .collect();
169 print_table.add_row(PrintRow::new(cells));
170 }
171
172 Ok((print_table.to_string(), row_count))
173}
174
175pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
177 let Statement::Delete(Delete {
178 from, selection, ..
179 }) = stmt
180 else {
181 return Err(SQLRiteError::Internal(
182 "execute_delete called on a non-DELETE statement".to_string(),
183 ));
184 };
185
186 let tables = match from {
187 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
188 };
189 let table_name = extract_single_table_name(tables)?;
190
191 let matching: Vec<i64> = {
193 let table = db
194 .get_table(table_name.clone())
195 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
196 match select_rowids(table, selection.as_ref())? {
197 RowidSource::IndexProbe(rowids) => rowids,
198 RowidSource::FullScan => {
199 let mut out = Vec::new();
200 for rowid in table.rowids() {
201 if let Some(expr) = selection {
202 if !eval_predicate(expr, table, rowid)? {
203 continue;
204 }
205 }
206 out.push(rowid);
207 }
208 out
209 }
210 }
211 };
212
213 let table = db.get_table_mut(table_name)?;
214 for rowid in &matching {
215 table.delete_row(*rowid);
216 }
217 if !matching.is_empty() {
226 for entry in &mut table.hnsw_indexes {
227 entry.needs_rebuild = true;
228 }
229 for entry in &mut table.fts_indexes {
230 entry.needs_rebuild = true;
231 }
232 }
233 Ok(matching.len())
234}
235
236pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
238 let Statement::Update(Update {
239 table,
240 assignments,
241 from,
242 selection,
243 ..
244 }) = stmt
245 else {
246 return Err(SQLRiteError::Internal(
247 "execute_update called on a non-UPDATE statement".to_string(),
248 ));
249 };
250
251 if from.is_some() {
252 return Err(SQLRiteError::NotImplemented(
253 "UPDATE ... FROM is not supported yet".to_string(),
254 ));
255 }
256
257 let table_name = extract_table_name(table)?;
258
259 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
261 {
262 let tbl = db
263 .get_table(table_name.clone())
264 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
265 for a in assignments {
266 let col = match &a.target {
267 AssignmentTarget::ColumnName(name) => name
268 .0
269 .last()
270 .map(|p| p.to_string())
271 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
272 AssignmentTarget::Tuple(_) => {
273 return Err(SQLRiteError::NotImplemented(
274 "tuple assignment targets are not supported".to_string(),
275 ));
276 }
277 };
278 if !tbl.contains_column(col.clone()) {
279 return Err(SQLRiteError::Internal(format!(
280 "UPDATE references unknown column '{col}'"
281 )));
282 }
283 parsed_assignments.push((col, a.value.clone()));
284 }
285 }
286
287 let work: Vec<(i64, Vec<(String, Value)>)> = {
291 let tbl = db.get_table(table_name.clone())?;
292 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
293 RowidSource::IndexProbe(rowids) => rowids,
294 RowidSource::FullScan => {
295 let mut out = Vec::new();
296 for rowid in tbl.rowids() {
297 if let Some(expr) = selection {
298 if !eval_predicate(expr, tbl, rowid)? {
299 continue;
300 }
301 }
302 out.push(rowid);
303 }
304 out
305 }
306 };
307 let mut rows_to_update = Vec::new();
308 for rowid in matched_rowids {
309 let mut values = Vec::with_capacity(parsed_assignments.len());
310 for (col, expr) in &parsed_assignments {
311 let v = eval_expr(expr, tbl, rowid)?;
314 values.push((col.clone(), v));
315 }
316 rows_to_update.push((rowid, values));
317 }
318 rows_to_update
319 };
320
321 let tbl = db.get_table_mut(table_name)?;
322 for (rowid, values) in &work {
323 for (col, v) in values {
324 tbl.set_value(col, *rowid, v.clone())?;
325 }
326 }
327
328 if !work.is_empty() {
337 let updated_columns: std::collections::HashSet<&str> = work
338 .iter()
339 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
340 .collect();
341 for entry in &mut tbl.hnsw_indexes {
342 if updated_columns.contains(entry.column_name.as_str()) {
343 entry.needs_rebuild = true;
344 }
345 }
346 for entry in &mut tbl.fts_indexes {
347 if updated_columns.contains(entry.column_name.as_str()) {
348 entry.needs_rebuild = true;
349 }
350 }
351 }
352 Ok(work.len())
353}
354
355pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
367 let Statement::CreateIndex(CreateIndex {
368 name,
369 table_name,
370 columns,
371 using,
372 unique,
373 if_not_exists,
374 predicate,
375 ..
376 }) = stmt
377 else {
378 return Err(SQLRiteError::Internal(
379 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
380 ));
381 };
382
383 if predicate.is_some() {
384 return Err(SQLRiteError::NotImplemented(
385 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
386 ));
387 }
388
389 if columns.len() != 1 {
390 return Err(SQLRiteError::NotImplemented(format!(
391 "multi-column indexes are not supported yet ({} columns given)",
392 columns.len()
393 )));
394 }
395
396 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
397 SQLRiteError::NotImplemented(
398 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
399 )
400 })?;
401
402 let method = match using {
408 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
409 IndexMethod::Hnsw
410 }
411 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
412 IndexMethod::Fts
413 }
414 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
415 IndexMethod::Btree
416 }
417 Some(other) => {
418 return Err(SQLRiteError::NotImplemented(format!(
419 "CREATE INDEX … USING {other:?} is not supported \
420 (try `hnsw`, `fts`, or no USING clause)"
421 )));
422 }
423 None => IndexMethod::Btree,
424 };
425
426 let table_name_str = table_name.to_string();
427 let column_name = match &columns[0].column.expr {
428 Expr::Identifier(ident) => ident.value.clone(),
429 Expr::CompoundIdentifier(parts) => parts
430 .last()
431 .map(|p| p.value.clone())
432 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
433 other => {
434 return Err(SQLRiteError::NotImplemented(format!(
435 "CREATE INDEX only supports simple column references, got {other:?}"
436 )));
437 }
438 };
439
440 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
445 let table = db.get_table(table_name_str.clone()).map_err(|_| {
446 SQLRiteError::General(format!(
447 "CREATE INDEX references unknown table '{table_name_str}'"
448 ))
449 })?;
450 if !table.contains_column(column_name.clone()) {
451 return Err(SQLRiteError::General(format!(
452 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
453 )));
454 }
455 let col = table
456 .columns
457 .iter()
458 .find(|c| c.column_name == column_name)
459 .expect("we just verified the column exists");
460
461 if table.index_by_name(&index_name).is_some()
464 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
465 || table.fts_indexes.iter().any(|i| i.name == index_name)
466 {
467 if *if_not_exists {
468 return Ok(index_name);
469 }
470 return Err(SQLRiteError::General(format!(
471 "index '{index_name}' already exists"
472 )));
473 }
474 let datatype = clone_datatype(&col.datatype);
475
476 let mut pairs = Vec::new();
477 for rowid in table.rowids() {
478 if let Some(v) = table.get_value(&column_name, rowid) {
479 pairs.push((rowid, v));
480 }
481 }
482 (datatype, pairs)
483 };
484
485 match method {
486 IndexMethod::Btree => create_btree_index(
487 db,
488 &table_name_str,
489 &index_name,
490 &column_name,
491 &datatype,
492 *unique,
493 &existing_rowids_and_values,
494 ),
495 IndexMethod::Hnsw => create_hnsw_index(
496 db,
497 &table_name_str,
498 &index_name,
499 &column_name,
500 &datatype,
501 *unique,
502 &existing_rowids_and_values,
503 ),
504 IndexMethod::Fts => create_fts_index(
505 db,
506 &table_name_str,
507 &index_name,
508 &column_name,
509 &datatype,
510 *unique,
511 &existing_rowids_and_values,
512 ),
513 }
514}
515
516pub fn execute_drop_table(
527 names: &[ObjectName],
528 if_exists: bool,
529 db: &mut Database,
530) -> Result<usize> {
531 if names.len() != 1 {
532 return Err(SQLRiteError::NotImplemented(
533 "DROP TABLE supports a single table per statement".to_string(),
534 ));
535 }
536 let name = names[0].to_string();
537
538 if name == crate::sql::pager::MASTER_TABLE_NAME {
539 return Err(SQLRiteError::General(format!(
540 "'{}' is a reserved name used by the internal schema catalog",
541 crate::sql::pager::MASTER_TABLE_NAME
542 )));
543 }
544
545 if !db.contains_table(name.clone()) {
546 return if if_exists {
547 Ok(0)
548 } else {
549 Err(SQLRiteError::General(format!(
550 "Table '{name}' does not exist"
551 )))
552 };
553 }
554
555 db.tables.remove(&name);
556 Ok(1)
557}
558
559pub fn execute_drop_index(
568 names: &[ObjectName],
569 if_exists: bool,
570 db: &mut Database,
571) -> Result<usize> {
572 if names.len() != 1 {
573 return Err(SQLRiteError::NotImplemented(
574 "DROP INDEX supports a single index per statement".to_string(),
575 ));
576 }
577 let name = names[0].to_string();
578
579 for table in db.tables.values_mut() {
580 if let Some(secondary) = table.secondary_indexes.iter().find(|i| i.name == name) {
581 if secondary.origin == IndexOrigin::Auto {
582 return Err(SQLRiteError::General(format!(
583 "cannot drop auto-created index '{name}' (drop the column or table instead)"
584 )));
585 }
586 table.secondary_indexes.retain(|i| i.name != name);
587 return Ok(1);
588 }
589 if table.hnsw_indexes.iter().any(|i| i.name == name) {
590 table.hnsw_indexes.retain(|i| i.name != name);
591 return Ok(1);
592 }
593 if table.fts_indexes.iter().any(|i| i.name == name) {
594 table.fts_indexes.retain(|i| i.name != name);
595 return Ok(1);
596 }
597 }
598
599 if if_exists {
600 Ok(0)
601 } else {
602 Err(SQLRiteError::General(format!(
603 "Index '{name}' does not exist"
604 )))
605 }
606}
607
608pub fn execute_alter_table(alter: AlterTable, db: &mut Database) -> Result<String> {
620 let table_name = alter.name.to_string();
621
622 if table_name == crate::sql::pager::MASTER_TABLE_NAME {
623 return Err(SQLRiteError::General(format!(
624 "'{}' is a reserved name used by the internal schema catalog",
625 crate::sql::pager::MASTER_TABLE_NAME
626 )));
627 }
628
629 if !db.contains_table(table_name.clone()) {
630 return if alter.if_exists {
631 Ok("ALTER TABLE: no-op (table does not exist)".to_string())
632 } else {
633 Err(SQLRiteError::General(format!(
634 "Table '{table_name}' does not exist"
635 )))
636 };
637 }
638
639 if alter.operations.len() != 1 {
640 return Err(SQLRiteError::NotImplemented(
641 "ALTER TABLE supports one operation per statement".to_string(),
642 ));
643 }
644
645 match &alter.operations[0] {
646 AlterTableOperation::RenameTable { table_name: kind } => {
647 let new_name = match kind {
648 RenameTableNameKind::To(name) => name.to_string(),
649 RenameTableNameKind::As(_) => {
650 return Err(SQLRiteError::NotImplemented(
651 "ALTER TABLE ... RENAME AS (MySQL-only) is not supported; use RENAME TO"
652 .to_string(),
653 ));
654 }
655 };
656 alter_rename_table(db, &table_name, &new_name)?;
657 Ok(format!(
658 "ALTER TABLE '{table_name}' RENAME TO '{new_name}' executed."
659 ))
660 }
661 AlterTableOperation::RenameColumn {
662 old_column_name,
663 new_column_name,
664 } => {
665 let old = old_column_name.value.clone();
666 let new = new_column_name.value.clone();
667 db.get_table_mut(table_name.clone())?
668 .rename_column(&old, &new)?;
669 Ok(format!(
670 "ALTER TABLE '{table_name}' RENAME COLUMN '{old}' TO '{new}' executed."
671 ))
672 }
673 AlterTableOperation::AddColumn {
674 column_def,
675 if_not_exists,
676 ..
677 } => {
678 let parsed = crate::sql::parser::create::parse_one_column(column_def)?;
679 let table = db.get_table_mut(table_name.clone())?;
680 if *if_not_exists && table.contains_column(parsed.name.clone()) {
681 return Ok(format!(
682 "ALTER TABLE '{table_name}' ADD COLUMN: no-op (column '{}' already exists)",
683 parsed.name
684 ));
685 }
686 let col_name = parsed.name.clone();
687 table.add_column(parsed)?;
688 Ok(format!(
689 "ALTER TABLE '{table_name}' ADD COLUMN '{col_name}' executed."
690 ))
691 }
692 AlterTableOperation::DropColumn {
693 column_names,
694 if_exists,
695 ..
696 } => {
697 if column_names.len() != 1 {
698 return Err(SQLRiteError::NotImplemented(
699 "ALTER TABLE DROP COLUMN supports a single column per statement".to_string(),
700 ));
701 }
702 let col_name = column_names[0].value.clone();
703 let table = db.get_table_mut(table_name.clone())?;
704 if *if_exists && !table.contains_column(col_name.clone()) {
705 return Ok(format!(
706 "ALTER TABLE '{table_name}' DROP COLUMN: no-op (column '{col_name}' does not exist)"
707 ));
708 }
709 table.drop_column(&col_name)?;
710 Ok(format!(
711 "ALTER TABLE '{table_name}' DROP COLUMN '{col_name}' executed."
712 ))
713 }
714 other => Err(SQLRiteError::NotImplemented(format!(
715 "ALTER TABLE operation {other:?} is not supported"
716 ))),
717 }
718}
719
720pub fn execute_vacuum(db: &mut Database) -> Result<String> {
730 if db.in_transaction() {
731 return Err(SQLRiteError::General(
732 "VACUUM cannot run inside a transaction".to_string(),
733 ));
734 }
735 let path = match db.source_path.clone() {
736 Some(p) => p,
737 None => {
738 return Ok("VACUUM is a no-op for in-memory databases".to_string());
739 }
740 };
741 if let Some(pager) = db.pager.as_mut() {
747 let _ = pager.checkpoint();
748 }
749 let size_before = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
750 let pages_before = db
751 .pager
752 .as_ref()
753 .map(|p| p.header().page_count)
754 .unwrap_or(0);
755 crate::sql::pager::vacuum_database(db, &path)?;
756 if let Some(pager) = db.pager.as_mut() {
759 let _ = pager.checkpoint();
760 }
761 let size_after = std::fs::metadata(&path).ok().map(|m| m.len()).unwrap_or(0);
762 let pages_after = db
763 .pager
764 .as_ref()
765 .map(|p| p.header().page_count)
766 .unwrap_or(0);
767 let pages_reclaimed = pages_before.saturating_sub(pages_after);
768 let bytes_reclaimed = size_before.saturating_sub(size_after);
769 Ok(format!(
770 "VACUUM completed. {pages_reclaimed} pages reclaimed ({bytes_reclaimed} bytes)."
771 ))
772}
773
774fn alter_rename_table(db: &mut Database, old: &str, new: &str) -> Result<()> {
780 if new == crate::sql::pager::MASTER_TABLE_NAME {
781 return Err(SQLRiteError::General(format!(
782 "'{}' is a reserved name used by the internal schema catalog",
783 crate::sql::pager::MASTER_TABLE_NAME
784 )));
785 }
786 if old == new {
787 return Ok(());
788 }
789 if db.contains_table(new.to_string()) {
790 return Err(SQLRiteError::General(format!(
791 "target table '{new}' already exists"
792 )));
793 }
794
795 let mut table = db
796 .tables
797 .remove(old)
798 .ok_or_else(|| SQLRiteError::General(format!("Table '{old}' does not exist")))?;
799 table.tb_name = new.to_string();
800 for idx in table.secondary_indexes.iter_mut() {
801 idx.table_name = new.to_string();
802 if idx.origin == IndexOrigin::Auto
803 && idx.name == SecondaryIndex::auto_name(old, &idx.column_name)
804 {
805 idx.name = SecondaryIndex::auto_name(new, &idx.column_name);
806 }
807 }
808 db.tables.insert(new.to_string(), table);
809 Ok(())
810}
811
812#[derive(Debug, Clone, Copy)]
816enum IndexMethod {
817 Btree,
818 Hnsw,
819 Fts,
821}
822
823fn create_btree_index(
825 db: &mut Database,
826 table_name: &str,
827 index_name: &str,
828 column_name: &str,
829 datatype: &DataType,
830 unique: bool,
831 existing: &[(i64, Value)],
832) -> Result<String> {
833 let mut idx = SecondaryIndex::new(
834 index_name.to_string(),
835 table_name.to_string(),
836 column_name.to_string(),
837 datatype,
838 unique,
839 IndexOrigin::Explicit,
840 )?;
841
842 for (rowid, v) in existing {
846 if unique && idx.would_violate_unique(v) {
847 return Err(SQLRiteError::General(format!(
848 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
849 already contains the duplicate value {}",
850 v.to_display_string()
851 )));
852 }
853 idx.insert(v, *rowid)?;
854 }
855
856 let table_mut = db.get_table_mut(table_name.to_string())?;
857 table_mut.secondary_indexes.push(idx);
858 Ok(index_name.to_string())
859}
860
861fn create_hnsw_index(
863 db: &mut Database,
864 table_name: &str,
865 index_name: &str,
866 column_name: &str,
867 datatype: &DataType,
868 unique: bool,
869 existing: &[(i64, Value)],
870) -> Result<String> {
871 let dim = match datatype {
874 DataType::Vector(d) => *d,
875 other => {
876 return Err(SQLRiteError::General(format!(
877 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
878 )));
879 }
880 };
881
882 if unique {
883 return Err(SQLRiteError::General(
884 "UNIQUE has no meaning for HNSW indexes".to_string(),
885 ));
886 }
887
888 let seed = hash_str_to_seed(index_name);
896 let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
897
898 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
902 std::collections::HashMap::with_capacity(existing.len());
903 for (rowid, v) in existing {
904 match v {
905 Value::Vector(vec) => {
906 if vec.len() != dim {
907 return Err(SQLRiteError::Internal(format!(
908 "row {rowid} stores a {}-dim vector in column '{column_name}' \
909 declared as VECTOR({dim}) — schema invariant violated",
910 vec.len()
911 )));
912 }
913 vec_map.insert(*rowid, vec.clone());
914 }
915 _ => continue,
919 }
920 }
921
922 for (rowid, _) in existing {
923 if let Some(v) = vec_map.get(rowid) {
924 let v_clone = v.clone();
925 idx.insert(*rowid, &v_clone, |id| {
926 vec_map.get(&id).cloned().unwrap_or_default()
927 });
928 }
929 }
930
931 let table_mut = db.get_table_mut(table_name.to_string())?;
932 table_mut.hnsw_indexes.push(HnswIndexEntry {
933 name: index_name.to_string(),
934 column_name: column_name.to_string(),
935 index: idx,
936 needs_rebuild: false,
938 });
939 Ok(index_name.to_string())
940}
941
942fn create_fts_index(
947 db: &mut Database,
948 table_name: &str,
949 index_name: &str,
950 column_name: &str,
951 datatype: &DataType,
952 unique: bool,
953 existing: &[(i64, Value)],
954) -> Result<String> {
955 match datatype {
960 DataType::Text => {}
961 other => {
962 return Err(SQLRiteError::General(format!(
963 "USING fts requires a TEXT column; '{column_name}' is {other}"
964 )));
965 }
966 }
967
968 if unique {
969 return Err(SQLRiteError::General(
970 "UNIQUE has no meaning for FTS indexes".to_string(),
971 ));
972 }
973
974 let mut idx = PostingList::new();
975 for (rowid, v) in existing {
976 if let Value::Text(text) = v {
977 idx.insert(*rowid, text);
978 }
979 }
982
983 let table_mut = db.get_table_mut(table_name.to_string())?;
984 table_mut.fts_indexes.push(FtsIndexEntry {
985 name: index_name.to_string(),
986 column_name: column_name.to_string(),
987 index: idx,
988 needs_rebuild: false,
989 });
990 Ok(index_name.to_string())
991}
992
993fn hash_str_to_seed(s: &str) -> u64 {
997 let mut h: u64 = 0xCBF29CE484222325;
998 for b in s.as_bytes() {
999 h ^= *b as u64;
1000 h = h.wrapping_mul(0x100000001B3);
1001 }
1002 h
1003}
1004
1005fn clone_datatype(dt: &DataType) -> DataType {
1008 match dt {
1009 DataType::Integer => DataType::Integer,
1010 DataType::Text => DataType::Text,
1011 DataType::Real => DataType::Real,
1012 DataType::Bool => DataType::Bool,
1013 DataType::Vector(dim) => DataType::Vector(*dim),
1014 DataType::Json => DataType::Json,
1015 DataType::None => DataType::None,
1016 DataType::Invalid => DataType::Invalid,
1017 }
1018}
1019
1020fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
1021 if tables.len() != 1 {
1022 return Err(SQLRiteError::NotImplemented(
1023 "multi-table DELETE is not supported yet".to_string(),
1024 ));
1025 }
1026 extract_table_name(&tables[0])
1027}
1028
1029fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
1030 if !twj.joins.is_empty() {
1031 return Err(SQLRiteError::NotImplemented(
1032 "JOIN is not supported yet".to_string(),
1033 ));
1034 }
1035 match &twj.relation {
1036 TableFactor::Table { name, .. } => Ok(name.to_string()),
1037 _ => Err(SQLRiteError::NotImplemented(
1038 "only plain table references are supported".to_string(),
1039 )),
1040 }
1041}
1042
1043enum RowidSource {
1045 IndexProbe(Vec<i64>),
1049 FullScan,
1052}
1053
1054fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
1059 let Some(expr) = selection else {
1060 return Ok(RowidSource::FullScan);
1061 };
1062 let Some((col, literal)) = try_extract_equality(expr) else {
1063 return Ok(RowidSource::FullScan);
1064 };
1065 let Some(idx) = table.index_for_column(&col) else {
1066 return Ok(RowidSource::FullScan);
1067 };
1068
1069 let literal_value = match convert_literal(&literal) {
1073 Ok(v) => v,
1074 Err(_) => return Ok(RowidSource::FullScan),
1075 };
1076
1077 let mut rowids = idx.lookup(&literal_value);
1081 rowids.sort_unstable();
1082 Ok(RowidSource::IndexProbe(rowids))
1083}
1084
1085fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
1089 let peeled = match expr {
1091 Expr::Nested(inner) => inner.as_ref(),
1092 other => other,
1093 };
1094 let Expr::BinaryOp { left, op, right } = peeled else {
1095 return None;
1096 };
1097 if !matches!(op, BinaryOperator::Eq) {
1098 return None;
1099 }
1100 let col_from = |e: &Expr| -> Option<String> {
1101 match e {
1102 Expr::Identifier(ident) => Some(ident.value.clone()),
1103 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
1104 _ => None,
1105 }
1106 };
1107 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
1108 if let Expr::Value(v) = e {
1109 Some(v.value.clone())
1110 } else {
1111 None
1112 }
1113 };
1114 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
1115 return Some((c, l));
1116 }
1117 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
1118 return Some((c, l));
1119 }
1120 None
1121}
1122
1123fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
1145 if k == 0 {
1146 return None;
1147 }
1148
1149 let func = match order_expr {
1151 Expr::Function(f) => f,
1152 _ => return None,
1153 };
1154 let fname = match func.name.0.as_slice() {
1155 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1156 _ => return None,
1157 };
1158 if fname != "vec_distance_l2" {
1159 return None;
1160 }
1161
1162 let arg_list = match &func.args {
1164 FunctionArguments::List(l) => &l.args,
1165 _ => return None,
1166 };
1167 if arg_list.len() != 2 {
1168 return None;
1169 }
1170 let exprs: Vec<&Expr> = arg_list
1171 .iter()
1172 .filter_map(|a| match a {
1173 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1174 _ => None,
1175 })
1176 .collect();
1177 if exprs.len() != 2 {
1178 return None;
1179 }
1180
1181 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
1186 Some(v) => v,
1187 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
1188 Some(v) => v,
1189 None => return None,
1190 },
1191 };
1192
1193 let entry = table
1195 .hnsw_indexes
1196 .iter()
1197 .find(|e| e.column_name == col_name)?;
1198
1199 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
1205 Some(c) => match &c.datatype {
1206 DataType::Vector(d) => *d,
1207 _ => return None,
1208 },
1209 None => return None,
1210 };
1211 if query_vec.len() != declared_dim {
1212 return None;
1213 }
1214
1215 let column_for_closure = col_name.clone();
1219 let table_ref = table;
1220 let result = entry.index.search(&query_vec, k, |id| {
1221 match table_ref.get_value(&column_for_closure, id) {
1222 Some(Value::Vector(v)) => v,
1223 _ => Vec::new(),
1224 }
1225 });
1226 Some(result)
1227}
1228
1229fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
1245 if k == 0 || ascending {
1246 return None;
1250 }
1251
1252 let func = match order_expr {
1253 Expr::Function(f) => f,
1254 _ => return None,
1255 };
1256 let fname = match func.name.0.as_slice() {
1257 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1258 _ => return None,
1259 };
1260 if fname != "bm25_score" {
1261 return None;
1262 }
1263
1264 let arg_list = match &func.args {
1265 FunctionArguments::List(l) => &l.args,
1266 _ => return None,
1267 };
1268 if arg_list.len() != 2 {
1269 return None;
1270 }
1271 let exprs: Vec<&Expr> = arg_list
1272 .iter()
1273 .filter_map(|a| match a {
1274 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
1275 _ => None,
1276 })
1277 .collect();
1278 if exprs.len() != 2 {
1279 return None;
1280 }
1281
1282 let col_name = match exprs[0] {
1284 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1285 _ => return None,
1286 };
1287
1288 let query = match exprs[1] {
1292 Expr::Value(v) => match &v.value {
1293 AstValue::SingleQuotedString(s) => s.clone(),
1294 _ => return None,
1295 },
1296 _ => return None,
1297 };
1298
1299 let entry = table
1300 .fts_indexes
1301 .iter()
1302 .find(|e| e.column_name == col_name)?;
1303
1304 let scored = entry.index.query(&query, &Bm25Params::default());
1305 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1306 if out.len() > k {
1307 out.truncate(k);
1308 }
1309 Some(out)
1310}
1311
1312fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1317 let col_name = match a {
1318 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1319 _ => return None,
1320 };
1321 let lit_str = match b {
1322 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1323 format!("[{}]", ident.value)
1324 }
1325 _ => return None,
1326 };
1327 let v = parse_vector_literal(&lit_str).ok()?;
1328 Some((col_name, v))
1329}
1330
1331struct HeapEntry {
1344 key: Value,
1345 rowid: i64,
1346 asc: bool,
1347}
1348
1349impl PartialEq for HeapEntry {
1350 fn eq(&self, other: &Self) -> bool {
1351 self.cmp(other) == Ordering::Equal
1352 }
1353}
1354
1355impl Eq for HeapEntry {}
1356
1357impl PartialOrd for HeapEntry {
1358 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1359 Some(self.cmp(other))
1360 }
1361}
1362
1363impl Ord for HeapEntry {
1364 fn cmp(&self, other: &Self) -> Ordering {
1365 let raw = compare_values(Some(&self.key), Some(&other.key));
1366 if self.asc { raw } else { raw.reverse() }
1367 }
1368}
1369
1370fn select_topk(
1379 matching: &[i64],
1380 table: &Table,
1381 order: &OrderByClause,
1382 k: usize,
1383) -> Result<Vec<i64>> {
1384 use std::collections::BinaryHeap;
1385
1386 if k == 0 || matching.is_empty() {
1387 return Ok(Vec::new());
1388 }
1389
1390 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
1391
1392 for &rowid in matching {
1393 let key = eval_expr(&order.expr, table, rowid)?;
1394 let entry = HeapEntry {
1395 key,
1396 rowid,
1397 asc: order.ascending,
1398 };
1399
1400 if heap.len() < k {
1401 heap.push(entry);
1402 } else {
1403 if entry < *heap.peek().unwrap() {
1407 heap.pop();
1408 heap.push(entry);
1409 }
1410 }
1411 }
1412
1413 Ok(heap
1418 .into_sorted_vec()
1419 .into_iter()
1420 .map(|e| e.rowid)
1421 .collect())
1422}
1423
1424fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
1425 let mut keys: Vec<(i64, Result<Value>)> = rowids
1433 .iter()
1434 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
1435 .collect();
1436
1437 for (_, k) in &keys {
1441 if let Err(e) = k {
1442 return Err(SQLRiteError::General(format!(
1443 "ORDER BY expression failed: {e}"
1444 )));
1445 }
1446 }
1447
1448 keys.sort_by(|(_, ka), (_, kb)| {
1449 let va = ka.as_ref().unwrap();
1452 let vb = kb.as_ref().unwrap();
1453 let ord = compare_values(Some(va), Some(vb));
1454 if order.ascending { ord } else { ord.reverse() }
1455 });
1456
1457 for (i, (rowid, _)) in keys.into_iter().enumerate() {
1459 rowids[i] = rowid;
1460 }
1461 Ok(())
1462}
1463
1464fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1465 match (a, b) {
1466 (None, None) => Ordering::Equal,
1467 (None, _) => Ordering::Less,
1468 (_, None) => Ordering::Greater,
1469 (Some(a), Some(b)) => match (a, b) {
1470 (Value::Null, Value::Null) => Ordering::Equal,
1471 (Value::Null, _) => Ordering::Less,
1472 (_, Value::Null) => Ordering::Greater,
1473 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1474 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1475 (Value::Integer(x), Value::Real(y)) => {
1476 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1477 }
1478 (Value::Real(x), Value::Integer(y)) => {
1479 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1480 }
1481 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1482 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1483 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1485 },
1486 }
1487}
1488
1489pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1491 let v = eval_expr(expr, table, rowid)?;
1492 match v {
1493 Value::Bool(b) => Ok(b),
1494 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
1496 other => Err(SQLRiteError::Internal(format!(
1497 "WHERE clause must evaluate to boolean, got {}",
1498 other.to_display_string()
1499 ))),
1500 }
1501}
1502
1503fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1504 match expr {
1505 Expr::Nested(inner) => eval_expr(inner, table, rowid),
1506
1507 Expr::Identifier(ident) => {
1508 if ident.quote_style == Some('[') {
1518 let raw = format!("[{}]", ident.value);
1519 let v = parse_vector_literal(&raw)?;
1520 return Ok(Value::Vector(v));
1521 }
1522 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1523 }
1524
1525 Expr::CompoundIdentifier(parts) => {
1526 let col = parts
1528 .last()
1529 .map(|i| i.value.as_str())
1530 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1531 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1532 }
1533
1534 Expr::Value(v) => convert_literal(&v.value),
1535
1536 Expr::UnaryOp { op, expr } => {
1537 let inner = eval_expr(expr, table, rowid)?;
1538 match op {
1539 UnaryOperator::Not => match inner {
1540 Value::Bool(b) => Ok(Value::Bool(!b)),
1541 Value::Null => Ok(Value::Null),
1542 other => Err(SQLRiteError::Internal(format!(
1543 "NOT applied to non-boolean value: {}",
1544 other.to_display_string()
1545 ))),
1546 },
1547 UnaryOperator::Minus => match inner {
1548 Value::Integer(i) => Ok(Value::Integer(-i)),
1549 Value::Real(f) => Ok(Value::Real(-f)),
1550 Value::Null => Ok(Value::Null),
1551 other => Err(SQLRiteError::Internal(format!(
1552 "unary minus on non-numeric value: {}",
1553 other.to_display_string()
1554 ))),
1555 },
1556 UnaryOperator::Plus => Ok(inner),
1557 other => Err(SQLRiteError::NotImplemented(format!(
1558 "unary operator {other:?} is not supported"
1559 ))),
1560 }
1561 }
1562
1563 Expr::BinaryOp { left, op, right } => match op {
1564 BinaryOperator::And => {
1565 let l = eval_expr(left, table, rowid)?;
1566 let r = eval_expr(right, table, rowid)?;
1567 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1568 }
1569 BinaryOperator::Or => {
1570 let l = eval_expr(left, table, rowid)?;
1571 let r = eval_expr(right, table, rowid)?;
1572 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1573 }
1574 cmp @ (BinaryOperator::Eq
1575 | BinaryOperator::NotEq
1576 | BinaryOperator::Lt
1577 | BinaryOperator::LtEq
1578 | BinaryOperator::Gt
1579 | BinaryOperator::GtEq) => {
1580 let l = eval_expr(left, table, rowid)?;
1581 let r = eval_expr(right, table, rowid)?;
1582 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1584 return Ok(Value::Bool(false));
1585 }
1586 let ord = compare_values(Some(&l), Some(&r));
1587 let result = match cmp {
1588 BinaryOperator::Eq => ord == Ordering::Equal,
1589 BinaryOperator::NotEq => ord != Ordering::Equal,
1590 BinaryOperator::Lt => ord == Ordering::Less,
1591 BinaryOperator::LtEq => ord != Ordering::Greater,
1592 BinaryOperator::Gt => ord == Ordering::Greater,
1593 BinaryOperator::GtEq => ord != Ordering::Less,
1594 _ => unreachable!(),
1595 };
1596 Ok(Value::Bool(result))
1597 }
1598 arith @ (BinaryOperator::Plus
1599 | BinaryOperator::Minus
1600 | BinaryOperator::Multiply
1601 | BinaryOperator::Divide
1602 | BinaryOperator::Modulo) => {
1603 let l = eval_expr(left, table, rowid)?;
1604 let r = eval_expr(right, table, rowid)?;
1605 eval_arith(arith, &l, &r)
1606 }
1607 BinaryOperator::StringConcat => {
1608 let l = eval_expr(left, table, rowid)?;
1609 let r = eval_expr(right, table, rowid)?;
1610 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1611 return Ok(Value::Null);
1612 }
1613 Ok(Value::Text(format!(
1614 "{}{}",
1615 l.to_display_string(),
1616 r.to_display_string()
1617 )))
1618 }
1619 other => Err(SQLRiteError::NotImplemented(format!(
1620 "binary operator {other:?} is not supported yet"
1621 ))),
1622 },
1623
1624 Expr::IsNull(inner) => {
1632 let v = eval_expr(inner, table, rowid)?;
1633 Ok(Value::Bool(matches!(v, Value::Null)))
1634 }
1635 Expr::IsNotNull(inner) => {
1636 let v = eval_expr(inner, table, rowid)?;
1637 Ok(Value::Bool(!matches!(v, Value::Null)))
1638 }
1639
1640 Expr::Function(func) => eval_function(func, table, rowid),
1651
1652 other => Err(SQLRiteError::NotImplemented(format!(
1653 "unsupported expression in WHERE/projection: {other:?}"
1654 ))),
1655 }
1656}
1657
1658fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1663 let name = match func.name.0.as_slice() {
1666 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1667 _ => {
1668 return Err(SQLRiteError::NotImplemented(format!(
1669 "qualified function names not supported: {:?}",
1670 func.name
1671 )));
1672 }
1673 };
1674
1675 match name.as_str() {
1676 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1677 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1678 let dist = match name.as_str() {
1679 "vec_distance_l2" => vec_distance_l2(&a, &b),
1680 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1681 "vec_distance_dot" => vec_distance_dot(&a, &b),
1682 _ => unreachable!(),
1683 };
1684 Ok(Value::Real(dist as f64))
1690 }
1691 "json_extract" => json_fn_extract(&name, &func.args, table, rowid),
1696 "json_type" => json_fn_type(&name, &func.args, table, rowid),
1697 "json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
1698 "json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
1699 "fts_match" => {
1703 let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
1704 Ok(Value::Bool(entry.index.matches(rowid, &query)))
1705 }
1706 "bm25_score" => {
1707 let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
1708 let s = entry.index.score(rowid, &query, &Bm25Params::default());
1709 Ok(Value::Real(s))
1710 }
1711 other => Err(SQLRiteError::NotImplemented(format!(
1712 "unknown function: {other}(...)"
1713 ))),
1714 }
1715}
1716
1717fn resolve_fts_args<'t>(
1722 fn_name: &str,
1723 args: &FunctionArguments,
1724 table: &'t Table,
1725 rowid: i64,
1726) -> Result<(&'t FtsIndexEntry, String)> {
1727 let arg_list = match args {
1728 FunctionArguments::List(l) => &l.args,
1729 _ => {
1730 return Err(SQLRiteError::General(format!(
1731 "{fn_name}() expects exactly two arguments: (column, query_text)"
1732 )));
1733 }
1734 };
1735 if arg_list.len() != 2 {
1736 return Err(SQLRiteError::General(format!(
1737 "{fn_name}() expects exactly 2 arguments, got {}",
1738 arg_list.len()
1739 )));
1740 }
1741
1742 let col_expr = match &arg_list[0] {
1746 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1747 other => {
1748 return Err(SQLRiteError::NotImplemented(format!(
1749 "{fn_name}() argument 0 must be a column name, got {other:?}"
1750 )));
1751 }
1752 };
1753 let col_name = match col_expr {
1754 Expr::Identifier(ident) => ident.value.clone(),
1755 Expr::CompoundIdentifier(parts) => parts
1756 .last()
1757 .map(|p| p.value.clone())
1758 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
1759 other => {
1760 return Err(SQLRiteError::General(format!(
1761 "{fn_name}() argument 0 must be a column reference, got {other:?}"
1762 )));
1763 }
1764 };
1765
1766 let q_expr = match &arg_list[1] {
1770 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1771 other => {
1772 return Err(SQLRiteError::NotImplemented(format!(
1773 "{fn_name}() argument 1 must be a text expression, got {other:?}"
1774 )));
1775 }
1776 };
1777 let query = match eval_expr(q_expr, table, rowid)? {
1778 Value::Text(s) => s,
1779 other => {
1780 return Err(SQLRiteError::General(format!(
1781 "{fn_name}() argument 1 must be TEXT, got {}",
1782 other.to_display_string()
1783 )));
1784 }
1785 };
1786
1787 let entry = table
1788 .fts_indexes
1789 .iter()
1790 .find(|e| e.column_name == col_name)
1791 .ok_or_else(|| {
1792 SQLRiteError::General(format!(
1793 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
1794 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
1795 ))
1796 })?;
1797 Ok((entry, query))
1798}
1799
1800fn extract_json_and_path(
1814 fn_name: &str,
1815 args: &FunctionArguments,
1816 table: &Table,
1817 rowid: i64,
1818) -> Result<(String, String)> {
1819 let arg_list = match args {
1820 FunctionArguments::List(l) => &l.args,
1821 _ => {
1822 return Err(SQLRiteError::General(format!(
1823 "{fn_name}() expects 1 or 2 arguments"
1824 )));
1825 }
1826 };
1827 if !(arg_list.len() == 1 || arg_list.len() == 2) {
1828 return Err(SQLRiteError::General(format!(
1829 "{fn_name}() expects 1 or 2 arguments, got {}",
1830 arg_list.len()
1831 )));
1832 }
1833 let first_expr = match &arg_list[0] {
1835 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1836 other => {
1837 return Err(SQLRiteError::NotImplemented(format!(
1838 "{fn_name}() argument 0 has unsupported shape: {other:?}"
1839 )));
1840 }
1841 };
1842 let json_text = match eval_expr(first_expr, table, rowid)? {
1843 Value::Text(s) => s,
1844 Value::Null => {
1845 return Err(SQLRiteError::General(format!(
1846 "{fn_name}() called on NULL — JSON column has no value for this row"
1847 )));
1848 }
1849 other => {
1850 return Err(SQLRiteError::General(format!(
1851 "{fn_name}() argument 0 is not JSON-typed: got {}",
1852 other.to_display_string()
1853 )));
1854 }
1855 };
1856
1857 let path = if arg_list.len() == 2 {
1859 let path_expr = match &arg_list[1] {
1860 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1861 other => {
1862 return Err(SQLRiteError::NotImplemented(format!(
1863 "{fn_name}() argument 1 has unsupported shape: {other:?}"
1864 )));
1865 }
1866 };
1867 match eval_expr(path_expr, table, rowid)? {
1868 Value::Text(s) => s,
1869 other => {
1870 return Err(SQLRiteError::General(format!(
1871 "{fn_name}() path argument must be a string literal, got {}",
1872 other.to_display_string()
1873 )));
1874 }
1875 }
1876 } else {
1877 "$".to_string()
1878 };
1879
1880 Ok((json_text, path))
1881}
1882
1883fn walk_json_path<'a>(
1893 value: &'a serde_json::Value,
1894 path: &str,
1895) -> Result<Option<&'a serde_json::Value>> {
1896 let mut chars = path.chars().peekable();
1897 if chars.next() != Some('$') {
1898 return Err(SQLRiteError::General(format!(
1899 "JSON path must start with '$', got `{path}`"
1900 )));
1901 }
1902 let mut current = value;
1903 while let Some(&c) = chars.peek() {
1904 match c {
1905 '.' => {
1906 chars.next();
1907 let mut key = String::new();
1908 while let Some(&c) = chars.peek() {
1909 if c == '.' || c == '[' {
1910 break;
1911 }
1912 key.push(c);
1913 chars.next();
1914 }
1915 if key.is_empty() {
1916 return Err(SQLRiteError::General(format!(
1917 "JSON path has empty key after '.' in `{path}`"
1918 )));
1919 }
1920 match current.get(&key) {
1921 Some(v) => current = v,
1922 None => return Ok(None),
1923 }
1924 }
1925 '[' => {
1926 chars.next();
1927 let mut idx_str = String::new();
1928 while let Some(&c) = chars.peek() {
1929 if c == ']' {
1930 break;
1931 }
1932 idx_str.push(c);
1933 chars.next();
1934 }
1935 if chars.next() != Some(']') {
1936 return Err(SQLRiteError::General(format!(
1937 "JSON path has unclosed `[` in `{path}`"
1938 )));
1939 }
1940 let idx: usize = idx_str.trim().parse().map_err(|_| {
1941 SQLRiteError::General(format!(
1942 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
1943 ))
1944 })?;
1945 match current.get(idx) {
1946 Some(v) => current = v,
1947 None => return Ok(None),
1948 }
1949 }
1950 other => {
1951 return Err(SQLRiteError::General(format!(
1952 "JSON path has unexpected character `{other}` in `{path}` \
1953 (expected `.`, `[`, or end-of-path)"
1954 )));
1955 }
1956 }
1957 }
1958 Ok(Some(current))
1959}
1960
1961fn json_value_to_sql(v: &serde_json::Value) -> Value {
1965 match v {
1966 serde_json::Value::Null => Value::Null,
1967 serde_json::Value::Bool(b) => Value::Bool(*b),
1968 serde_json::Value::Number(n) => {
1969 if let Some(i) = n.as_i64() {
1971 Value::Integer(i)
1972 } else if let Some(f) = n.as_f64() {
1973 Value::Real(f)
1974 } else {
1975 Value::Null
1976 }
1977 }
1978 serde_json::Value::String(s) => Value::Text(s.clone()),
1979 composite => Value::Text(composite.to_string()),
1983 }
1984}
1985
1986fn json_fn_extract(
1987 name: &str,
1988 args: &FunctionArguments,
1989 table: &Table,
1990 rowid: i64,
1991) -> Result<Value> {
1992 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1993 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1994 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1995 })?;
1996 match walk_json_path(&parsed, &path)? {
1997 Some(v) => Ok(json_value_to_sql(v)),
1998 None => Ok(Value::Null),
1999 }
2000}
2001
2002fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
2003 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2004 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2005 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2006 })?;
2007 let resolved = match walk_json_path(&parsed, &path)? {
2008 Some(v) => v,
2009 None => return Ok(Value::Null),
2010 };
2011 let ty = match resolved {
2012 serde_json::Value::Null => "null",
2013 serde_json::Value::Bool(true) => "true",
2014 serde_json::Value::Bool(false) => "false",
2015 serde_json::Value::Number(n) => {
2016 if n.is_i64() || n.is_u64() {
2017 "integer"
2018 } else {
2019 "real"
2020 }
2021 }
2022 serde_json::Value::String(_) => "text",
2023 serde_json::Value::Array(_) => "array",
2024 serde_json::Value::Object(_) => "object",
2025 };
2026 Ok(Value::Text(ty.to_string()))
2027}
2028
2029fn json_fn_array_length(
2030 name: &str,
2031 args: &FunctionArguments,
2032 table: &Table,
2033 rowid: i64,
2034) -> Result<Value> {
2035 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2036 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2037 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2038 })?;
2039 let resolved = match walk_json_path(&parsed, &path)? {
2040 Some(v) => v,
2041 None => return Ok(Value::Null),
2042 };
2043 match resolved.as_array() {
2044 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
2045 None => Err(SQLRiteError::General(format!(
2046 "{name}() resolved to a non-array value at path `{path}`"
2047 ))),
2048 }
2049}
2050
2051fn json_fn_object_keys(
2052 name: &str,
2053 args: &FunctionArguments,
2054 table: &Table,
2055 rowid: i64,
2056) -> Result<Value> {
2057 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
2058 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
2059 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
2060 })?;
2061 let resolved = match walk_json_path(&parsed, &path)? {
2062 Some(v) => v,
2063 None => return Ok(Value::Null),
2064 };
2065 let obj = resolved.as_object().ok_or_else(|| {
2066 SQLRiteError::General(format!(
2067 "{name}() resolved to a non-object value at path `{path}`"
2068 ))
2069 })?;
2070 let keys: Vec<serde_json::Value> = obj
2077 .keys()
2078 .map(|k| serde_json::Value::String(k.clone()))
2079 .collect();
2080 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
2081}
2082
2083fn extract_two_vector_args(
2087 fn_name: &str,
2088 args: &FunctionArguments,
2089 table: &Table,
2090 rowid: i64,
2091) -> Result<(Vec<f32>, Vec<f32>)> {
2092 let arg_list = match args {
2093 FunctionArguments::List(l) => &l.args,
2094 _ => {
2095 return Err(SQLRiteError::General(format!(
2096 "{fn_name}() expects exactly two vector arguments"
2097 )));
2098 }
2099 };
2100 if arg_list.len() != 2 {
2101 return Err(SQLRiteError::General(format!(
2102 "{fn_name}() expects exactly 2 arguments, got {}",
2103 arg_list.len()
2104 )));
2105 }
2106 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
2107 for (i, arg) in arg_list.iter().enumerate() {
2108 let expr = match arg {
2109 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
2110 other => {
2111 return Err(SQLRiteError::NotImplemented(format!(
2112 "{fn_name}() argument {i} has unsupported shape: {other:?}"
2113 )));
2114 }
2115 };
2116 let val = eval_expr(expr, table, rowid)?;
2117 match val {
2118 Value::Vector(v) => out.push(v),
2119 other => {
2120 return Err(SQLRiteError::General(format!(
2121 "{fn_name}() argument {i} is not a vector: got {}",
2122 other.to_display_string()
2123 )));
2124 }
2125 }
2126 }
2127 let b = out.pop().unwrap();
2128 let a = out.pop().unwrap();
2129 if a.len() != b.len() {
2130 return Err(SQLRiteError::General(format!(
2131 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
2132 a.len(),
2133 b.len()
2134 )));
2135 }
2136 Ok((a, b))
2137}
2138
2139pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
2142 debug_assert_eq!(a.len(), b.len());
2143 let mut sum = 0.0f32;
2144 for i in 0..a.len() {
2145 let d = a[i] - b[i];
2146 sum += d * d;
2147 }
2148 sum.sqrt()
2149}
2150
2151pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
2161 debug_assert_eq!(a.len(), b.len());
2162 let mut dot = 0.0f32;
2163 let mut norm_a_sq = 0.0f32;
2164 let mut norm_b_sq = 0.0f32;
2165 for i in 0..a.len() {
2166 dot += a[i] * b[i];
2167 norm_a_sq += a[i] * a[i];
2168 norm_b_sq += b[i] * b[i];
2169 }
2170 let denom = (norm_a_sq * norm_b_sq).sqrt();
2171 if denom == 0.0 {
2172 return Err(SQLRiteError::General(
2173 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
2174 ));
2175 }
2176 Ok(1.0 - dot / denom)
2177}
2178
2179pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
2183 debug_assert_eq!(a.len(), b.len());
2184 let mut dot = 0.0f32;
2185 for i in 0..a.len() {
2186 dot += a[i] * b[i];
2187 }
2188 -dot
2189}
2190
2191fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
2194 if matches!(l, Value::Null) || matches!(r, Value::Null) {
2195 return Ok(Value::Null);
2196 }
2197 match (l, r) {
2198 (Value::Integer(a), Value::Integer(b)) => match op {
2199 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
2200 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
2201 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
2202 BinaryOperator::Divide => {
2203 if *b == 0 {
2204 Err(SQLRiteError::General("division by zero".to_string()))
2205 } else {
2206 Ok(Value::Integer(a / b))
2207 }
2208 }
2209 BinaryOperator::Modulo => {
2210 if *b == 0 {
2211 Err(SQLRiteError::General("modulo by zero".to_string()))
2212 } else {
2213 Ok(Value::Integer(a % b))
2214 }
2215 }
2216 _ => unreachable!(),
2217 },
2218 (a, b) => {
2220 let af = as_number(a)?;
2221 let bf = as_number(b)?;
2222 match op {
2223 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
2224 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
2225 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
2226 BinaryOperator::Divide => {
2227 if bf == 0.0 {
2228 Err(SQLRiteError::General("division by zero".to_string()))
2229 } else {
2230 Ok(Value::Real(af / bf))
2231 }
2232 }
2233 BinaryOperator::Modulo => {
2234 if bf == 0.0 {
2235 Err(SQLRiteError::General("modulo by zero".to_string()))
2236 } else {
2237 Ok(Value::Real(af % bf))
2238 }
2239 }
2240 _ => unreachable!(),
2241 }
2242 }
2243 }
2244}
2245
2246fn as_number(v: &Value) -> Result<f64> {
2247 match v {
2248 Value::Integer(i) => Ok(*i as f64),
2249 Value::Real(f) => Ok(*f),
2250 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
2251 other => Err(SQLRiteError::General(format!(
2252 "arithmetic on non-numeric value '{}'",
2253 other.to_display_string()
2254 ))),
2255 }
2256}
2257
2258fn as_bool(v: &Value) -> Result<bool> {
2259 match v {
2260 Value::Bool(b) => Ok(*b),
2261 Value::Null => Ok(false),
2262 Value::Integer(i) => Ok(*i != 0),
2263 other => Err(SQLRiteError::Internal(format!(
2264 "expected boolean, got {}",
2265 other.to_display_string()
2266 ))),
2267 }
2268}
2269
2270fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
2271 use sqlparser::ast::Value as AstValue;
2272 match v {
2273 AstValue::Number(n, _) => {
2274 if let Ok(i) = n.parse::<i64>() {
2275 Ok(Value::Integer(i))
2276 } else if let Ok(f) = n.parse::<f64>() {
2277 Ok(Value::Real(f))
2278 } else {
2279 Err(SQLRiteError::Internal(format!(
2280 "could not parse numeric literal '{n}'"
2281 )))
2282 }
2283 }
2284 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
2285 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
2286 AstValue::Null => Ok(Value::Null),
2287 other => Err(SQLRiteError::NotImplemented(format!(
2288 "unsupported literal value: {other:?}"
2289 ))),
2290 }
2291}
2292
2293#[cfg(test)]
2294mod tests {
2295 use super::*;
2296
2297 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
2304 (a - b).abs() < eps
2305 }
2306
2307 #[test]
2308 fn vec_distance_l2_identical_is_zero() {
2309 let v = vec![0.1, 0.2, 0.3];
2310 assert_eq!(vec_distance_l2(&v, &v), 0.0);
2311 }
2312
2313 #[test]
2314 fn vec_distance_l2_unit_basis_is_sqrt2() {
2315 let a = vec![1.0, 0.0];
2317 let b = vec![0.0, 1.0];
2318 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
2319 }
2320
2321 #[test]
2322 fn vec_distance_l2_known_value() {
2323 let a = vec![0.0, 0.0, 0.0];
2325 let b = vec![3.0, 4.0, 0.0];
2326 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
2327 }
2328
2329 #[test]
2330 fn vec_distance_cosine_identical_is_zero() {
2331 let v = vec![0.1, 0.2, 0.3];
2332 let d = vec_distance_cosine(&v, &v).unwrap();
2333 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
2334 }
2335
2336 #[test]
2337 fn vec_distance_cosine_orthogonal_is_one() {
2338 let a = vec![1.0, 0.0];
2341 let b = vec![0.0, 1.0];
2342 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
2343 }
2344
2345 #[test]
2346 fn vec_distance_cosine_opposite_is_two() {
2347 let a = vec![1.0, 0.0, 0.0];
2349 let b = vec![-1.0, 0.0, 0.0];
2350 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
2351 }
2352
2353 #[test]
2354 fn vec_distance_cosine_zero_magnitude_errors() {
2355 let a = vec![0.0, 0.0];
2357 let b = vec![1.0, 0.0];
2358 let err = vec_distance_cosine(&a, &b).unwrap_err();
2359 assert!(format!("{err}").contains("zero-magnitude"));
2360 }
2361
2362 #[test]
2363 fn vec_distance_dot_negates() {
2364 let a = vec![1.0, 2.0, 3.0];
2366 let b = vec![4.0, 5.0, 6.0];
2367 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
2368 }
2369
2370 #[test]
2371 fn vec_distance_dot_orthogonal_is_zero() {
2372 let a = vec![1.0, 0.0];
2374 let b = vec![0.0, 1.0];
2375 assert_eq!(vec_distance_dot(&a, &b), 0.0);
2376 }
2377
2378 #[test]
2379 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
2380 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
2386 let cos = vec_distance_cosine(&a, &b).unwrap();
2387 assert!(approx_eq(dot, cos - 1.0, 1e-5));
2388 }
2389
2390 use crate::sql::db::database::Database;
2395 use crate::sql::parser::select::SelectQuery;
2396 use sqlparser::dialect::SQLiteDialect;
2397 use sqlparser::parser::Parser;
2398
2399 fn seed_score_table(n: usize) -> Database {
2412 let mut db = Database::new("tempdb".to_string());
2413 crate::sql::process_command(
2414 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
2415 &mut db,
2416 )
2417 .expect("create");
2418 for i in 0..n {
2419 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
2423 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
2424 crate::sql::process_command(&sql, &mut db).expect("insert");
2425 }
2426 db
2427 }
2428
2429 fn parse_select(sql: &str) -> SelectQuery {
2433 let dialect = SQLiteDialect {};
2434 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
2435 let stmt = ast.pop().expect("one statement");
2436 SelectQuery::new(&stmt).expect("select-query")
2437 }
2438
2439 #[test]
2440 fn topk_matches_full_sort_asc() {
2441 let db = seed_score_table(200);
2444 let table = db.get_table("docs".to_string()).unwrap();
2445 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2446 let order = q.order_by.as_ref().unwrap();
2447 let all_rowids = table.rowids();
2448
2449 let mut full = all_rowids.clone();
2451 sort_rowids(&mut full, table, order).unwrap();
2452 full.truncate(10);
2453
2454 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
2456
2457 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
2458 }
2459
2460 #[test]
2461 fn topk_matches_full_sort_desc() {
2462 let db = seed_score_table(200);
2464 let table = db.get_table("docs".to_string()).unwrap();
2465 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
2466 let order = q.order_by.as_ref().unwrap();
2467 let all_rowids = table.rowids();
2468
2469 let mut full = all_rowids.clone();
2470 sort_rowids(&mut full, table, order).unwrap();
2471 full.truncate(10);
2472
2473 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
2474
2475 assert_eq!(
2476 topk, full,
2477 "top-k DESC via heap should match full-sort+truncate"
2478 );
2479 }
2480
2481 #[test]
2482 fn topk_k_larger_than_n_returns_everything_sorted() {
2483 let db = seed_score_table(50);
2488 let table = db.get_table("docs".to_string()).unwrap();
2489 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
2490 let order = q.order_by.as_ref().unwrap();
2491 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
2492 assert_eq!(topk.len(), 50);
2493 let scores: Vec<f64> = topk
2495 .iter()
2496 .filter_map(|r| match table.get_value("score", *r) {
2497 Some(Value::Real(f)) => Some(f),
2498 _ => None,
2499 })
2500 .collect();
2501 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
2502 }
2503
2504 #[test]
2505 fn topk_k_zero_returns_empty() {
2506 let db = seed_score_table(10);
2507 let table = db.get_table("docs".to_string()).unwrap();
2508 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
2509 let order = q.order_by.as_ref().unwrap();
2510 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
2511 assert!(topk.is_empty());
2512 }
2513
2514 #[test]
2515 fn topk_empty_input_returns_empty() {
2516 let db = seed_score_table(0);
2517 let table = db.get_table("docs".to_string()).unwrap();
2518 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
2519 let order = q.order_by.as_ref().unwrap();
2520 let topk = select_topk(&[], table, order, 5).unwrap();
2521 assert!(topk.is_empty());
2522 }
2523
2524 #[test]
2525 fn topk_works_through_select_executor_with_distance_function() {
2526 let mut db = Database::new("tempdb".to_string());
2530 crate::sql::process_command(
2531 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
2532 &mut db,
2533 )
2534 .unwrap();
2535 for v in &[
2542 "[1.0, 0.0]",
2543 "[2.0, 0.0]",
2544 "[0.0, 3.0]",
2545 "[1.0, 4.0]",
2546 "[10.0, 10.0]",
2547 ] {
2548 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
2549 .unwrap();
2550 }
2551 let resp = crate::sql::process_command(
2552 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
2553 &mut db,
2554 )
2555 .unwrap();
2556 assert!(resp.contains("3 rows returned"), "got: {resp}");
2559 }
2560
2561 #[test]
2584 #[ignore]
2585 fn topk_benchmark() {
2586 use std::time::Instant;
2587 const N: usize = 10_000;
2588 const K: usize = 10;
2589
2590 let db = seed_score_table(N);
2591 let table = db.get_table("docs".to_string()).unwrap();
2592 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2593 let order = q.order_by.as_ref().unwrap();
2594 let all_rowids = table.rowids();
2595
2596 let t0 = Instant::now();
2598 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
2599 let heap_dur = t0.elapsed();
2600
2601 let t1 = Instant::now();
2603 let mut full = all_rowids.clone();
2604 sort_rowids(&mut full, table, order).unwrap();
2605 full.truncate(K);
2606 let sort_dur = t1.elapsed();
2607
2608 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
2609 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
2610 println!(" bounded heap: {heap_dur:?}");
2611 println!(" full sort+trunc: {sort_dur:?}");
2612 println!(" speedup ratio: {ratio:.2}×");
2613
2614 assert!(
2621 ratio > 1.4,
2622 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
2623 );
2624 }
2625
2626 fn run_select(db: &mut Database, sql: &str) -> String {
2634 crate::sql::process_command(sql, db).expect("select")
2635 }
2636
2637 #[test]
2638 fn where_is_null_returns_null_rows() {
2639 let mut db = Database::new("t".to_string());
2640 crate::sql::process_command(
2641 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
2642 &mut db,
2643 )
2644 .unwrap();
2645 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
2646 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
2647 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
2648 crate::sql::process_command("INSERT INTO t (id, n) VALUES (4, NULL);", &mut db).unwrap();
2649
2650 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL;");
2651 assert!(
2652 response.contains("2 rows returned"),
2653 "IS NULL should return 2 rows, got: {response}"
2654 );
2655 }
2656
2657 #[test]
2658 fn where_is_not_null_returns_non_null_rows() {
2659 let mut db = Database::new("t".to_string());
2660 crate::sql::process_command(
2661 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
2662 &mut db,
2663 )
2664 .unwrap();
2665 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, 10);", &mut db).unwrap();
2666 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
2667 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
2668
2669 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NOT NULL;");
2670 assert!(
2671 response.contains("2 rows returned"),
2672 "IS NOT NULL should return 2 rows, got: {response}"
2673 );
2674 }
2675
2676 #[test]
2677 fn where_is_null_on_indexed_column() {
2678 let mut db = Database::new("t".to_string());
2683 crate::sql::process_command(
2684 "CREATE TABLE t (id INTEGER PRIMARY KEY, name TEXT UNIQUE);",
2685 &mut db,
2686 )
2687 .unwrap();
2688 crate::sql::process_command("INSERT INTO t (id, name) VALUES (1, 'alice');", &mut db)
2689 .unwrap();
2690 crate::sql::process_command("INSERT INTO t (id, name) VALUES (2, NULL);", &mut db).unwrap();
2691 crate::sql::process_command("INSERT INTO t (id, name) VALUES (3, 'bob');", &mut db)
2692 .unwrap();
2693
2694 let null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NULL;");
2695 assert!(
2696 null_rows.contains("1 row returned"),
2697 "indexed IS NULL should return 1 row, got: {null_rows}"
2698 );
2699 let not_null_rows = run_select(&mut db, "SELECT id FROM t WHERE name IS NOT NULL;");
2700 assert!(
2701 not_null_rows.contains("2 rows returned"),
2702 "indexed IS NOT NULL should return 2 rows, got: {not_null_rows}"
2703 );
2704 }
2705
2706 #[test]
2707 fn where_is_null_works_on_omitted_column() {
2708 let mut db = Database::new("t".to_string());
2712 crate::sql::process_command(
2713 "CREATE TABLE t (id INTEGER PRIMARY KEY, qty INTEGER, label TEXT);",
2714 &mut db,
2715 )
2716 .unwrap();
2717 crate::sql::process_command(
2718 "INSERT INTO t (id, qty, label) VALUES (1, 7, 'a');",
2719 &mut db,
2720 )
2721 .unwrap();
2722 crate::sql::process_command("INSERT INTO t (id, label) VALUES (2, 'b');", &mut db).unwrap();
2724
2725 let response = run_select(&mut db, "SELECT id FROM t WHERE qty IS NULL;");
2726 assert!(
2727 response.contains("1 row returned"),
2728 "IS NULL should match the omitted-column row, got: {response}"
2729 );
2730 }
2731
2732 #[test]
2733 fn where_is_null_combines_with_and_or() {
2734 let mut db = Database::new("t".to_string());
2738 crate::sql::process_command(
2739 "CREATE TABLE t (id INTEGER PRIMARY KEY, n INTEGER);",
2740 &mut db,
2741 )
2742 .unwrap();
2743 crate::sql::process_command("INSERT INTO t (id, n) VALUES (1, NULL);", &mut db).unwrap();
2744 crate::sql::process_command("INSERT INTO t (id, n) VALUES (2, NULL);", &mut db).unwrap();
2745 crate::sql::process_command("INSERT INTO t (id, n) VALUES (3, 30);", &mut db).unwrap();
2746
2747 let response = run_select(&mut db, "SELECT id FROM t WHERE n IS NULL AND id > 1;");
2748 assert!(
2749 response.contains("1 row returned"),
2750 "IS NULL combined with AND should match exactly row 2, got: {response}"
2751 );
2752 }
2753}