1use std::cmp::Ordering;
5
6use prettytable::{Cell as PrintCell, Row as PrintRow, Table as PrintTable};
7use sqlparser::ast::{
8 AssignmentTarget, BinaryOperator, CreateIndex, Delete, Expr, FromTable, FunctionArg,
9 FunctionArgExpr, FunctionArguments, IndexType, ObjectNamePart, Statement, TableFactor,
10 TableWithJoins, UnaryOperator, Update, Value as AstValue,
11};
12
13use crate::error::{Result, SQLRiteError};
14use crate::sql::db::database::Database;
15use crate::sql::db::secondary_index::{IndexOrigin, SecondaryIndex};
16use crate::sql::db::table::{
17 DataType, FtsIndexEntry, HnswIndexEntry, Table, Value, parse_vector_literal,
18};
19use crate::sql::fts::{Bm25Params, PostingList};
20use crate::sql::hnsw::{DistanceMetric, HnswIndex};
21use crate::sql::parser::select::{OrderByClause, Projection, SelectQuery};
22
23pub struct SelectResult {
32 pub columns: Vec<String>,
33 pub rows: Vec<Vec<Value>>,
34}
35
36pub fn execute_select_rows(query: SelectQuery, db: &Database) -> Result<SelectResult> {
40 let table = db
41 .get_table(query.table_name.clone())
42 .map_err(|_| SQLRiteError::Internal(format!("Table '{}' not found", query.table_name)))?;
43
44 let projected_cols: Vec<String> = match &query.projection {
46 Projection::All => table.column_names(),
47 Projection::Columns(cols) => {
48 for c in cols {
49 if !table.contains_column(c.to_string()) {
50 return Err(SQLRiteError::Internal(format!(
51 "Column '{c}' does not exist on table '{}'",
52 query.table_name
53 )));
54 }
55 }
56 cols.clone()
57 }
58 };
59
60 let matching = match select_rowids(table, query.selection.as_ref())? {
64 RowidSource::IndexProbe(rowids) => rowids,
65 RowidSource::FullScan => {
66 let mut out = Vec::new();
67 for rowid in table.rowids() {
68 if let Some(expr) = &query.selection {
69 if !eval_predicate(expr, table, rowid)? {
70 continue;
71 }
72 }
73 out.push(rowid);
74 }
75 out
76 }
77 };
78 let mut matching = matching;
79
80 match (&query.order_by, query.limit) {
110 (Some(order), Some(k)) if try_hnsw_probe(table, &order.expr, k).is_some() => {
111 matching = try_hnsw_probe(table, &order.expr, k).unwrap();
112 }
113 (Some(order), Some(k))
114 if try_fts_probe(table, &order.expr, order.ascending, k).is_some() =>
115 {
116 matching = try_fts_probe(table, &order.expr, order.ascending, k).unwrap();
117 }
118 (Some(order), Some(k)) if k < matching.len() => {
119 matching = select_topk(&matching, table, order, k)?;
120 }
121 (Some(order), _) => {
122 sort_rowids(&mut matching, table, order)?;
123 if let Some(k) = query.limit {
124 matching.truncate(k);
125 }
126 }
127 (None, Some(k)) => {
128 matching.truncate(k);
129 }
130 (None, None) => {}
131 }
132
133 let mut rows: Vec<Vec<Value>> = Vec::with_capacity(matching.len());
137 for rowid in &matching {
138 let row: Vec<Value> = projected_cols
139 .iter()
140 .map(|col| table.get_value(col, *rowid).unwrap_or(Value::Null))
141 .collect();
142 rows.push(row);
143 }
144
145 Ok(SelectResult {
146 columns: projected_cols,
147 rows,
148 })
149}
150
151pub fn execute_select(query: SelectQuery, db: &Database) -> Result<(String, usize)> {
156 let result = execute_select_rows(query, db)?;
157 let row_count = result.rows.len();
158
159 let mut print_table = PrintTable::new();
160 let header_cells: Vec<PrintCell> = result.columns.iter().map(|c| PrintCell::new(c)).collect();
161 print_table.add_row(PrintRow::new(header_cells));
162
163 for row in &result.rows {
164 let cells: Vec<PrintCell> = row
165 .iter()
166 .map(|v| PrintCell::new(&v.to_display_string()))
167 .collect();
168 print_table.add_row(PrintRow::new(cells));
169 }
170
171 Ok((print_table.to_string(), row_count))
172}
173
174pub fn execute_delete(stmt: &Statement, db: &mut Database) -> Result<usize> {
176 let Statement::Delete(Delete {
177 from, selection, ..
178 }) = stmt
179 else {
180 return Err(SQLRiteError::Internal(
181 "execute_delete called on a non-DELETE statement".to_string(),
182 ));
183 };
184
185 let tables = match from {
186 FromTable::WithFromKeyword(t) | FromTable::WithoutKeyword(t) => t,
187 };
188 let table_name = extract_single_table_name(tables)?;
189
190 let matching: Vec<i64> = {
192 let table = db
193 .get_table(table_name.clone())
194 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
195 match select_rowids(table, selection.as_ref())? {
196 RowidSource::IndexProbe(rowids) => rowids,
197 RowidSource::FullScan => {
198 let mut out = Vec::new();
199 for rowid in table.rowids() {
200 if let Some(expr) = selection {
201 if !eval_predicate(expr, table, rowid)? {
202 continue;
203 }
204 }
205 out.push(rowid);
206 }
207 out
208 }
209 }
210 };
211
212 let table = db.get_table_mut(table_name)?;
213 for rowid in &matching {
214 table.delete_row(*rowid);
215 }
216 if !matching.is_empty() {
225 for entry in &mut table.hnsw_indexes {
226 entry.needs_rebuild = true;
227 }
228 for entry in &mut table.fts_indexes {
229 entry.needs_rebuild = true;
230 }
231 }
232 Ok(matching.len())
233}
234
235pub fn execute_update(stmt: &Statement, db: &mut Database) -> Result<usize> {
237 let Statement::Update(Update {
238 table,
239 assignments,
240 from,
241 selection,
242 ..
243 }) = stmt
244 else {
245 return Err(SQLRiteError::Internal(
246 "execute_update called on a non-UPDATE statement".to_string(),
247 ));
248 };
249
250 if from.is_some() {
251 return Err(SQLRiteError::NotImplemented(
252 "UPDATE ... FROM is not supported yet".to_string(),
253 ));
254 }
255
256 let table_name = extract_table_name(table)?;
257
258 let mut parsed_assignments: Vec<(String, Expr)> = Vec::with_capacity(assignments.len());
260 {
261 let tbl = db
262 .get_table(table_name.clone())
263 .map_err(|_| SQLRiteError::Internal(format!("Table '{table_name}' not found")))?;
264 for a in assignments {
265 let col = match &a.target {
266 AssignmentTarget::ColumnName(name) => name
267 .0
268 .last()
269 .map(|p| p.to_string())
270 .ok_or_else(|| SQLRiteError::Internal("empty column name".to_string()))?,
271 AssignmentTarget::Tuple(_) => {
272 return Err(SQLRiteError::NotImplemented(
273 "tuple assignment targets are not supported".to_string(),
274 ));
275 }
276 };
277 if !tbl.contains_column(col.clone()) {
278 return Err(SQLRiteError::Internal(format!(
279 "UPDATE references unknown column '{col}'"
280 )));
281 }
282 parsed_assignments.push((col, a.value.clone()));
283 }
284 }
285
286 let work: Vec<(i64, Vec<(String, Value)>)> = {
290 let tbl = db.get_table(table_name.clone())?;
291 let matched_rowids: Vec<i64> = match select_rowids(tbl, selection.as_ref())? {
292 RowidSource::IndexProbe(rowids) => rowids,
293 RowidSource::FullScan => {
294 let mut out = Vec::new();
295 for rowid in tbl.rowids() {
296 if let Some(expr) = selection {
297 if !eval_predicate(expr, tbl, rowid)? {
298 continue;
299 }
300 }
301 out.push(rowid);
302 }
303 out
304 }
305 };
306 let mut rows_to_update = Vec::new();
307 for rowid in matched_rowids {
308 let mut values = Vec::with_capacity(parsed_assignments.len());
309 for (col, expr) in &parsed_assignments {
310 let v = eval_expr(expr, tbl, rowid)?;
313 values.push((col.clone(), v));
314 }
315 rows_to_update.push((rowid, values));
316 }
317 rows_to_update
318 };
319
320 let tbl = db.get_table_mut(table_name)?;
321 for (rowid, values) in &work {
322 for (col, v) in values {
323 tbl.set_value(col, *rowid, v.clone())?;
324 }
325 }
326
327 if !work.is_empty() {
336 let updated_columns: std::collections::HashSet<&str> = work
337 .iter()
338 .flat_map(|(_, values)| values.iter().map(|(c, _)| c.as_str()))
339 .collect();
340 for entry in &mut tbl.hnsw_indexes {
341 if updated_columns.contains(entry.column_name.as_str()) {
342 entry.needs_rebuild = true;
343 }
344 }
345 for entry in &mut tbl.fts_indexes {
346 if updated_columns.contains(entry.column_name.as_str()) {
347 entry.needs_rebuild = true;
348 }
349 }
350 }
351 Ok(work.len())
352}
353
354pub fn execute_create_index(stmt: &Statement, db: &mut Database) -> Result<String> {
366 let Statement::CreateIndex(CreateIndex {
367 name,
368 table_name,
369 columns,
370 using,
371 unique,
372 if_not_exists,
373 predicate,
374 ..
375 }) = stmt
376 else {
377 return Err(SQLRiteError::Internal(
378 "execute_create_index called on a non-CREATE-INDEX statement".to_string(),
379 ));
380 };
381
382 if predicate.is_some() {
383 return Err(SQLRiteError::NotImplemented(
384 "partial indexes (CREATE INDEX ... WHERE) are not supported yet".to_string(),
385 ));
386 }
387
388 if columns.len() != 1 {
389 return Err(SQLRiteError::NotImplemented(format!(
390 "multi-column indexes are not supported yet ({} columns given)",
391 columns.len()
392 )));
393 }
394
395 let index_name = name.as_ref().map(|n| n.to_string()).ok_or_else(|| {
396 SQLRiteError::NotImplemented(
397 "anonymous CREATE INDEX (no name) is not supported — give it a name".to_string(),
398 )
399 })?;
400
401 let method = match using {
407 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("hnsw") => {
408 IndexMethod::Hnsw
409 }
410 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("fts") => {
411 IndexMethod::Fts
412 }
413 Some(IndexType::Custom(ident)) if ident.value.eq_ignore_ascii_case("btree") => {
414 IndexMethod::Btree
415 }
416 Some(other) => {
417 return Err(SQLRiteError::NotImplemented(format!(
418 "CREATE INDEX … USING {other:?} is not supported \
419 (try `hnsw`, `fts`, or no USING clause)"
420 )));
421 }
422 None => IndexMethod::Btree,
423 };
424
425 let table_name_str = table_name.to_string();
426 let column_name = match &columns[0].column.expr {
427 Expr::Identifier(ident) => ident.value.clone(),
428 Expr::CompoundIdentifier(parts) => parts
429 .last()
430 .map(|p| p.value.clone())
431 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
432 other => {
433 return Err(SQLRiteError::NotImplemented(format!(
434 "CREATE INDEX only supports simple column references, got {other:?}"
435 )));
436 }
437 };
438
439 let (datatype, existing_rowids_and_values): (DataType, Vec<(i64, Value)>) = {
444 let table = db.get_table(table_name_str.clone()).map_err(|_| {
445 SQLRiteError::General(format!(
446 "CREATE INDEX references unknown table '{table_name_str}'"
447 ))
448 })?;
449 if !table.contains_column(column_name.clone()) {
450 return Err(SQLRiteError::General(format!(
451 "CREATE INDEX references unknown column '{column_name}' on table '{table_name_str}'"
452 )));
453 }
454 let col = table
455 .columns
456 .iter()
457 .find(|c| c.column_name == column_name)
458 .expect("we just verified the column exists");
459
460 if table.index_by_name(&index_name).is_some()
463 || table.hnsw_indexes.iter().any(|i| i.name == index_name)
464 || table.fts_indexes.iter().any(|i| i.name == index_name)
465 {
466 if *if_not_exists {
467 return Ok(index_name);
468 }
469 return Err(SQLRiteError::General(format!(
470 "index '{index_name}' already exists"
471 )));
472 }
473 let datatype = clone_datatype(&col.datatype);
474
475 let mut pairs = Vec::new();
476 for rowid in table.rowids() {
477 if let Some(v) = table.get_value(&column_name, rowid) {
478 pairs.push((rowid, v));
479 }
480 }
481 (datatype, pairs)
482 };
483
484 match method {
485 IndexMethod::Btree => create_btree_index(
486 db,
487 &table_name_str,
488 &index_name,
489 &column_name,
490 &datatype,
491 *unique,
492 &existing_rowids_and_values,
493 ),
494 IndexMethod::Hnsw => create_hnsw_index(
495 db,
496 &table_name_str,
497 &index_name,
498 &column_name,
499 &datatype,
500 *unique,
501 &existing_rowids_and_values,
502 ),
503 IndexMethod::Fts => create_fts_index(
504 db,
505 &table_name_str,
506 &index_name,
507 &column_name,
508 &datatype,
509 *unique,
510 &existing_rowids_and_values,
511 ),
512 }
513}
514
515#[derive(Debug, Clone, Copy)]
519enum IndexMethod {
520 Btree,
521 Hnsw,
522 Fts,
524}
525
526fn create_btree_index(
528 db: &mut Database,
529 table_name: &str,
530 index_name: &str,
531 column_name: &str,
532 datatype: &DataType,
533 unique: bool,
534 existing: &[(i64, Value)],
535) -> Result<String> {
536 let mut idx = SecondaryIndex::new(
537 index_name.to_string(),
538 table_name.to_string(),
539 column_name.to_string(),
540 datatype,
541 unique,
542 IndexOrigin::Explicit,
543 )?;
544
545 for (rowid, v) in existing {
549 if unique && idx.would_violate_unique(v) {
550 return Err(SQLRiteError::General(format!(
551 "cannot create UNIQUE index '{index_name}': column '{column_name}' \
552 already contains the duplicate value {}",
553 v.to_display_string()
554 )));
555 }
556 idx.insert(v, *rowid)?;
557 }
558
559 let table_mut = db.get_table_mut(table_name.to_string())?;
560 table_mut.secondary_indexes.push(idx);
561 Ok(index_name.to_string())
562}
563
564fn create_hnsw_index(
566 db: &mut Database,
567 table_name: &str,
568 index_name: &str,
569 column_name: &str,
570 datatype: &DataType,
571 unique: bool,
572 existing: &[(i64, Value)],
573) -> Result<String> {
574 let dim = match datatype {
577 DataType::Vector(d) => *d,
578 other => {
579 return Err(SQLRiteError::General(format!(
580 "USING hnsw requires a VECTOR column; '{column_name}' is {other}"
581 )));
582 }
583 };
584
585 if unique {
586 return Err(SQLRiteError::General(
587 "UNIQUE has no meaning for HNSW indexes".to_string(),
588 ));
589 }
590
591 let seed = hash_str_to_seed(index_name);
599 let mut idx = HnswIndex::new(DistanceMetric::L2, seed);
600
601 let mut vec_map: std::collections::HashMap<i64, Vec<f32>> =
605 std::collections::HashMap::with_capacity(existing.len());
606 for (rowid, v) in existing {
607 match v {
608 Value::Vector(vec) => {
609 if vec.len() != dim {
610 return Err(SQLRiteError::Internal(format!(
611 "row {rowid} stores a {}-dim vector in column '{column_name}' \
612 declared as VECTOR({dim}) — schema invariant violated",
613 vec.len()
614 )));
615 }
616 vec_map.insert(*rowid, vec.clone());
617 }
618 _ => continue,
622 }
623 }
624
625 for (rowid, _) in existing {
626 if let Some(v) = vec_map.get(rowid) {
627 let v_clone = v.clone();
628 idx.insert(*rowid, &v_clone, |id| {
629 vec_map.get(&id).cloned().unwrap_or_default()
630 });
631 }
632 }
633
634 let table_mut = db.get_table_mut(table_name.to_string())?;
635 table_mut.hnsw_indexes.push(HnswIndexEntry {
636 name: index_name.to_string(),
637 column_name: column_name.to_string(),
638 index: idx,
639 needs_rebuild: false,
641 });
642 Ok(index_name.to_string())
643}
644
645fn create_fts_index(
650 db: &mut Database,
651 table_name: &str,
652 index_name: &str,
653 column_name: &str,
654 datatype: &DataType,
655 unique: bool,
656 existing: &[(i64, Value)],
657) -> Result<String> {
658 match datatype {
663 DataType::Text => {}
664 other => {
665 return Err(SQLRiteError::General(format!(
666 "USING fts requires a TEXT column; '{column_name}' is {other}"
667 )));
668 }
669 }
670
671 if unique {
672 return Err(SQLRiteError::General(
673 "UNIQUE has no meaning for FTS indexes".to_string(),
674 ));
675 }
676
677 let mut idx = PostingList::new();
678 for (rowid, v) in existing {
679 if let Value::Text(text) = v {
680 idx.insert(*rowid, text);
681 }
682 }
685
686 let table_mut = db.get_table_mut(table_name.to_string())?;
687 table_mut.fts_indexes.push(FtsIndexEntry {
688 name: index_name.to_string(),
689 column_name: column_name.to_string(),
690 index: idx,
691 needs_rebuild: false,
692 });
693 Ok(index_name.to_string())
694}
695
696fn hash_str_to_seed(s: &str) -> u64 {
700 let mut h: u64 = 0xCBF29CE484222325;
701 for b in s.as_bytes() {
702 h ^= *b as u64;
703 h = h.wrapping_mul(0x100000001B3);
704 }
705 h
706}
707
708fn clone_datatype(dt: &DataType) -> DataType {
711 match dt {
712 DataType::Integer => DataType::Integer,
713 DataType::Text => DataType::Text,
714 DataType::Real => DataType::Real,
715 DataType::Bool => DataType::Bool,
716 DataType::Vector(dim) => DataType::Vector(*dim),
717 DataType::Json => DataType::Json,
718 DataType::None => DataType::None,
719 DataType::Invalid => DataType::Invalid,
720 }
721}
722
723fn extract_single_table_name(tables: &[TableWithJoins]) -> Result<String> {
724 if tables.len() != 1 {
725 return Err(SQLRiteError::NotImplemented(
726 "multi-table DELETE is not supported yet".to_string(),
727 ));
728 }
729 extract_table_name(&tables[0])
730}
731
732fn extract_table_name(twj: &TableWithJoins) -> Result<String> {
733 if !twj.joins.is_empty() {
734 return Err(SQLRiteError::NotImplemented(
735 "JOIN is not supported yet".to_string(),
736 ));
737 }
738 match &twj.relation {
739 TableFactor::Table { name, .. } => Ok(name.to_string()),
740 _ => Err(SQLRiteError::NotImplemented(
741 "only plain table references are supported".to_string(),
742 )),
743 }
744}
745
746enum RowidSource {
748 IndexProbe(Vec<i64>),
752 FullScan,
755}
756
757fn select_rowids(table: &Table, selection: Option<&Expr>) -> Result<RowidSource> {
762 let Some(expr) = selection else {
763 return Ok(RowidSource::FullScan);
764 };
765 let Some((col, literal)) = try_extract_equality(expr) else {
766 return Ok(RowidSource::FullScan);
767 };
768 let Some(idx) = table.index_for_column(&col) else {
769 return Ok(RowidSource::FullScan);
770 };
771
772 let literal_value = match convert_literal(&literal) {
776 Ok(v) => v,
777 Err(_) => return Ok(RowidSource::FullScan),
778 };
779
780 let mut rowids = idx.lookup(&literal_value);
784 rowids.sort_unstable();
785 Ok(RowidSource::IndexProbe(rowids))
786}
787
788fn try_extract_equality(expr: &Expr) -> Option<(String, sqlparser::ast::Value)> {
792 let peeled = match expr {
794 Expr::Nested(inner) => inner.as_ref(),
795 other => other,
796 };
797 let Expr::BinaryOp { left, op, right } = peeled else {
798 return None;
799 };
800 if !matches!(op, BinaryOperator::Eq) {
801 return None;
802 }
803 let col_from = |e: &Expr| -> Option<String> {
804 match e {
805 Expr::Identifier(ident) => Some(ident.value.clone()),
806 Expr::CompoundIdentifier(parts) => parts.last().map(|p| p.value.clone()),
807 _ => None,
808 }
809 };
810 let literal_from = |e: &Expr| -> Option<sqlparser::ast::Value> {
811 if let Expr::Value(v) = e {
812 Some(v.value.clone())
813 } else {
814 None
815 }
816 };
817 if let (Some(c), Some(l)) = (col_from(left), literal_from(right)) {
818 return Some((c, l));
819 }
820 if let (Some(l), Some(c)) = (literal_from(left), col_from(right)) {
821 return Some((c, l));
822 }
823 None
824}
825
826fn try_hnsw_probe(table: &Table, order_expr: &Expr, k: usize) -> Option<Vec<i64>> {
848 if k == 0 {
849 return None;
850 }
851
852 let func = match order_expr {
854 Expr::Function(f) => f,
855 _ => return None,
856 };
857 let fname = match func.name.0.as_slice() {
858 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
859 _ => return None,
860 };
861 if fname != "vec_distance_l2" {
862 return None;
863 }
864
865 let arg_list = match &func.args {
867 FunctionArguments::List(l) => &l.args,
868 _ => return None,
869 };
870 if arg_list.len() != 2 {
871 return None;
872 }
873 let exprs: Vec<&Expr> = arg_list
874 .iter()
875 .filter_map(|a| match a {
876 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
877 _ => None,
878 })
879 .collect();
880 if exprs.len() != 2 {
881 return None;
882 }
883
884 let (col_name, query_vec) = match identify_indexed_arg_and_literal(exprs[0], exprs[1]) {
889 Some(v) => v,
890 None => match identify_indexed_arg_and_literal(exprs[1], exprs[0]) {
891 Some(v) => v,
892 None => return None,
893 },
894 };
895
896 let entry = table
898 .hnsw_indexes
899 .iter()
900 .find(|e| e.column_name == col_name)?;
901
902 let declared_dim = match table.columns.iter().find(|c| c.column_name == col_name) {
908 Some(c) => match &c.datatype {
909 DataType::Vector(d) => *d,
910 _ => return None,
911 },
912 None => return None,
913 };
914 if query_vec.len() != declared_dim {
915 return None;
916 }
917
918 let column_for_closure = col_name.clone();
922 let table_ref = table;
923 let result = entry.index.search(&query_vec, k, |id| {
924 match table_ref.get_value(&column_for_closure, id) {
925 Some(Value::Vector(v)) => v,
926 _ => Vec::new(),
927 }
928 });
929 Some(result)
930}
931
932fn try_fts_probe(table: &Table, order_expr: &Expr, ascending: bool, k: usize) -> Option<Vec<i64>> {
948 if k == 0 || ascending {
949 return None;
953 }
954
955 let func = match order_expr {
956 Expr::Function(f) => f,
957 _ => return None,
958 };
959 let fname = match func.name.0.as_slice() {
960 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
961 _ => return None,
962 };
963 if fname != "bm25_score" {
964 return None;
965 }
966
967 let arg_list = match &func.args {
968 FunctionArguments::List(l) => &l.args,
969 _ => return None,
970 };
971 if arg_list.len() != 2 {
972 return None;
973 }
974 let exprs: Vec<&Expr> = arg_list
975 .iter()
976 .filter_map(|a| match a {
977 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => Some(e),
978 _ => None,
979 })
980 .collect();
981 if exprs.len() != 2 {
982 return None;
983 }
984
985 let col_name = match exprs[0] {
987 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
988 _ => return None,
989 };
990
991 let query = match exprs[1] {
995 Expr::Value(v) => match &v.value {
996 AstValue::SingleQuotedString(s) => s.clone(),
997 _ => return None,
998 },
999 _ => return None,
1000 };
1001
1002 let entry = table
1003 .fts_indexes
1004 .iter()
1005 .find(|e| e.column_name == col_name)?;
1006
1007 let scored = entry.index.query(&query, &Bm25Params::default());
1008 let mut out: Vec<i64> = scored.into_iter().map(|(id, _)| id).collect();
1009 if out.len() > k {
1010 out.truncate(k);
1011 }
1012 Some(out)
1013}
1014
1015fn identify_indexed_arg_and_literal(a: &Expr, b: &Expr) -> Option<(String, Vec<f32>)> {
1020 let col_name = match a {
1021 Expr::Identifier(ident) if ident.quote_style.is_none() => ident.value.clone(),
1022 _ => return None,
1023 };
1024 let lit_str = match b {
1025 Expr::Identifier(ident) if ident.quote_style == Some('[') => {
1026 format!("[{}]", ident.value)
1027 }
1028 _ => return None,
1029 };
1030 let v = parse_vector_literal(&lit_str).ok()?;
1031 Some((col_name, v))
1032}
1033
1034struct HeapEntry {
1047 key: Value,
1048 rowid: i64,
1049 asc: bool,
1050}
1051
1052impl PartialEq for HeapEntry {
1053 fn eq(&self, other: &Self) -> bool {
1054 self.cmp(other) == Ordering::Equal
1055 }
1056}
1057
1058impl Eq for HeapEntry {}
1059
1060impl PartialOrd for HeapEntry {
1061 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
1062 Some(self.cmp(other))
1063 }
1064}
1065
1066impl Ord for HeapEntry {
1067 fn cmp(&self, other: &Self) -> Ordering {
1068 let raw = compare_values(Some(&self.key), Some(&other.key));
1069 if self.asc { raw } else { raw.reverse() }
1070 }
1071}
1072
1073fn select_topk(
1082 matching: &[i64],
1083 table: &Table,
1084 order: &OrderByClause,
1085 k: usize,
1086) -> Result<Vec<i64>> {
1087 use std::collections::BinaryHeap;
1088
1089 if k == 0 || matching.is_empty() {
1090 return Ok(Vec::new());
1091 }
1092
1093 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
1094
1095 for &rowid in matching {
1096 let key = eval_expr(&order.expr, table, rowid)?;
1097 let entry = HeapEntry {
1098 key,
1099 rowid,
1100 asc: order.ascending,
1101 };
1102
1103 if heap.len() < k {
1104 heap.push(entry);
1105 } else {
1106 if entry < *heap.peek().unwrap() {
1110 heap.pop();
1111 heap.push(entry);
1112 }
1113 }
1114 }
1115
1116 Ok(heap
1121 .into_sorted_vec()
1122 .into_iter()
1123 .map(|e| e.rowid)
1124 .collect())
1125}
1126
1127fn sort_rowids(rowids: &mut [i64], table: &Table, order: &OrderByClause) -> Result<()> {
1128 let mut keys: Vec<(i64, Result<Value>)> = rowids
1136 .iter()
1137 .map(|r| (*r, eval_expr(&order.expr, table, *r)))
1138 .collect();
1139
1140 for (_, k) in &keys {
1144 if let Err(e) = k {
1145 return Err(SQLRiteError::General(format!(
1146 "ORDER BY expression failed: {e}"
1147 )));
1148 }
1149 }
1150
1151 keys.sort_by(|(_, ka), (_, kb)| {
1152 let va = ka.as_ref().unwrap();
1155 let vb = kb.as_ref().unwrap();
1156 let ord = compare_values(Some(va), Some(vb));
1157 if order.ascending { ord } else { ord.reverse() }
1158 });
1159
1160 for (i, (rowid, _)) in keys.into_iter().enumerate() {
1162 rowids[i] = rowid;
1163 }
1164 Ok(())
1165}
1166
1167fn compare_values(a: Option<&Value>, b: Option<&Value>) -> Ordering {
1168 match (a, b) {
1169 (None, None) => Ordering::Equal,
1170 (None, _) => Ordering::Less,
1171 (_, None) => Ordering::Greater,
1172 (Some(a), Some(b)) => match (a, b) {
1173 (Value::Null, Value::Null) => Ordering::Equal,
1174 (Value::Null, _) => Ordering::Less,
1175 (_, Value::Null) => Ordering::Greater,
1176 (Value::Integer(x), Value::Integer(y)) => x.cmp(y),
1177 (Value::Real(x), Value::Real(y)) => x.partial_cmp(y).unwrap_or(Ordering::Equal),
1178 (Value::Integer(x), Value::Real(y)) => {
1179 (*x as f64).partial_cmp(y).unwrap_or(Ordering::Equal)
1180 }
1181 (Value::Real(x), Value::Integer(y)) => {
1182 x.partial_cmp(&(*y as f64)).unwrap_or(Ordering::Equal)
1183 }
1184 (Value::Text(x), Value::Text(y)) => x.cmp(y),
1185 (Value::Bool(x), Value::Bool(y)) => x.cmp(y),
1186 (x, y) => x.to_display_string().cmp(&y.to_display_string()),
1188 },
1189 }
1190}
1191
1192pub fn eval_predicate(expr: &Expr, table: &Table, rowid: i64) -> Result<bool> {
1194 let v = eval_expr(expr, table, rowid)?;
1195 match v {
1196 Value::Bool(b) => Ok(b),
1197 Value::Null => Ok(false), Value::Integer(i) => Ok(i != 0),
1199 other => Err(SQLRiteError::Internal(format!(
1200 "WHERE clause must evaluate to boolean, got {}",
1201 other.to_display_string()
1202 ))),
1203 }
1204}
1205
1206fn eval_expr(expr: &Expr, table: &Table, rowid: i64) -> Result<Value> {
1207 match expr {
1208 Expr::Nested(inner) => eval_expr(inner, table, rowid),
1209
1210 Expr::Identifier(ident) => {
1211 if ident.quote_style == Some('[') {
1221 let raw = format!("[{}]", ident.value);
1222 let v = parse_vector_literal(&raw)?;
1223 return Ok(Value::Vector(v));
1224 }
1225 Ok(table.get_value(&ident.value, rowid).unwrap_or(Value::Null))
1226 }
1227
1228 Expr::CompoundIdentifier(parts) => {
1229 let col = parts
1231 .last()
1232 .map(|i| i.value.as_str())
1233 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?;
1234 Ok(table.get_value(col, rowid).unwrap_or(Value::Null))
1235 }
1236
1237 Expr::Value(v) => convert_literal(&v.value),
1238
1239 Expr::UnaryOp { op, expr } => {
1240 let inner = eval_expr(expr, table, rowid)?;
1241 match op {
1242 UnaryOperator::Not => match inner {
1243 Value::Bool(b) => Ok(Value::Bool(!b)),
1244 Value::Null => Ok(Value::Null),
1245 other => Err(SQLRiteError::Internal(format!(
1246 "NOT applied to non-boolean value: {}",
1247 other.to_display_string()
1248 ))),
1249 },
1250 UnaryOperator::Minus => match inner {
1251 Value::Integer(i) => Ok(Value::Integer(-i)),
1252 Value::Real(f) => Ok(Value::Real(-f)),
1253 Value::Null => Ok(Value::Null),
1254 other => Err(SQLRiteError::Internal(format!(
1255 "unary minus on non-numeric value: {}",
1256 other.to_display_string()
1257 ))),
1258 },
1259 UnaryOperator::Plus => Ok(inner),
1260 other => Err(SQLRiteError::NotImplemented(format!(
1261 "unary operator {other:?} is not supported"
1262 ))),
1263 }
1264 }
1265
1266 Expr::BinaryOp { left, op, right } => match op {
1267 BinaryOperator::And => {
1268 let l = eval_expr(left, table, rowid)?;
1269 let r = eval_expr(right, table, rowid)?;
1270 Ok(Value::Bool(as_bool(&l)? && as_bool(&r)?))
1271 }
1272 BinaryOperator::Or => {
1273 let l = eval_expr(left, table, rowid)?;
1274 let r = eval_expr(right, table, rowid)?;
1275 Ok(Value::Bool(as_bool(&l)? || as_bool(&r)?))
1276 }
1277 cmp @ (BinaryOperator::Eq
1278 | BinaryOperator::NotEq
1279 | BinaryOperator::Lt
1280 | BinaryOperator::LtEq
1281 | BinaryOperator::Gt
1282 | BinaryOperator::GtEq) => {
1283 let l = eval_expr(left, table, rowid)?;
1284 let r = eval_expr(right, table, rowid)?;
1285 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1287 return Ok(Value::Bool(false));
1288 }
1289 let ord = compare_values(Some(&l), Some(&r));
1290 let result = match cmp {
1291 BinaryOperator::Eq => ord == Ordering::Equal,
1292 BinaryOperator::NotEq => ord != Ordering::Equal,
1293 BinaryOperator::Lt => ord == Ordering::Less,
1294 BinaryOperator::LtEq => ord != Ordering::Greater,
1295 BinaryOperator::Gt => ord == Ordering::Greater,
1296 BinaryOperator::GtEq => ord != Ordering::Less,
1297 _ => unreachable!(),
1298 };
1299 Ok(Value::Bool(result))
1300 }
1301 arith @ (BinaryOperator::Plus
1302 | BinaryOperator::Minus
1303 | BinaryOperator::Multiply
1304 | BinaryOperator::Divide
1305 | BinaryOperator::Modulo) => {
1306 let l = eval_expr(left, table, rowid)?;
1307 let r = eval_expr(right, table, rowid)?;
1308 eval_arith(arith, &l, &r)
1309 }
1310 BinaryOperator::StringConcat => {
1311 let l = eval_expr(left, table, rowid)?;
1312 let r = eval_expr(right, table, rowid)?;
1313 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1314 return Ok(Value::Null);
1315 }
1316 Ok(Value::Text(format!(
1317 "{}{}",
1318 l.to_display_string(),
1319 r.to_display_string()
1320 )))
1321 }
1322 other => Err(SQLRiteError::NotImplemented(format!(
1323 "binary operator {other:?} is not supported yet"
1324 ))),
1325 },
1326
1327 Expr::Function(func) => eval_function(func, table, rowid),
1338
1339 other => Err(SQLRiteError::NotImplemented(format!(
1340 "unsupported expression in WHERE/projection: {other:?}"
1341 ))),
1342 }
1343}
1344
1345fn eval_function(func: &sqlparser::ast::Function, table: &Table, rowid: i64) -> Result<Value> {
1350 let name = match func.name.0.as_slice() {
1353 [ObjectNamePart::Identifier(ident)] => ident.value.to_lowercase(),
1354 _ => {
1355 return Err(SQLRiteError::NotImplemented(format!(
1356 "qualified function names not supported: {:?}",
1357 func.name
1358 )));
1359 }
1360 };
1361
1362 match name.as_str() {
1363 "vec_distance_l2" | "vec_distance_cosine" | "vec_distance_dot" => {
1364 let (a, b) = extract_two_vector_args(&name, &func.args, table, rowid)?;
1365 let dist = match name.as_str() {
1366 "vec_distance_l2" => vec_distance_l2(&a, &b),
1367 "vec_distance_cosine" => vec_distance_cosine(&a, &b)?,
1368 "vec_distance_dot" => vec_distance_dot(&a, &b),
1369 _ => unreachable!(),
1370 };
1371 Ok(Value::Real(dist as f64))
1377 }
1378 "json_extract" => json_fn_extract(&name, &func.args, table, rowid),
1383 "json_type" => json_fn_type(&name, &func.args, table, rowid),
1384 "json_array_length" => json_fn_array_length(&name, &func.args, table, rowid),
1385 "json_object_keys" => json_fn_object_keys(&name, &func.args, table, rowid),
1386 "fts_match" => {
1390 let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
1391 Ok(Value::Bool(entry.index.matches(rowid, &query)))
1392 }
1393 "bm25_score" => {
1394 let (entry, query) = resolve_fts_args(&name, &func.args, table, rowid)?;
1395 let s = entry.index.score(rowid, &query, &Bm25Params::default());
1396 Ok(Value::Real(s))
1397 }
1398 other => Err(SQLRiteError::NotImplemented(format!(
1399 "unknown function: {other}(...)"
1400 ))),
1401 }
1402}
1403
1404fn resolve_fts_args<'t>(
1409 fn_name: &str,
1410 args: &FunctionArguments,
1411 table: &'t Table,
1412 rowid: i64,
1413) -> Result<(&'t FtsIndexEntry, String)> {
1414 let arg_list = match args {
1415 FunctionArguments::List(l) => &l.args,
1416 _ => {
1417 return Err(SQLRiteError::General(format!(
1418 "{fn_name}() expects exactly two arguments: (column, query_text)"
1419 )));
1420 }
1421 };
1422 if arg_list.len() != 2 {
1423 return Err(SQLRiteError::General(format!(
1424 "{fn_name}() expects exactly 2 arguments, got {}",
1425 arg_list.len()
1426 )));
1427 }
1428
1429 let col_expr = match &arg_list[0] {
1433 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1434 other => {
1435 return Err(SQLRiteError::NotImplemented(format!(
1436 "{fn_name}() argument 0 must be a column name, got {other:?}"
1437 )));
1438 }
1439 };
1440 let col_name = match col_expr {
1441 Expr::Identifier(ident) => ident.value.clone(),
1442 Expr::CompoundIdentifier(parts) => parts
1443 .last()
1444 .map(|p| p.value.clone())
1445 .ok_or_else(|| SQLRiteError::Internal("empty compound identifier".to_string()))?,
1446 other => {
1447 return Err(SQLRiteError::General(format!(
1448 "{fn_name}() argument 0 must be a column reference, got {other:?}"
1449 )));
1450 }
1451 };
1452
1453 let q_expr = match &arg_list[1] {
1457 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1458 other => {
1459 return Err(SQLRiteError::NotImplemented(format!(
1460 "{fn_name}() argument 1 must be a text expression, got {other:?}"
1461 )));
1462 }
1463 };
1464 let query = match eval_expr(q_expr, table, rowid)? {
1465 Value::Text(s) => s,
1466 other => {
1467 return Err(SQLRiteError::General(format!(
1468 "{fn_name}() argument 1 must be TEXT, got {}",
1469 other.to_display_string()
1470 )));
1471 }
1472 };
1473
1474 let entry = table
1475 .fts_indexes
1476 .iter()
1477 .find(|e| e.column_name == col_name)
1478 .ok_or_else(|| {
1479 SQLRiteError::General(format!(
1480 "{fn_name}({col_name}, ...): no FTS index on column '{col_name}' \
1481 (run CREATE INDEX <name> ON <table> USING fts({col_name}) first)"
1482 ))
1483 })?;
1484 Ok((entry, query))
1485}
1486
1487fn extract_json_and_path(
1501 fn_name: &str,
1502 args: &FunctionArguments,
1503 table: &Table,
1504 rowid: i64,
1505) -> Result<(String, String)> {
1506 let arg_list = match args {
1507 FunctionArguments::List(l) => &l.args,
1508 _ => {
1509 return Err(SQLRiteError::General(format!(
1510 "{fn_name}() expects 1 or 2 arguments"
1511 )));
1512 }
1513 };
1514 if !(arg_list.len() == 1 || arg_list.len() == 2) {
1515 return Err(SQLRiteError::General(format!(
1516 "{fn_name}() expects 1 or 2 arguments, got {}",
1517 arg_list.len()
1518 )));
1519 }
1520 let first_expr = match &arg_list[0] {
1522 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1523 other => {
1524 return Err(SQLRiteError::NotImplemented(format!(
1525 "{fn_name}() argument 0 has unsupported shape: {other:?}"
1526 )));
1527 }
1528 };
1529 let json_text = match eval_expr(first_expr, table, rowid)? {
1530 Value::Text(s) => s,
1531 Value::Null => {
1532 return Err(SQLRiteError::General(format!(
1533 "{fn_name}() called on NULL — JSON column has no value for this row"
1534 )));
1535 }
1536 other => {
1537 return Err(SQLRiteError::General(format!(
1538 "{fn_name}() argument 0 is not JSON-typed: got {}",
1539 other.to_display_string()
1540 )));
1541 }
1542 };
1543
1544 let path = if arg_list.len() == 2 {
1546 let path_expr = match &arg_list[1] {
1547 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1548 other => {
1549 return Err(SQLRiteError::NotImplemented(format!(
1550 "{fn_name}() argument 1 has unsupported shape: {other:?}"
1551 )));
1552 }
1553 };
1554 match eval_expr(path_expr, table, rowid)? {
1555 Value::Text(s) => s,
1556 other => {
1557 return Err(SQLRiteError::General(format!(
1558 "{fn_name}() path argument must be a string literal, got {}",
1559 other.to_display_string()
1560 )));
1561 }
1562 }
1563 } else {
1564 "$".to_string()
1565 };
1566
1567 Ok((json_text, path))
1568}
1569
1570fn walk_json_path<'a>(
1580 value: &'a serde_json::Value,
1581 path: &str,
1582) -> Result<Option<&'a serde_json::Value>> {
1583 let mut chars = path.chars().peekable();
1584 if chars.next() != Some('$') {
1585 return Err(SQLRiteError::General(format!(
1586 "JSON path must start with '$', got `{path}`"
1587 )));
1588 }
1589 let mut current = value;
1590 while let Some(&c) = chars.peek() {
1591 match c {
1592 '.' => {
1593 chars.next();
1594 let mut key = String::new();
1595 while let Some(&c) = chars.peek() {
1596 if c == '.' || c == '[' {
1597 break;
1598 }
1599 key.push(c);
1600 chars.next();
1601 }
1602 if key.is_empty() {
1603 return Err(SQLRiteError::General(format!(
1604 "JSON path has empty key after '.' in `{path}`"
1605 )));
1606 }
1607 match current.get(&key) {
1608 Some(v) => current = v,
1609 None => return Ok(None),
1610 }
1611 }
1612 '[' => {
1613 chars.next();
1614 let mut idx_str = String::new();
1615 while let Some(&c) = chars.peek() {
1616 if c == ']' {
1617 break;
1618 }
1619 idx_str.push(c);
1620 chars.next();
1621 }
1622 if chars.next() != Some(']') {
1623 return Err(SQLRiteError::General(format!(
1624 "JSON path has unclosed `[` in `{path}`"
1625 )));
1626 }
1627 let idx: usize = idx_str.trim().parse().map_err(|_| {
1628 SQLRiteError::General(format!(
1629 "JSON path has non-integer index `[{idx_str}]` in `{path}`"
1630 ))
1631 })?;
1632 match current.get(idx) {
1633 Some(v) => current = v,
1634 None => return Ok(None),
1635 }
1636 }
1637 other => {
1638 return Err(SQLRiteError::General(format!(
1639 "JSON path has unexpected character `{other}` in `{path}` \
1640 (expected `.`, `[`, or end-of-path)"
1641 )));
1642 }
1643 }
1644 }
1645 Ok(Some(current))
1646}
1647
1648fn json_value_to_sql(v: &serde_json::Value) -> Value {
1652 match v {
1653 serde_json::Value::Null => Value::Null,
1654 serde_json::Value::Bool(b) => Value::Bool(*b),
1655 serde_json::Value::Number(n) => {
1656 if let Some(i) = n.as_i64() {
1658 Value::Integer(i)
1659 } else if let Some(f) = n.as_f64() {
1660 Value::Real(f)
1661 } else {
1662 Value::Null
1663 }
1664 }
1665 serde_json::Value::String(s) => Value::Text(s.clone()),
1666 composite => Value::Text(composite.to_string()),
1670 }
1671}
1672
1673fn json_fn_extract(
1674 name: &str,
1675 args: &FunctionArguments,
1676 table: &Table,
1677 rowid: i64,
1678) -> Result<Value> {
1679 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1680 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1681 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1682 })?;
1683 match walk_json_path(&parsed, &path)? {
1684 Some(v) => Ok(json_value_to_sql(v)),
1685 None => Ok(Value::Null),
1686 }
1687}
1688
1689fn json_fn_type(name: &str, args: &FunctionArguments, table: &Table, rowid: i64) -> Result<Value> {
1690 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1691 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1692 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1693 })?;
1694 let resolved = match walk_json_path(&parsed, &path)? {
1695 Some(v) => v,
1696 None => return Ok(Value::Null),
1697 };
1698 let ty = match resolved {
1699 serde_json::Value::Null => "null",
1700 serde_json::Value::Bool(true) => "true",
1701 serde_json::Value::Bool(false) => "false",
1702 serde_json::Value::Number(n) => {
1703 if n.is_i64() || n.is_u64() {
1704 "integer"
1705 } else {
1706 "real"
1707 }
1708 }
1709 serde_json::Value::String(_) => "text",
1710 serde_json::Value::Array(_) => "array",
1711 serde_json::Value::Object(_) => "object",
1712 };
1713 Ok(Value::Text(ty.to_string()))
1714}
1715
1716fn json_fn_array_length(
1717 name: &str,
1718 args: &FunctionArguments,
1719 table: &Table,
1720 rowid: i64,
1721) -> Result<Value> {
1722 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1723 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1724 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1725 })?;
1726 let resolved = match walk_json_path(&parsed, &path)? {
1727 Some(v) => v,
1728 None => return Ok(Value::Null),
1729 };
1730 match resolved.as_array() {
1731 Some(arr) => Ok(Value::Integer(arr.len() as i64)),
1732 None => Err(SQLRiteError::General(format!(
1733 "{name}() resolved to a non-array value at path `{path}`"
1734 ))),
1735 }
1736}
1737
1738fn json_fn_object_keys(
1739 name: &str,
1740 args: &FunctionArguments,
1741 table: &Table,
1742 rowid: i64,
1743) -> Result<Value> {
1744 let (json_text, path) = extract_json_and_path(name, args, table, rowid)?;
1745 let parsed: serde_json::Value = serde_json::from_str(&json_text).map_err(|e| {
1746 SQLRiteError::General(format!("{name}() got invalid JSON `{json_text}`: {e}"))
1747 })?;
1748 let resolved = match walk_json_path(&parsed, &path)? {
1749 Some(v) => v,
1750 None => return Ok(Value::Null),
1751 };
1752 let obj = resolved.as_object().ok_or_else(|| {
1753 SQLRiteError::General(format!(
1754 "{name}() resolved to a non-object value at path `{path}`"
1755 ))
1756 })?;
1757 let keys: Vec<serde_json::Value> = obj
1764 .keys()
1765 .map(|k| serde_json::Value::String(k.clone()))
1766 .collect();
1767 Ok(Value::Text(serde_json::Value::Array(keys).to_string()))
1768}
1769
1770fn extract_two_vector_args(
1774 fn_name: &str,
1775 args: &FunctionArguments,
1776 table: &Table,
1777 rowid: i64,
1778) -> Result<(Vec<f32>, Vec<f32>)> {
1779 let arg_list = match args {
1780 FunctionArguments::List(l) => &l.args,
1781 _ => {
1782 return Err(SQLRiteError::General(format!(
1783 "{fn_name}() expects exactly two vector arguments"
1784 )));
1785 }
1786 };
1787 if arg_list.len() != 2 {
1788 return Err(SQLRiteError::General(format!(
1789 "{fn_name}() expects exactly 2 arguments, got {}",
1790 arg_list.len()
1791 )));
1792 }
1793 let mut out: Vec<Vec<f32>> = Vec::with_capacity(2);
1794 for (i, arg) in arg_list.iter().enumerate() {
1795 let expr = match arg {
1796 FunctionArg::Unnamed(FunctionArgExpr::Expr(e)) => e,
1797 other => {
1798 return Err(SQLRiteError::NotImplemented(format!(
1799 "{fn_name}() argument {i} has unsupported shape: {other:?}"
1800 )));
1801 }
1802 };
1803 let val = eval_expr(expr, table, rowid)?;
1804 match val {
1805 Value::Vector(v) => out.push(v),
1806 other => {
1807 return Err(SQLRiteError::General(format!(
1808 "{fn_name}() argument {i} is not a vector: got {}",
1809 other.to_display_string()
1810 )));
1811 }
1812 }
1813 }
1814 let b = out.pop().unwrap();
1815 let a = out.pop().unwrap();
1816 if a.len() != b.len() {
1817 return Err(SQLRiteError::General(format!(
1818 "{fn_name}(): vector dimensions don't match (lhs={}, rhs={})",
1819 a.len(),
1820 b.len()
1821 )));
1822 }
1823 Ok((a, b))
1824}
1825
1826pub(crate) fn vec_distance_l2(a: &[f32], b: &[f32]) -> f32 {
1829 debug_assert_eq!(a.len(), b.len());
1830 let mut sum = 0.0f32;
1831 for i in 0..a.len() {
1832 let d = a[i] - b[i];
1833 sum += d * d;
1834 }
1835 sum.sqrt()
1836}
1837
1838pub(crate) fn vec_distance_cosine(a: &[f32], b: &[f32]) -> Result<f32> {
1848 debug_assert_eq!(a.len(), b.len());
1849 let mut dot = 0.0f32;
1850 let mut norm_a_sq = 0.0f32;
1851 let mut norm_b_sq = 0.0f32;
1852 for i in 0..a.len() {
1853 dot += a[i] * b[i];
1854 norm_a_sq += a[i] * a[i];
1855 norm_b_sq += b[i] * b[i];
1856 }
1857 let denom = (norm_a_sq * norm_b_sq).sqrt();
1858 if denom == 0.0 {
1859 return Err(SQLRiteError::General(
1860 "vec_distance_cosine() is undefined for zero-magnitude vectors".to_string(),
1861 ));
1862 }
1863 Ok(1.0 - dot / denom)
1864}
1865
1866pub(crate) fn vec_distance_dot(a: &[f32], b: &[f32]) -> f32 {
1870 debug_assert_eq!(a.len(), b.len());
1871 let mut dot = 0.0f32;
1872 for i in 0..a.len() {
1873 dot += a[i] * b[i];
1874 }
1875 -dot
1876}
1877
1878fn eval_arith(op: &BinaryOperator, l: &Value, r: &Value) -> Result<Value> {
1881 if matches!(l, Value::Null) || matches!(r, Value::Null) {
1882 return Ok(Value::Null);
1883 }
1884 match (l, r) {
1885 (Value::Integer(a), Value::Integer(b)) => match op {
1886 BinaryOperator::Plus => Ok(Value::Integer(a.wrapping_add(*b))),
1887 BinaryOperator::Minus => Ok(Value::Integer(a.wrapping_sub(*b))),
1888 BinaryOperator::Multiply => Ok(Value::Integer(a.wrapping_mul(*b))),
1889 BinaryOperator::Divide => {
1890 if *b == 0 {
1891 Err(SQLRiteError::General("division by zero".to_string()))
1892 } else {
1893 Ok(Value::Integer(a / b))
1894 }
1895 }
1896 BinaryOperator::Modulo => {
1897 if *b == 0 {
1898 Err(SQLRiteError::General("modulo by zero".to_string()))
1899 } else {
1900 Ok(Value::Integer(a % b))
1901 }
1902 }
1903 _ => unreachable!(),
1904 },
1905 (a, b) => {
1907 let af = as_number(a)?;
1908 let bf = as_number(b)?;
1909 match op {
1910 BinaryOperator::Plus => Ok(Value::Real(af + bf)),
1911 BinaryOperator::Minus => Ok(Value::Real(af - bf)),
1912 BinaryOperator::Multiply => Ok(Value::Real(af * bf)),
1913 BinaryOperator::Divide => {
1914 if bf == 0.0 {
1915 Err(SQLRiteError::General("division by zero".to_string()))
1916 } else {
1917 Ok(Value::Real(af / bf))
1918 }
1919 }
1920 BinaryOperator::Modulo => {
1921 if bf == 0.0 {
1922 Err(SQLRiteError::General("modulo by zero".to_string()))
1923 } else {
1924 Ok(Value::Real(af % bf))
1925 }
1926 }
1927 _ => unreachable!(),
1928 }
1929 }
1930 }
1931}
1932
1933fn as_number(v: &Value) -> Result<f64> {
1934 match v {
1935 Value::Integer(i) => Ok(*i as f64),
1936 Value::Real(f) => Ok(*f),
1937 Value::Bool(b) => Ok(if *b { 1.0 } else { 0.0 }),
1938 other => Err(SQLRiteError::General(format!(
1939 "arithmetic on non-numeric value '{}'",
1940 other.to_display_string()
1941 ))),
1942 }
1943}
1944
1945fn as_bool(v: &Value) -> Result<bool> {
1946 match v {
1947 Value::Bool(b) => Ok(*b),
1948 Value::Null => Ok(false),
1949 Value::Integer(i) => Ok(*i != 0),
1950 other => Err(SQLRiteError::Internal(format!(
1951 "expected boolean, got {}",
1952 other.to_display_string()
1953 ))),
1954 }
1955}
1956
1957fn convert_literal(v: &sqlparser::ast::Value) -> Result<Value> {
1958 use sqlparser::ast::Value as AstValue;
1959 match v {
1960 AstValue::Number(n, _) => {
1961 if let Ok(i) = n.parse::<i64>() {
1962 Ok(Value::Integer(i))
1963 } else if let Ok(f) = n.parse::<f64>() {
1964 Ok(Value::Real(f))
1965 } else {
1966 Err(SQLRiteError::Internal(format!(
1967 "could not parse numeric literal '{n}'"
1968 )))
1969 }
1970 }
1971 AstValue::SingleQuotedString(s) => Ok(Value::Text(s.clone())),
1972 AstValue::Boolean(b) => Ok(Value::Bool(*b)),
1973 AstValue::Null => Ok(Value::Null),
1974 other => Err(SQLRiteError::NotImplemented(format!(
1975 "unsupported literal value: {other:?}"
1976 ))),
1977 }
1978}
1979
1980#[cfg(test)]
1981mod tests {
1982 use super::*;
1983
1984 fn approx_eq(a: f32, b: f32, eps: f32) -> bool {
1991 (a - b).abs() < eps
1992 }
1993
1994 #[test]
1995 fn vec_distance_l2_identical_is_zero() {
1996 let v = vec![0.1, 0.2, 0.3];
1997 assert_eq!(vec_distance_l2(&v, &v), 0.0);
1998 }
1999
2000 #[test]
2001 fn vec_distance_l2_unit_basis_is_sqrt2() {
2002 let a = vec![1.0, 0.0];
2004 let b = vec![0.0, 1.0];
2005 assert!(approx_eq(vec_distance_l2(&a, &b), 2.0_f32.sqrt(), 1e-6));
2006 }
2007
2008 #[test]
2009 fn vec_distance_l2_known_value() {
2010 let a = vec![0.0, 0.0, 0.0];
2012 let b = vec![3.0, 4.0, 0.0];
2013 assert!(approx_eq(vec_distance_l2(&a, &b), 5.0, 1e-6));
2014 }
2015
2016 #[test]
2017 fn vec_distance_cosine_identical_is_zero() {
2018 let v = vec![0.1, 0.2, 0.3];
2019 let d = vec_distance_cosine(&v, &v).unwrap();
2020 assert!(approx_eq(d, 0.0, 1e-6), "cos(v,v) = {d}, expected ≈ 0");
2021 }
2022
2023 #[test]
2024 fn vec_distance_cosine_orthogonal_is_one() {
2025 let a = vec![1.0, 0.0];
2028 let b = vec![0.0, 1.0];
2029 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 1.0, 1e-6));
2030 }
2031
2032 #[test]
2033 fn vec_distance_cosine_opposite_is_two() {
2034 let a = vec![1.0, 0.0, 0.0];
2036 let b = vec![-1.0, 0.0, 0.0];
2037 assert!(approx_eq(vec_distance_cosine(&a, &b).unwrap(), 2.0, 1e-6));
2038 }
2039
2040 #[test]
2041 fn vec_distance_cosine_zero_magnitude_errors() {
2042 let a = vec![0.0, 0.0];
2044 let b = vec![1.0, 0.0];
2045 let err = vec_distance_cosine(&a, &b).unwrap_err();
2046 assert!(format!("{err}").contains("zero-magnitude"));
2047 }
2048
2049 #[test]
2050 fn vec_distance_dot_negates() {
2051 let a = vec![1.0, 2.0, 3.0];
2053 let b = vec![4.0, 5.0, 6.0];
2054 assert!(approx_eq(vec_distance_dot(&a, &b), -32.0, 1e-6));
2055 }
2056
2057 #[test]
2058 fn vec_distance_dot_orthogonal_is_zero() {
2059 let a = vec![1.0, 0.0];
2061 let b = vec![0.0, 1.0];
2062 assert_eq!(vec_distance_dot(&a, &b), 0.0);
2063 }
2064
2065 #[test]
2066 fn vec_distance_dot_unit_norm_matches_cosine_minus_one() {
2067 let a = vec![0.6f32, 0.8]; let b = vec![0.8f32, 0.6]; let dot = vec_distance_dot(&a, &b);
2073 let cos = vec_distance_cosine(&a, &b).unwrap();
2074 assert!(approx_eq(dot, cos - 1.0, 1e-5));
2075 }
2076
2077 use crate::sql::db::database::Database;
2082 use crate::sql::parser::select::SelectQuery;
2083 use sqlparser::dialect::SQLiteDialect;
2084 use sqlparser::parser::Parser;
2085
2086 fn seed_score_table(n: usize) -> Database {
2099 let mut db = Database::new("tempdb".to_string());
2100 crate::sql::process_command(
2101 "CREATE TABLE docs (id INTEGER PRIMARY KEY, score REAL);",
2102 &mut db,
2103 )
2104 .expect("create");
2105 for i in 0..n {
2106 let score = ((i as u64).wrapping_mul(2_654_435_761) % 1_000_000) as f64;
2110 let sql = format!("INSERT INTO docs (score) VALUES ({score});");
2111 crate::sql::process_command(&sql, &mut db).expect("insert");
2112 }
2113 db
2114 }
2115
2116 fn parse_select(sql: &str) -> SelectQuery {
2120 let dialect = SQLiteDialect {};
2121 let mut ast = Parser::parse_sql(&dialect, sql).expect("parse");
2122 let stmt = ast.pop().expect("one statement");
2123 SelectQuery::new(&stmt).expect("select-query")
2124 }
2125
2126 #[test]
2127 fn topk_matches_full_sort_asc() {
2128 let db = seed_score_table(200);
2131 let table = db.get_table("docs".to_string()).unwrap();
2132 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2133 let order = q.order_by.as_ref().unwrap();
2134 let all_rowids = table.rowids();
2135
2136 let mut full = all_rowids.clone();
2138 sort_rowids(&mut full, table, order).unwrap();
2139 full.truncate(10);
2140
2141 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
2143
2144 assert_eq!(topk, full, "top-k via heap should match full-sort+truncate");
2145 }
2146
2147 #[test]
2148 fn topk_matches_full_sort_desc() {
2149 let db = seed_score_table(200);
2151 let table = db.get_table("docs".to_string()).unwrap();
2152 let q = parse_select("SELECT * FROM docs ORDER BY score DESC LIMIT 10;");
2153 let order = q.order_by.as_ref().unwrap();
2154 let all_rowids = table.rowids();
2155
2156 let mut full = all_rowids.clone();
2157 sort_rowids(&mut full, table, order).unwrap();
2158 full.truncate(10);
2159
2160 let topk = select_topk(&all_rowids, table, order, 10).unwrap();
2161
2162 assert_eq!(
2163 topk, full,
2164 "top-k DESC via heap should match full-sort+truncate"
2165 );
2166 }
2167
2168 #[test]
2169 fn topk_k_larger_than_n_returns_everything_sorted() {
2170 let db = seed_score_table(50);
2175 let table = db.get_table("docs".to_string()).unwrap();
2176 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1000;");
2177 let order = q.order_by.as_ref().unwrap();
2178 let topk = select_topk(&table.rowids(), table, order, 1000).unwrap();
2179 assert_eq!(topk.len(), 50);
2180 let scores: Vec<f64> = topk
2182 .iter()
2183 .filter_map(|r| match table.get_value("score", *r) {
2184 Some(Value::Real(f)) => Some(f),
2185 _ => None,
2186 })
2187 .collect();
2188 assert!(scores.windows(2).all(|w| w[0] <= w[1]));
2189 }
2190
2191 #[test]
2192 fn topk_k_zero_returns_empty() {
2193 let db = seed_score_table(10);
2194 let table = db.get_table("docs".to_string()).unwrap();
2195 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 1;");
2196 let order = q.order_by.as_ref().unwrap();
2197 let topk = select_topk(&table.rowids(), table, order, 0).unwrap();
2198 assert!(topk.is_empty());
2199 }
2200
2201 #[test]
2202 fn topk_empty_input_returns_empty() {
2203 let db = seed_score_table(0);
2204 let table = db.get_table("docs".to_string()).unwrap();
2205 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 5;");
2206 let order = q.order_by.as_ref().unwrap();
2207 let topk = select_topk(&[], table, order, 5).unwrap();
2208 assert!(topk.is_empty());
2209 }
2210
2211 #[test]
2212 fn topk_works_through_select_executor_with_distance_function() {
2213 let mut db = Database::new("tempdb".to_string());
2217 crate::sql::process_command(
2218 "CREATE TABLE docs (id INTEGER PRIMARY KEY, e VECTOR(2));",
2219 &mut db,
2220 )
2221 .unwrap();
2222 for v in &[
2229 "[1.0, 0.0]",
2230 "[2.0, 0.0]",
2231 "[0.0, 3.0]",
2232 "[1.0, 4.0]",
2233 "[10.0, 10.0]",
2234 ] {
2235 crate::sql::process_command(&format!("INSERT INTO docs (e) VALUES ({v});"), &mut db)
2236 .unwrap();
2237 }
2238 let resp = crate::sql::process_command(
2239 "SELECT id FROM docs ORDER BY vec_distance_l2(e, [1.0, 0.0]) ASC LIMIT 3;",
2240 &mut db,
2241 )
2242 .unwrap();
2243 assert!(resp.contains("3 rows returned"), "got: {resp}");
2246 }
2247
2248 #[test]
2271 #[ignore]
2272 fn topk_benchmark() {
2273 use std::time::Instant;
2274 const N: usize = 10_000;
2275 const K: usize = 10;
2276
2277 let db = seed_score_table(N);
2278 let table = db.get_table("docs".to_string()).unwrap();
2279 let q = parse_select("SELECT * FROM docs ORDER BY score ASC LIMIT 10;");
2280 let order = q.order_by.as_ref().unwrap();
2281 let all_rowids = table.rowids();
2282
2283 let t0 = Instant::now();
2285 let _topk = select_topk(&all_rowids, table, order, K).unwrap();
2286 let heap_dur = t0.elapsed();
2287
2288 let t1 = Instant::now();
2290 let mut full = all_rowids.clone();
2291 sort_rowids(&mut full, table, order).unwrap();
2292 full.truncate(K);
2293 let sort_dur = t1.elapsed();
2294
2295 let ratio = sort_dur.as_secs_f64() / heap_dur.as_secs_f64().max(1e-9);
2296 println!("\n--- topk_benchmark (N={N}, k={K}) ---");
2297 println!(" bounded heap: {heap_dur:?}");
2298 println!(" full sort+trunc: {sort_dur:?}");
2299 println!(" speedup ratio: {ratio:.2}×");
2300
2301 assert!(
2308 ratio > 1.4,
2309 "bounded heap should be substantially faster than full sort, but ratio = {ratio:.2}"
2310 );
2311 }
2312}