1use rig_core::embeddings::{Embedding, EmbeddingModel};
11use rig_core::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
12use rig_core::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
13use rig_core::wasm_compat::{WasmCompatSend, WasmCompatSync};
14use rig_core::{Embed, OneOrMany};
15use rusqlite::OptionalExtension;
16use rusqlite::types::{Type, Value, ValueRef};
17use serde::{Deserialize, Serialize};
18use std::fmt::{self, Display};
19use std::marker::PhantomData;
20use std::ops::RangeInclusive;
21use tokio_rusqlite::Connection;
22use tracing::{debug, info};
23use zerocopy::IntoBytes;
24
25#[derive(Debug)]
26pub enum SqliteError {
27 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
28 SerializationError(Box<dyn std::error::Error + Send + Sync>),
29 InvalidColumnType(String),
30}
31
32pub trait ColumnValue: Send + Sync {
36 fn to_sql_value(&self) -> Value;
38
39 fn column_type(&self) -> &'static str;
41}
42
43#[derive(Clone, Debug)]
44pub struct Column {
45 name: &'static str,
46 col_type: &'static str,
47 indexed: bool,
48}
49
50impl Column {
51 pub fn new(name: &'static str, col_type: &'static str) -> Self {
52 Self {
53 name,
54 col_type,
55 indexed: false,
56 }
57 }
58
59 pub fn indexed(mut self) -> Self {
67 self.indexed = true;
68 self
69 }
70}
71
72pub trait SqliteVectorStoreTable: Send + Sync + Clone {
110 fn name() -> &'static str;
111 fn schema() -> Vec<Column>;
112 fn id(&self) -> String;
113 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
114}
115
116#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
124pub enum SqliteDistanceMetric {
125 #[default]
127 Cosine,
128 L2,
130 L1,
132}
133
134impl SqliteDistanceMetric {
135 fn vec0_name(self) -> &'static str {
136 match self {
137 Self::Cosine => "cosine",
138 Self::L2 => "l2",
139 Self::L1 => "l1",
140 }
141 }
142
143 fn score_expression(self, query_param: &str, embedding_expr: &str) -> String {
144 match self {
145 Self::Cosine => {
146 format!("(1 - vec_distance_cosine({query_param}, {embedding_expr}))")
147 }
148 Self::L2 => format!("(-vec_distance_l2({query_param}, {embedding_expr}))"),
149 Self::L1 => format!("(-vec_distance_l1({query_param}, {embedding_expr}))"),
150 }
151 }
152}
153
154#[derive(Debug)]
155struct SqliteDistanceMetricMismatch {
156 table_name: String,
157 requested: SqliteDistanceMetric,
158 configured: SqliteDistanceMetric,
159}
160
161impl Display for SqliteDistanceMetricMismatch {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 write!(
164 f,
165 "SQLite vector table `{}` uses {:?}, but {:?} was requested",
166 self.table_name, self.configured, self.requested
167 )
168 }
169}
170
171impl std::error::Error for SqliteDistanceMetricMismatch {}
172
173#[derive(Debug)]
174struct SqliteVectorTableMissingSchema {
175 table_name: String,
176}
177
178impl Display for SqliteVectorTableMissingSchema {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 write!(
181 f,
182 "SQLite vector table `{}` was created but is missing from sqlite_schema",
183 self.table_name
184 )
185 }
186}
187
188impl std::error::Error for SqliteVectorTableMissingSchema {}
189
190#[derive(Clone, Copy, Debug, Eq, PartialEq)]
191enum SqliteMetadataType {
192 Text,
193 Integer,
194 Float,
195 Boolean,
196}
197
198impl SqliteMetadataType {
199 fn from_column_type(column_type: &str) -> Option<Self> {
200 let first_type_token = column_type
201 .split_whitespace()
202 .next()
203 .unwrap_or_default()
204 .to_ascii_uppercase();
205
206 match first_type_token.as_str() {
207 "TEXT" => Some(Self::Text),
208 "INTEGER" | "INT" | "INT64" | "INTEGER64" => Some(Self::Integer),
209 "FLOAT" | "REAL" | "DOUBLE" | "FLOAT64" | "F64" => Some(Self::Float),
210 "BOOLEAN" | "BOOL" => Some(Self::Boolean),
211 _ => match SqliteColumnAffinity::from_column_type(column_type) {
212 SqliteColumnAffinity::Text => Some(Self::Text),
213 SqliteColumnAffinity::Integer => Some(Self::Integer),
214 SqliteColumnAffinity::Float => Some(Self::Float),
215 SqliteColumnAffinity::Boolean => Some(Self::Boolean),
216 SqliteColumnAffinity::Numeric | SqliteColumnAffinity::Blob => None,
217 },
218 }
219 }
220
221 fn vec0_name(self) -> &'static str {
222 match self {
223 Self::Text => "TEXT",
224 Self::Integer => "INTEGER",
225 Self::Float => "FLOAT",
226 Self::Boolean => "BOOLEAN",
227 }
228 }
229
230 fn supports_native_comparison(self, op: SqliteComparisonOp) -> bool {
231 !matches!(
232 (self, op),
233 (
234 Self::Boolean,
235 SqliteComparisonOp::Gt
236 | SqliteComparisonOp::Lt
237 | SqliteComparisonOp::Gte
238 | SqliteComparisonOp::Lte
239 )
240 )
241 }
242}
243
244#[derive(Clone, Copy, Debug, Eq, PartialEq)]
245enum SqliteColumnAffinity {
246 Text,
247 Integer,
248 Float,
249 Boolean,
250 Numeric,
251 Blob,
252}
253
254impl SqliteColumnAffinity {
255 fn from_column_type(column_type: &str) -> Self {
256 let column_type = column_type.to_ascii_uppercase();
257
258 if column_type.contains("INT") {
259 Self::Integer
260 } else if column_type.contains("CHAR")
261 || column_type.contains("CLOB")
262 || column_type.contains("TEXT")
263 {
264 Self::Text
265 } else if column_type.contains("BLOB") || column_type.trim().is_empty() {
266 Self::Blob
267 } else if column_type.contains("REAL")
268 || column_type.contains("FLOA")
269 || column_type.contains("DOUB")
270 {
271 Self::Float
272 } else if column_type.contains("BOOL") {
273 Self::Boolean
274 } else {
275 Self::Numeric
276 }
277 }
278}
279
280#[derive(Clone, Debug, Eq, PartialEq)]
281struct SqliteMetadataColumn {
282 name: &'static str,
283 metadata_type: SqliteMetadataType,
284}
285
286#[derive(Debug)]
287struct SqliteUnsupportedMetadataColumn {
288 column_name: &'static str,
289 column_type: &'static str,
290}
291
292impl Display for SqliteUnsupportedMetadataColumn {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(
295 f,
296 "SQLite metadata column `{}` has unsupported type `{}`",
297 self.column_name, self.column_type
298 )
299 }
300}
301
302impl std::error::Error for SqliteUnsupportedMetadataColumn {}
303
304#[derive(Debug)]
305struct SqliteMetadataSchemaMismatch {
306 table_name: String,
307 column_name: &'static str,
308 column_type: SqliteMetadataType,
309}
310
311impl Display for SqliteMetadataSchemaMismatch {
312 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
313 write!(
314 f,
315 "SQLite vector table `{}` is missing metadata column `{} {}`",
316 self.table_name,
317 self.column_name,
318 self.column_type.vec0_name()
319 )
320 }
321}
322
323impl std::error::Error for SqliteMetadataSchemaMismatch {}
324
325#[derive(Debug)]
326struct SqliteMetadataValueError {
327 column_name: &'static str,
328 column_type: SqliteMetadataType,
329 value_type: Type,
330}
331
332impl Display for SqliteMetadataValueError {
333 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
334 write!(
335 f,
336 "could not convert SQLite value type `{:?}` for metadata column `{} {}`",
337 self.value_type,
338 self.column_name,
339 self.column_type.vec0_name()
340 )
341 }
342}
343
344impl std::error::Error for SqliteMetadataValueError {}
345
346#[derive(Debug)]
347struct SqliteMissingIdColumn {
348 table_name: String,
349}
350
351impl Display for SqliteMissingIdColumn {
352 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
353 write!(
354 f,
355 "SQLite vector store table `{}` is missing an `id` column",
356 self.table_name
357 )
358 }
359}
360
361impl std::error::Error for SqliteMissingIdColumn {}
362
363fn sqlite_metadata_columns(
364 schema: &[Column],
365) -> Result<Vec<SqliteMetadataColumn>, VectorStoreError> {
366 schema
367 .iter()
368 .filter(|column| column.indexed)
369 .map(|column| {
370 let metadata_type =
371 SqliteMetadataType::from_column_type(column.col_type).ok_or_else(|| {
372 VectorStoreError::DatastoreError(Box::new(SqliteUnsupportedMetadataColumn {
373 column_name: column.name,
374 column_type: column.col_type,
375 }))
376 })?;
377
378 Ok(SqliteMetadataColumn {
379 name: column.name,
380 metadata_type,
381 })
382 })
383 .collect()
384}
385
386fn sqlite_metadata_value(
387 values: &[(&'static str, Box<dyn ColumnValue>)],
388 column: &SqliteMetadataColumn,
389) -> rusqlite::Result<Value> {
390 let value = values
391 .iter()
392 .find(|(name, _)| *name == column.name)
393 .ok_or_else(|| rusqlite::Error::InvalidParameterName(column.name.to_string()))?
394 .1
395 .to_sql_value();
396
397 match (column.metadata_type, value) {
398 (SqliteMetadataType::Text, Value::Text(value)) => Ok(Value::Text(value)),
399 (SqliteMetadataType::Integer, Value::Integer(value)) => Ok(Value::Integer(value)),
400 (SqliteMetadataType::Float, Value::Real(value)) => Ok(Value::Real(value)),
401 (SqliteMetadataType::Float, Value::Integer(value)) => Ok(Value::Real(value as f64)),
402 (SqliteMetadataType::Boolean, Value::Integer(value @ (0 | 1))) => Ok(Value::Integer(value)),
403 (_, value) => Err(rusqlite::Error::ToSqlConversionFailure(Box::new(
404 SqliteMetadataValueError {
405 column_name: column.name,
406 column_type: column.metadata_type,
407 value_type: value.data_type(),
408 },
409 ))),
410 }
411}
412
413#[derive(Clone)]
414pub struct SqliteVectorStore<E, T>
415where
416 E: EmbeddingModel + 'static,
417 T: SqliteVectorStoreTable + 'static,
418{
419 conn: Connection,
420 distance_metric: SqliteDistanceMetric,
421 metadata_columns: Vec<SqliteMetadataColumn>,
422 _phantom: PhantomData<(E, T)>,
423}
424
425impl<E, T> SqliteVectorStore<E, T>
426where
427 E: EmbeddingModel + 'static,
428 T: SqliteVectorStoreTable + 'static,
429{
430 async fn candidate_limit(&self, samples: u64, exhaustive: bool) -> Result<u64, VectorStoreError>
431 where
432 Self: 'static,
433 {
434 if samples == 0 {
435 return Ok(0);
436 }
437
438 let embedding_map_table_name = format!("{}_embedding_map", T::name());
439 let (embedding_count, document_count) = self
440 .conn
441 .call(move |conn| {
442 Ok(conn.query_row(
443 &format!(
444 "SELECT COUNT(*), COUNT(DISTINCT document_rowid) FROM {embedding_map_table_name}"
445 ),
446 [],
447 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)),
448 )?)
449 })
450 .await
451 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
452
453 let embedding_count = u64::try_from(embedding_count).unwrap_or(0);
454 let document_count = u64::try_from(document_count).unwrap_or(0);
455
456 if exhaustive {
457 Ok(embedding_count.max(samples))
458 } else if embedding_count > document_count {
459 Ok(embedding_count)
460 } else {
461 Ok(samples)
462 }
463 }
464}
465
466impl<E, T> SqliteVectorStore<E, T>
467where
468 E: EmbeddingModel + Clone + 'static,
469 T: SqliteVectorStoreTable + 'static,
470{
471 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
473 Self::with_distance_metric(conn, embedding_model, SqliteDistanceMetric::default()).await
474 }
475
476 pub async fn with_distance_metric(
482 conn: Connection,
483 embedding_model: &E,
484 distance_metric: SqliteDistanceMetric,
485 ) -> Result<Self, VectorStoreError> {
486 let dims = embedding_model.ndims();
487 let table_name = T::name();
488 let embeddings_table_name = format!("{table_name}_embeddings");
489 let embeddings_table_name_for_sql = embeddings_table_name.clone();
490 let embedding_map_table_name_for_sql = format!("{table_name}_embedding_map");
491 let schema = T::schema();
492 let metadata_columns = sqlite_metadata_columns(&schema)?;
493 let metadata_columns_for_schema_check = metadata_columns.clone();
494 let distance_metric_name = distance_metric.vec0_name();
495 let mut embeddings_columns =
496 format!("embedding float[{dims}] distance_metric={distance_metric_name}");
497 for column in &metadata_columns {
498 embeddings_columns.push_str(&format!(
499 ", {} {}",
500 column.name,
501 column.metadata_type.vec0_name()
502 ));
503 }
504
505 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
507
508 let mut first = true;
510 for column in &schema {
511 if !first {
512 create_table.push(',');
513 }
514 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
515 first = false;
516 }
517
518 create_table.push_str("\n)");
519
520 let mut create_indexes = vec![format!(
522 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
523 table_name, table_name
524 )];
525
526 for column in schema {
528 if column.indexed {
529 create_indexes.push(format!(
530 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
531 table_name, column.name, table_name, column.name
532 ));
533 }
534 }
535
536 let embeddings_table_sql = conn
537 .call(move |conn| {
538 conn.execute_batch("BEGIN")?;
539
540 conn.execute_batch(&create_table)?;
542
543 for index_stmt in create_indexes {
545 conn.execute_batch(&index_stmt)?;
546 }
547
548 conn.execute_batch(&format!(
550 "CREATE VIRTUAL TABLE IF NOT EXISTS {embeddings_table_name_for_sql} USING vec0({embeddings_columns})"
551 ))?;
552 conn.execute_batch(&format!(
553 "CREATE TABLE IF NOT EXISTS {embedding_map_table_name_for_sql} (
554 embedding_rowid INTEGER PRIMARY KEY AUTOINCREMENT,
555 document_rowid INTEGER NOT NULL
556 )"
557 ))?;
558 conn.execute_batch(&format!(
559 "CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding_map_document_rowid ON {embedding_map_table_name_for_sql}(document_rowid)"
560 ))?;
561
562 conn.execute_batch("COMMIT")?;
563
564 let schema_sql = conn
565 .query_row(
566 "SELECT sql FROM sqlite_schema WHERE name = ?1",
567 [&embeddings_table_name_for_sql],
568 |row| row.get::<_, String>(0),
569 )
570 .optional()?;
571
572 Ok(schema_sql)
573 })
574 .await
575 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
576
577 let schema_sql = embeddings_table_sql.ok_or_else(|| {
578 VectorStoreError::DatastoreError(Box::new(SqliteVectorTableMissingSchema {
579 table_name: embeddings_table_name.clone(),
580 }))
581 })?;
582
583 let configured = sqlite_distance_metric_from_schema(&schema_sql);
584 if configured != distance_metric {
585 return Err(VectorStoreError::DatastoreError(Box::new(
586 SqliteDistanceMetricMismatch {
587 table_name: embeddings_table_name,
588 requested: distance_metric,
589 configured,
590 },
591 )));
592 }
593 for column in metadata_columns_for_schema_check {
594 if !sqlite_schema_contains_metadata_column(&schema_sql, &column) {
595 return Err(VectorStoreError::DatastoreError(Box::new(
596 SqliteMetadataSchemaMismatch {
597 table_name: embeddings_table_name.clone(),
598 column_name: column.name,
599 column_type: column.metadata_type,
600 },
601 )));
602 }
603 }
604
605 Ok(Self {
606 conn,
607 distance_metric,
608 metadata_columns,
609 _phantom: PhantomData,
610 })
611 }
612
613 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
614 SqliteVectorIndex::new(model, self)
615 }
616
617 pub fn add_rows_with_txn(
618 &self,
619 txn: &rusqlite::Transaction<'_>,
620 documents: Vec<(T, OneOrMany<Embedding>)>,
621 ) -> Result<i64, tokio_rusqlite::Error> {
622 info!("Adding {} documents to store", documents.len());
623 let table_name = T::name();
624 let embeddings_table_name = format!("{table_name}_embeddings");
625 let embedding_map_table_name = format!("{table_name}_embedding_map");
626 let mut last_id = 0;
627 let embedding_columns = std::iter::once("rowid")
628 .chain(std::iter::once("embedding"))
629 .chain(self.metadata_columns.iter().map(|column| column.name))
630 .collect::<Vec<_>>();
631 let embedding_placeholders = (1..=embedding_columns.len())
632 .map(|i| format!("?{i}"))
633 .collect::<Vec<_>>();
634 let embeddings_sql = format!(
635 "INSERT INTO {embeddings_table_name} ({}) VALUES ({})",
636 embedding_columns.join(", "),
637 embedding_placeholders.join(", ")
638 );
639 let existing_rowid_sql = format!("SELECT rowid FROM {table_name} WHERE id = ?1");
640 let existing_embedding_rowids_sql = format!(
641 "SELECT embedding_rowid FROM {embedding_map_table_name} WHERE document_rowid = ?1"
642 );
643 let insert_embedding_map_sql =
644 format!("INSERT INTO {embedding_map_table_name}(document_rowid) VALUES (?1)");
645 let delete_embedding_map_sql =
646 format!("DELETE FROM {embedding_map_table_name} WHERE document_rowid = ?1");
647 let delete_embeddings_sql = format!("DELETE FROM {embeddings_table_name} WHERE rowid = ?1");
648
649 for (doc, embeddings) in &documents {
650 debug!("Storing document with id {}", doc.id());
651
652 let values = doc.column_values();
653 let id_value = values
654 .iter()
655 .find(|(name, _)| *name == "id")
656 .map(|(_, value)| value.to_sql_value())
657 .unwrap_or_else(|| Value::Text(doc.id()));
658 if let Some(existing_rowid) = txn
659 .query_row(&existing_rowid_sql, rusqlite::params![id_value], |row| {
660 row.get::<_, i64>(0)
661 })
662 .optional()?
663 {
664 let existing_embedding_rowids = txn
665 .prepare(&existing_embedding_rowids_sql)?
666 .query_map([existing_rowid], |row| row.get::<_, i64>(0))?
667 .collect::<rusqlite::Result<Vec<_>>>()?;
668 for embedding_rowid in existing_embedding_rowids {
669 txn.execute(&delete_embeddings_sql, [embedding_rowid])?;
670 }
671 txn.execute(&delete_embedding_map_sql, [existing_rowid])?;
672 }
673
674 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
675
676 let placeholders = (1..=values.len())
677 .map(|i| format!("?{i}"))
678 .collect::<Vec<_>>();
679
680 let insert_sql = format!(
681 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
682 table_name,
683 columns.join(", "),
684 placeholders.join(", ")
685 );
686
687 txn.execute(
688 &insert_sql,
689 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_value())),
690 )?;
691 last_id = txn.last_insert_rowid();
692
693 let metadata_values = self
694 .metadata_columns
695 .iter()
696 .map(|column| sqlite_metadata_value(&values, column))
697 .collect::<rusqlite::Result<Vec<_>>>()?;
698
699 let mut stmt = txn.prepare(&embeddings_sql)?;
700 for (i, embedding) in embeddings.iter().enumerate() {
701 let vec = serialize_embedding(embedding);
702 debug!(
703 "Storing embedding {} of {} (size: {} bytes)",
704 i + 1,
705 embeddings.len(),
706 vec.len() * 4
707 );
708 txn.execute(&insert_embedding_map_sql, [last_id])?;
709 let embedding_rowid = txn.last_insert_rowid();
710 let mut params = Vec::with_capacity(2 + metadata_values.len());
711 params.push(Value::Integer(embedding_rowid));
712 params.push(Value::Blob(vec.as_bytes().to_vec()));
713 params.extend(metadata_values.iter().cloned());
714 stmt.execute(rusqlite::params_from_iter(params))?;
715 }
716 }
717
718 Ok(last_id)
719 }
720
721 pub async fn add_rows(
722 &self,
723 documents: Vec<(T, OneOrMany<Embedding>)>,
724 ) -> Result<i64, VectorStoreError>
725 where
726 T: 'static,
727 Self: 'static,
728 {
729 let cloned = self.clone();
730
731 self.conn
732 .call(move |conn| {
733 let tx = conn.transaction()?;
734 let result = cloned.add_rows_with_txn(&tx, documents)?;
735 tx.commit()?;
736
737 Ok(result)
738 })
739 .await
740 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
741 }
742}
743
744impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
745where
746 E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
747 T: SqliteVectorStoreTable
748 + for<'de> Deserialize<'de>
749 + WasmCompatSend
750 + WasmCompatSync
751 + 'static,
752{
753 async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
754 &self,
755 documents: Vec<(Doc, OneOrMany<Embedding>)>,
756 ) -> Result<(), VectorStoreError> {
757 if documents.is_empty() {
758 return Ok(());
759 }
760
761 let rows = documents
762 .into_iter()
763 .map(|(document, embeddings)| {
764 let document = serde_json::to_value(document)?;
765 let row = serde_json::from_value::<T>(document)?;
766
767 Ok((row, embeddings))
768 })
769 .collect::<Result<Vec<_>, VectorStoreError>>()?;
770
771 self.add_rows(rows).await?;
772
773 Ok(())
774 }
775}
776
777#[derive(Clone, Deserialize, Serialize, Debug)]
789pub struct SqliteSearchFilter {
790 expr: SqliteSearchFilterExpr,
791}
792
793impl Default for SqliteSearchFilter {
794 fn default() -> Self {
795 Self {
796 expr: SqliteSearchFilterExpr::Raw {
797 condition: "1 = 1".to_string(),
798 params: Vec::new(),
799 },
800 }
801 }
802}
803
804#[derive(Clone, Deserialize, Serialize, Debug)]
805enum SqliteSearchFilterExpr {
806 Comparison {
807 key: String,
808 op: SqliteComparisonOp,
809 value: serde_json::Value,
810 },
811 And(Box<SqliteSearchFilterExpr>, Box<SqliteSearchFilterExpr>),
812 Or(Box<SqliteSearchFilterExpr>, Box<SqliteSearchFilterExpr>),
813 Not(Box<SqliteSearchFilterExpr>),
814 Between {
815 key: String,
816 lo: serde_json::Value,
817 hi: serde_json::Value,
818 },
819 NullCheck {
820 key: String,
821 negated: bool,
822 },
823 Pattern {
824 key: String,
825 op: SqlitePatternOp,
826 pattern: String,
827 },
828 Raw {
829 condition: String,
830 params: Vec<serde_json::Value>,
831 },
832}
833
834#[derive(Clone, Copy, Deserialize, Eq, PartialEq, Serialize, Debug)]
835enum SqliteComparisonOp {
836 Eq,
837 Ne,
838 Gt,
839 Gte,
840 Lt,
841 Lte,
842}
843
844impl SqliteComparisonOp {
845 fn as_sql(self) -> &'static str {
846 match self {
847 Self::Eq => "=",
848 Self::Ne => "!=",
849 Self::Gt => ">",
850 Self::Gte => ">=",
851 Self::Lt => "<",
852 Self::Lte => "<=",
853 }
854 }
855
856 fn negate(self) -> Self {
857 match self {
858 Self::Eq => Self::Ne,
859 Self::Ne => Self::Eq,
860 Self::Gt => Self::Lte,
861 Self::Gte => Self::Lt,
862 Self::Lt => Self::Gte,
863 Self::Lte => Self::Gt,
864 }
865 }
866}
867
868#[derive(Clone, Copy, Deserialize, Serialize, Debug)]
869enum SqlitePatternOp {
870 Glob,
871 Like,
872}
873
874impl SqlitePatternOp {
875 fn as_sql(self) -> &'static str {
876 match self {
877 Self::Glob => "glob",
878 Self::Like => "like",
879 }
880 }
881}
882
883#[derive(Debug, Default)]
884struct SqliteRenderedFilters {
885 native: Vec<SqliteRenderedFilter>,
886 post: Vec<SqliteRenderedFilter>,
887}
888
889impl SqliteRenderedFilters {
890 fn extend(&mut self, rhs: Self) {
891 self.native.extend(rhs.native);
892 self.post.extend(rhs.post);
893 }
894
895 fn has_post_filters(&self) -> bool {
896 !self.post.is_empty()
897 }
898}
899
900#[derive(Debug)]
901struct SqliteRenderedFilter {
902 condition: String,
903 params: Vec<Value>,
904}
905
906#[derive(Clone, Copy, Debug, Eq, PartialEq)]
907enum SqliteDocumentValueMode {
908 Sql,
909 JsonText,
910}
911
912#[derive(Debug)]
913struct SqliteQualifiedDocumentKey {
914 expression: String,
915 value_mode: SqliteDocumentValueMode,
916 plain_column: Option<String>,
917}
918
919impl SearchFilter for SqliteSearchFilter {
920 type Value = serde_json::Value;
921
922 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
923 Self {
924 expr: SqliteSearchFilterExpr::Comparison {
925 key: key.as_ref().to_string(),
926 op: SqliteComparisonOp::Eq,
927 value,
928 },
929 }
930 }
931
932 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
933 Self {
934 expr: SqliteSearchFilterExpr::Comparison {
935 key: key.as_ref().to_string(),
936 op: SqliteComparisonOp::Gt,
937 value,
938 },
939 }
940 }
941
942 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
943 Self {
944 expr: SqliteSearchFilterExpr::Comparison {
945 key: key.as_ref().to_string(),
946 op: SqliteComparisonOp::Lt,
947 value,
948 },
949 }
950 }
951
952 fn and(self, rhs: Self) -> Self {
953 Self {
954 expr: SqliteSearchFilterExpr::And(Box::new(self.expr), Box::new(rhs.expr)),
955 }
956 }
957
958 fn or(self, rhs: Self) -> Self {
959 Self {
960 expr: SqliteSearchFilterExpr::Or(Box::new(self.expr), Box::new(rhs.expr)),
961 }
962 }
963}
964
965impl SqliteSearchFilter {
966 #[allow(clippy::should_implement_trait)]
967 pub fn not(self) -> Self {
974 Self {
975 expr: SqliteSearchFilterExpr::Not(Box::new(self.expr)),
976 }
977 }
978
979 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
985 where
986 N: Into<serde_json::Value>,
987 {
988 let (lo, hi) = range.into_inner();
989
990 Self {
991 expr: SqliteSearchFilterExpr::Between {
992 key,
993 lo: lo.into(),
994 hi: hi.into(),
995 },
996 }
997 }
998
999 pub fn is_null(key: String) -> Self {
1001 Self {
1002 expr: SqliteSearchFilterExpr::NullCheck {
1003 key,
1004 negated: false,
1005 },
1006 }
1007 }
1008
1009 pub fn is_not_null(key: String) -> Self {
1010 Self {
1011 expr: SqliteSearchFilterExpr::NullCheck { key, negated: true },
1012 }
1013 }
1014
1015 pub fn glob(key: String, pattern: impl Into<String>) -> Self {
1020 Self {
1021 expr: SqliteSearchFilterExpr::Pattern {
1022 key,
1023 op: SqlitePatternOp::Glob,
1024 pattern: pattern.into(),
1025 },
1026 }
1027 }
1028
1029 pub fn like(key: String, pattern: impl Into<String>) -> Self {
1034 Self {
1035 expr: SqliteSearchFilterExpr::Pattern {
1036 key,
1037 op: SqlitePatternOp::Like,
1038 pattern: pattern.into(),
1039 },
1040 }
1041 }
1042}
1043
1044impl SqliteSearchFilter {
1045 fn raw(condition: impl Into<String>, params: Vec<serde_json::Value>) -> Self {
1046 Self {
1047 expr: SqliteSearchFilterExpr::Raw {
1048 condition: condition.into(),
1049 params,
1050 },
1051 }
1052 }
1053
1054 fn render_split(
1055 &self,
1056 metadata_columns: &[SqliteMetadataColumn],
1057 ) -> Result<SqliteRenderedFilters, FilterError> {
1058 self.expr.render_split(metadata_columns)
1059 }
1060}
1061
1062impl SqliteSearchFilterExpr {
1063 fn render_native_comparison(
1064 key: &str,
1065 op: SqliteComparisonOp,
1066 value: serde_json::Value,
1067 metadata_columns: &[SqliteMetadataColumn],
1068 ) -> Result<SqliteRenderedFilters, FilterError> {
1069 let Some(metadata_column) = sqlite_native_metadata_column(key, metadata_columns) else {
1070 return Ok(SqliteRenderedFilters {
1071 native: Vec::new(),
1072 post: vec![Self::render_document_comparison(
1073 key,
1074 op,
1075 value,
1076 metadata_columns,
1077 )?],
1078 });
1079 };
1080
1081 if !metadata_column.metadata_type.supports_native_comparison(op) {
1082 return Err(sqlite_unsupported_filter(format!(
1083 "`{key}` is a BOOLEAN metadata column, and sqlite-vec only supports `=` and `!=` filters for booleans"
1084 )));
1085 }
1086
1087 Ok(SqliteRenderedFilters {
1088 native: vec![SqliteRenderedFilter {
1089 condition: format!("e.{key} {} ?", op.as_sql()),
1090 params: vec![sqlite_metadata_filter_param(metadata_column, value)?],
1091 }],
1092 post: Vec::new(),
1093 })
1094 }
1095
1096 fn render_document_comparison(
1097 key: &str,
1098 op: SqliteComparisonOp,
1099 value: serde_json::Value,
1100 metadata_columns: &[SqliteMetadataColumn],
1101 ) -> Result<SqliteRenderedFilter, FilterError> {
1102 let key = sqlite_qualify_document_key(key)?;
1103 Ok(SqliteRenderedFilter {
1104 condition: format!("{} {} ?", key.expression, op.as_sql()),
1105 params: vec![sqlite_document_filter_param(&key, metadata_columns, value)?],
1106 })
1107 }
1108
1109 fn render_split(
1110 &self,
1111 metadata_columns: &[SqliteMetadataColumn],
1112 ) -> Result<SqliteRenderedFilters, FilterError> {
1113 match self {
1114 Self::Comparison { key, op, value } => {
1115 Self::render_native_comparison(key, *op, value.clone(), metadata_columns)
1116 }
1117 Self::And(lhs, rhs) => {
1118 let mut rendered = lhs.render_split(metadata_columns)?;
1119 rendered.extend(rhs.render_split(metadata_columns)?);
1120 Ok(rendered)
1121 }
1122 Self::Between { key, lo, hi } => {
1123 let Some(metadata_column) = sqlite_native_metadata_column(key, metadata_columns)
1124 else {
1125 return Ok(SqliteRenderedFilters {
1126 native: Vec::new(),
1127 post: vec![self.render_document(metadata_columns)?],
1128 });
1129 };
1130
1131 if metadata_column.metadata_type == SqliteMetadataType::Boolean {
1132 return Err(sqlite_unsupported_filter(format!(
1133 "`{key}` is a BOOLEAN metadata column, and sqlite-vec does not support range filters for booleans"
1134 )));
1135 }
1136
1137 Ok(SqliteRenderedFilters {
1138 native: vec![SqliteRenderedFilter {
1139 condition: format!("e.{key} >= ? AND e.{key} <= ?"),
1140 params: vec![
1141 sqlite_metadata_filter_param(metadata_column, lo.clone())?,
1142 sqlite_metadata_filter_param(metadata_column, hi.clone())?,
1143 ],
1144 }],
1145 post: Vec::new(),
1146 })
1147 }
1148 Self::Raw { condition, params } if condition == "1 = 1" && params.is_empty() => {
1149 Ok(SqliteRenderedFilters {
1150 native: Vec::new(),
1151 post: Vec::new(),
1152 })
1153 }
1154 Self::Or(_, _) => Ok(SqliteRenderedFilters {
1155 native: Vec::new(),
1156 post: vec![self.render_document(metadata_columns)?],
1157 }),
1158 Self::Not(expr) => expr.render_negated_split(metadata_columns),
1159 Self::NullCheck { .. } | Self::Pattern { .. } => Ok(SqliteRenderedFilters {
1160 native: Vec::new(),
1161 post: vec![self.render_document(metadata_columns)?],
1162 }),
1163 Self::Raw { .. } => Err(sqlite_unsupported_filter(
1164 "raw filters cannot be validated as sqlite-vec metadata constraints",
1165 )),
1166 }
1167 }
1168
1169 fn render_negated_split(
1170 &self,
1171 metadata_columns: &[SqliteMetadataColumn],
1172 ) -> Result<SqliteRenderedFilters, FilterError> {
1173 match self {
1174 Self::Comparison { key, op, value } => {
1175 Self::render_native_comparison(key, op.negate(), value.clone(), metadata_columns)
1176 }
1177 Self::Not(expr) => expr.render_split(metadata_columns),
1178 _ => {
1179 let rendered = self.render_document(metadata_columns)?;
1180 Ok(SqliteRenderedFilters {
1181 native: Vec::new(),
1182 post: vec![SqliteRenderedFilter {
1183 condition: format!("NOT ({})", rendered.condition),
1184 params: rendered.params,
1185 }],
1186 })
1187 }
1188 }
1189 }
1190
1191 fn render_vector(&self) -> Result<SqliteRenderedFilter, FilterError> {
1192 match self {
1193 Self::Comparison { key, op, value } => Ok(SqliteRenderedFilter {
1194 condition: format!("{} {} ?", sqlite_qualify_vector_key(key), op.as_sql()),
1195 params: vec![sqlite_filter_param(value.clone())?],
1196 }),
1197 Self::And(lhs, rhs) => {
1198 let lhs = lhs.render_vector()?;
1199 let rhs = rhs.render_vector()?;
1200 Ok(SqliteRenderedFilter {
1201 condition: format!("({}) AND ({})", lhs.condition, rhs.condition),
1202 params: lhs.params.into_iter().chain(rhs.params).collect(),
1203 })
1204 }
1205 Self::Or(lhs, rhs) => {
1206 let lhs = lhs.render_vector()?;
1207 let rhs = rhs.render_vector()?;
1208 Ok(SqliteRenderedFilter {
1209 condition: format!("({}) OR ({})", lhs.condition, rhs.condition),
1210 params: lhs.params.into_iter().chain(rhs.params).collect(),
1211 })
1212 }
1213 Self::Not(expr) => {
1214 let expr = expr.render_vector()?;
1215 Ok(SqliteRenderedFilter {
1216 condition: format!("NOT ({})", expr.condition),
1217 params: expr.params,
1218 })
1219 }
1220 Self::Between { key, lo, hi } => Ok(SqliteRenderedFilter {
1221 condition: format!("{} between ? and ?", sqlite_qualify_vector_key(key)),
1222 params: vec![
1223 sqlite_filter_param(lo.clone())?,
1224 sqlite_filter_param(hi.clone())?,
1225 ],
1226 }),
1227 Self::NullCheck { key, negated } => {
1228 let operator = if *negated { "is not null" } else { "is null" };
1229 Ok(SqliteRenderedFilter {
1230 condition: format!("{} {operator}", sqlite_qualify_vector_key(key)),
1231 params: Vec::new(),
1232 })
1233 }
1234 Self::Pattern { key, op, pattern } => Ok(SqliteRenderedFilter {
1235 condition: format!("{} {} ?", sqlite_qualify_vector_key(key), op.as_sql()),
1236 params: vec![Value::Text(pattern.clone())],
1237 }),
1238 Self::Raw { condition, params } => Ok(SqliteRenderedFilter {
1239 condition: condition.clone(),
1240 params: params
1241 .iter()
1242 .cloned()
1243 .map(sqlite_filter_param)
1244 .collect::<Result<Vec<_>, _>>()?,
1245 }),
1246 }
1247 }
1248
1249 fn render_document(
1250 &self,
1251 metadata_columns: &[SqliteMetadataColumn],
1252 ) -> Result<SqliteRenderedFilter, FilterError> {
1253 match self {
1254 Self::Comparison { key, op, value } => {
1255 Self::render_document_comparison(key, *op, value.clone(), metadata_columns)
1256 }
1257 Self::And(lhs, rhs) => {
1258 let lhs = lhs.render_document(metadata_columns)?;
1259 let rhs = rhs.render_document(metadata_columns)?;
1260 Ok(SqliteRenderedFilter {
1261 condition: format!("({}) AND ({})", lhs.condition, rhs.condition),
1262 params: lhs.params.into_iter().chain(rhs.params).collect(),
1263 })
1264 }
1265 Self::Or(lhs, rhs) => {
1266 let lhs = lhs.render_document(metadata_columns)?;
1267 let rhs = rhs.render_document(metadata_columns)?;
1268 Ok(SqliteRenderedFilter {
1269 condition: format!("({}) OR ({})", lhs.condition, rhs.condition),
1270 params: lhs.params.into_iter().chain(rhs.params).collect(),
1271 })
1272 }
1273 Self::Not(expr) => {
1274 let expr = expr.render_document(metadata_columns)?;
1275 Ok(SqliteRenderedFilter {
1276 condition: format!("NOT ({})", expr.condition),
1277 params: expr.params,
1278 })
1279 }
1280 Self::Between { key, lo, hi } => {
1281 let key = sqlite_qualify_document_key(key)?;
1282 Ok(SqliteRenderedFilter {
1283 condition: format!("{} between ? and ?", key.expression),
1284 params: vec![
1285 sqlite_document_filter_param(&key, metadata_columns, lo.clone())?,
1286 sqlite_document_filter_param(&key, metadata_columns, hi.clone())?,
1287 ],
1288 })
1289 }
1290 Self::NullCheck { key, negated } => {
1291 let key = sqlite_qualify_document_key(key)?;
1292 let operator = if *negated { "is not null" } else { "is null" };
1293 Ok(SqliteRenderedFilter {
1294 condition: format!("{} {operator}", key.expression),
1295 params: Vec::new(),
1296 })
1297 }
1298 Self::Pattern { key, op, pattern } => {
1299 let key = sqlite_qualify_document_key(key)?;
1300 Ok(SqliteRenderedFilter {
1301 condition: format!("{} {} ?", key.expression, op.as_sql()),
1302 params: vec![Value::Text(pattern.clone())],
1303 })
1304 }
1305 Self::Raw { .. } => Err(sqlite_unsupported_filter(
1306 "raw filters cannot be validated as document-table constraints",
1307 )),
1308 }
1309 }
1310}
1311
1312fn sqlite_native_metadata_column<'a>(
1313 key: &str,
1314 metadata_columns: &'a [SqliteMetadataColumn],
1315) -> Option<&'a SqliteMetadataColumn> {
1316 if !sqlite_is_plain_identifier(key) {
1317 return None;
1318 }
1319
1320 metadata_columns.iter().find(|column| column.name == key)
1321}
1322
1323fn sqlite_is_plain_identifier(key: &str) -> bool {
1324 let mut chars = key.chars();
1325 let Some(first) = chars.next() else {
1326 return false;
1327 };
1328
1329 (first == '_' || first.is_ascii_alphabetic())
1330 && chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
1331}
1332
1333fn sqlite_leading_identifier_len(key: &str) -> Option<usize> {
1334 let mut chars = key.char_indices();
1335 let (_, first) = chars.next()?;
1336 if !(first == '_' || first.is_ascii_alphabetic()) {
1337 return None;
1338 }
1339
1340 let mut end = first.len_utf8();
1341 for (index, c) in chars {
1342 if c == '_' || c.is_ascii_alphanumeric() {
1343 end = index + c.len_utf8();
1344 } else {
1345 break;
1346 }
1347 }
1348
1349 Some(end)
1350}
1351
1352fn sqlite_unsupported_filter(reason: impl Into<String>) -> FilterError {
1353 FilterError::TypeError(format!(
1354 "SQLite filter cannot be safely lowered; {}",
1355 reason.into()
1356 ))
1357}
1358
1359fn sqlite_json_type_name(value: &serde_json::Value) -> &'static str {
1360 match value {
1361 serde_json::Value::Null => "null",
1362 serde_json::Value::Bool(_) => "boolean",
1363 serde_json::Value::Number(_) => "number",
1364 serde_json::Value::String(_) => "string",
1365 serde_json::Value::Array(_) => "array",
1366 serde_json::Value::Object(_) => "object",
1367 }
1368}
1369
1370fn sqlite_metadata_filter_type_error(
1371 column: &SqliteMetadataColumn,
1372 value: &serde_json::Value,
1373 expected: &str,
1374) -> FilterError {
1375 sqlite_unsupported_filter(format!(
1376 "`{}` is a {} metadata column and requires {expected}; got {}",
1377 column.name,
1378 column.metadata_type.vec0_name(),
1379 sqlite_json_type_name(value)
1380 ))
1381}
1382
1383fn sqlite_metadata_filter_param(
1384 column: &SqliteMetadataColumn,
1385 value: serde_json::Value,
1386) -> Result<Value, FilterError> {
1387 match column.metadata_type {
1388 SqliteMetadataType::Text => match value {
1389 serde_json::Value::String(value) => Ok(Value::Text(value)),
1390 value => Err(sqlite_metadata_filter_type_error(
1391 column,
1392 &value,
1393 "a string filter value",
1394 )),
1395 },
1396 SqliteMetadataType::Integer => match value {
1397 serde_json::Value::Number(number) => {
1398 if let Some(value) = number.as_i64() {
1399 Ok(Value::Integer(value))
1400 } else if let Some(value) = number.as_u64() {
1401 i64::try_from(value).map(Value::Integer).map_err(|_| {
1402 FilterError::TypeError(format!(
1403 "SQLite integer filter value `{number}` exceeds i64::MAX"
1404 ))
1405 })
1406 } else {
1407 let value = serde_json::Value::Number(number);
1408 Err(sqlite_metadata_filter_type_error(
1409 column,
1410 &value,
1411 "an integer filter value",
1412 ))
1413 }
1414 }
1415 value => Err(sqlite_metadata_filter_type_error(
1416 column,
1417 &value,
1418 "an integer filter value",
1419 )),
1420 },
1421 SqliteMetadataType::Float => match value {
1422 serde_json::Value::Number(number) => {
1423 number.as_f64().map(Value::Real).ok_or_else(|| {
1424 let value = serde_json::Value::Number(number);
1425 sqlite_metadata_filter_type_error(
1426 column,
1427 &value,
1428 "a finite number filter value",
1429 )
1430 })
1431 }
1432 value => Err(sqlite_metadata_filter_type_error(
1433 column,
1434 &value,
1435 "a finite number filter value",
1436 )),
1437 },
1438 SqliteMetadataType::Boolean => match value {
1439 serde_json::Value::Bool(value) => Ok(Value::Integer(value as i64)),
1440 value => Err(sqlite_metadata_filter_type_error(
1441 column,
1442 &value,
1443 "a boolean filter value",
1444 )),
1445 },
1446 }
1447}
1448
1449fn sqlite_filter_param(value: serde_json::Value) -> Result<Value, FilterError> {
1450 use serde_json::Value::*;
1451
1452 match value {
1453 Null => Ok(Value::Null),
1454 Bool(b) => Ok(Value::Integer(b as i64)),
1455 String(s) => Ok(Value::Text(s)),
1456 Number(n) => Ok(if let Some(value) = n.as_i64() {
1457 Value::Integer(value)
1458 } else if let Some(value) = n.as_u64() {
1459 let value = i64::try_from(value).map_err(|_| {
1460 FilterError::TypeError(format!(
1461 "SQLite integer filter value `{n}` exceeds i64::MAX"
1462 ))
1463 })?;
1464 Value::Integer(value)
1465 } else if let Some(float) = n.as_f64() {
1466 Value::Real(float)
1467 } else {
1468 Value::Text(n.to_string())
1469 }),
1470 Array(arr) => {
1471 let blob =
1472 serde_json::to_vec(&arr).map_err(|e| FilterError::Serialization(e.to_string()))?;
1473
1474 Ok(Value::Blob(blob))
1475 }
1476 Object(obj) => {
1477 let blob =
1478 serde_json::to_vec(&obj).map_err(|e| FilterError::Serialization(e.to_string()))?;
1479
1480 Ok(Value::Blob(blob))
1481 }
1482 }
1483}
1484
1485fn sqlite_key_is_qualified(key: &str) -> bool {
1486 key.contains('.') || key.contains('(') || key.contains(' ') || key.contains('?')
1487}
1488
1489fn sqlite_qualify_vector_key(key: &str) -> String {
1490 if sqlite_key_is_qualified(key) {
1491 key.to_string()
1492 } else {
1493 format!("e.{key}")
1494 }
1495}
1496
1497fn sqlite_qualify_document_key(key: &str) -> Result<SqliteQualifiedDocumentKey, FilterError> {
1498 if let Some(key_without_alias) = key.strip_prefix("d.") {
1499 if sqlite_is_plain_identifier(key_without_alias) {
1500 return Ok(SqliteQualifiedDocumentKey {
1501 expression: key.to_string(),
1502 value_mode: SqliteDocumentValueMode::Sql,
1503 plain_column: Some(key_without_alias.to_string()),
1504 });
1505 }
1506
1507 if let Some(value_mode) = sqlite_json_operator_value_mode(key_without_alias) {
1508 return Ok(SqliteQualifiedDocumentKey {
1509 expression: key.to_string(),
1510 value_mode,
1511 plain_column: None,
1512 });
1513 }
1514
1515 return Err(sqlite_unsupported_filter(format!(
1516 "`{key}` is not a supported SQLite document filter expression"
1517 )));
1518 }
1519
1520 if sqlite_is_plain_identifier(key) {
1521 return Ok(SqliteQualifiedDocumentKey {
1522 expression: format!("d.{key}"),
1523 value_mode: SqliteDocumentValueMode::Sql,
1524 plain_column: Some(key.to_string()),
1525 });
1526 }
1527
1528 if let Some(value_mode) = sqlite_json_operator_value_mode(key) {
1529 return Ok(SqliteQualifiedDocumentKey {
1530 expression: format!("d.{key}"),
1531 value_mode,
1532 plain_column: None,
1533 });
1534 }
1535
1536 Err(sqlite_unsupported_filter(format!(
1537 "`{key}` is not a supported SQLite document filter expression"
1538 )))
1539}
1540
1541fn sqlite_document_filter_param(
1542 key: &SqliteQualifiedDocumentKey,
1543 metadata_columns: &[SqliteMetadataColumn],
1544 value: serde_json::Value,
1545) -> Result<Value, FilterError> {
1546 match key.value_mode {
1547 SqliteDocumentValueMode::Sql => {
1548 if let Some(column_name) = key.plain_column.as_deref()
1549 && let Some(metadata_column) = metadata_columns
1550 .iter()
1551 .find(|column| column.name == column_name)
1552 {
1553 return sqlite_metadata_filter_param(metadata_column, value);
1554 }
1555
1556 sqlite_filter_param(value)
1557 }
1558 SqliteDocumentValueMode::JsonText => serde_json::to_string(&value)
1559 .map(Value::Text)
1560 .map_err(|e| FilterError::Serialization(e.to_string())),
1561 }
1562}
1563
1564fn sqlite_json_operator_value_mode(expr: &str) -> Option<SqliteDocumentValueMode> {
1565 let mut index = sqlite_leading_identifier_len(expr)?;
1566
1567 if index == expr.len() {
1568 return None;
1569 }
1570
1571 let mut value_mode = None;
1572 while index < expr.len() {
1573 let remaining = &expr[index..];
1574 let (operator_len, next_value_mode) = if remaining.starts_with("->>") {
1575 (3, SqliteDocumentValueMode::Sql)
1576 } else if remaining.starts_with("->") {
1577 (2, SqliteDocumentValueMode::JsonText)
1578 } else {
1579 return None;
1580 };
1581 value_mode = Some(next_value_mode);
1582 index += operator_len;
1583
1584 let operand_len = sqlite_json_operator_operand_len(&expr[index..])?;
1585 index += operand_len;
1586 }
1587
1588 value_mode
1589}
1590
1591fn sqlite_json_operator_operand_len(operand: &str) -> Option<usize> {
1592 if operand.is_empty() {
1593 return None;
1594 }
1595
1596 if let Some(operand) = operand.strip_prefix('\'') {
1597 let closing_quote = operand.find('\'')?;
1598 let literal = &operand[..closing_quote];
1599 if literal.chars().any(char::is_control) {
1600 return None;
1601 }
1602
1603 return Some(closing_quote + 2);
1604 }
1605
1606 let mut chars = operand.char_indices();
1607 let mut end = 0;
1608 if let Some((_, '-')) = chars.clone().next() {
1609 end = 1;
1610 chars.next();
1611 }
1612
1613 let mut has_digit = false;
1614 for (index, c) in chars {
1615 if c.is_ascii_digit() {
1616 has_digit = true;
1617 end = index + c.len_utf8();
1618 } else {
1619 break;
1620 }
1621 }
1622
1623 has_digit.then_some(end)
1624}
1625
1626pub struct SqliteVectorIndex<E, T>
1725where
1726 E: EmbeddingModel + 'static,
1727 T: SqliteVectorStoreTable + 'static,
1728{
1729 store: SqliteVectorStore<E, T>,
1730 embedding_model: E,
1731}
1732
1733impl<E, T> SqliteVectorIndex<E, T>
1734where
1735 E: EmbeddingModel + 'static,
1736 T: SqliteVectorStoreTable,
1737{
1738 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
1739 Self {
1740 store,
1741 embedding_model,
1742 }
1743 }
1744}
1745
1746fn sqlite_distance_metric_from_schema(schema_sql: &str) -> SqliteDistanceMetric {
1747 let normalized = sqlite_normalized_schema(schema_sql);
1748
1749 if normalized.contains("distance_metric=cosine") {
1750 SqliteDistanceMetric::Cosine
1751 } else if normalized.contains("distance_metric=l1") {
1752 SqliteDistanceMetric::L1
1753 } else {
1754 SqliteDistanceMetric::L2
1755 }
1756}
1757
1758fn sqlite_normalized_schema(schema_sql: &str) -> String {
1759 schema_sql
1760 .chars()
1761 .filter(|c| !c.is_whitespace())
1762 .flat_map(char::to_lowercase)
1763 .collect()
1764}
1765
1766fn sqlite_schema_contains_metadata_column(schema_sql: &str, column: &SqliteMetadataColumn) -> bool {
1767 let normalized = sqlite_normalized_schema(schema_sql);
1768 let column_sql = format!(
1769 ",{}{}",
1770 column.name.to_ascii_lowercase(),
1771 column.metadata_type.vec0_name().to_ascii_lowercase()
1772 );
1773
1774 normalized.contains(&column_sql)
1775}
1776
1777struct SqliteSearchQuery {
1778 vector_where_clause: String,
1779 document_filter_clause: String,
1780 params: Vec<Value>,
1781}
1782
1783fn render_search_filters(
1784 req: &VectorSearchRequest<SqliteSearchFilter>,
1785 distance_metric: SqliteDistanceMetric,
1786 metadata_columns: &[SqliteMetadataColumn],
1787) -> Result<SqliteRenderedFilters, FilterError> {
1788 let score_expression = distance_metric.score_expression("?1", "e.embedding");
1789 let threshold_filter = req.threshold().map(|threshold| {
1790 SqliteSearchFilter::raw(format!("{score_expression} >= ?"), vec![threshold.into()])
1791 });
1792
1793 let mut filters = SqliteRenderedFilters::default();
1794 if let Some(threshold_filter) = threshold_filter {
1795 filters.native.push(threshold_filter.expr.render_vector()?);
1796 }
1797 if let Some(filter) = req.filter() {
1798 filters.extend(filter.render_split(metadata_columns)?);
1799 }
1800
1801 Ok(filters)
1802}
1803
1804fn build_search_query(
1805 query_vec: Vec<f32>,
1806 filters: SqliteRenderedFilters,
1807 candidate_limit: u64,
1808) -> Result<SqliteSearchQuery, FilterError> {
1809 let mut conditions = vec!["e.embedding MATCH ?".to_string(), "k = ?".to_string()];
1810 conditions.extend(
1811 filters
1812 .native
1813 .iter()
1814 .map(|filter| format!("({})", filter.condition)),
1815 );
1816
1817 let vector_where_clause = format!("WHERE {}", conditions.join(" AND "));
1818 let document_filter_clause = if filters.post.is_empty() {
1819 String::new()
1820 } else {
1821 format!(
1822 "AND {}",
1823 filters
1824 .post
1825 .iter()
1826 .map(|filter| format!("({})", filter.condition))
1827 .collect::<Vec<_>>()
1828 .join(" AND ")
1829 )
1830 };
1831
1832 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
1833 let query_vec = Value::Blob(query_vec);
1834 let candidate_limit = sqlite_limit_param(candidate_limit, "candidate limit")?;
1835
1836 let mut params = vec![query_vec.clone(), query_vec, candidate_limit];
1837 params.extend(filters.native.into_iter().flat_map(|filter| filter.params));
1838 params.extend(filters.post.into_iter().flat_map(|filter| filter.params));
1839
1840 Ok(SqliteSearchQuery {
1841 vector_where_clause,
1842 document_filter_clause,
1843 params,
1844 })
1845}
1846
1847#[cfg(test)]
1848fn build_where_clause(
1849 req: &VectorSearchRequest<SqliteSearchFilter>,
1850 query_vec: Vec<f32>,
1851 distance_metric: SqliteDistanceMetric,
1852 metadata_columns: &[SqliteMetadataColumn],
1853 candidate_limit: u64,
1854) -> Result<(String, Vec<Value>), FilterError> {
1855 let filters = render_search_filters(req, distance_metric, metadata_columns)?;
1856 let query = build_search_query(query_vec, filters, candidate_limit)?;
1857 Ok((query.vector_where_clause, query.params))
1858}
1859
1860fn sqlite_limit_param(value: u64, name: &str) -> Result<Value, FilterError> {
1861 i64::try_from(value)
1862 .map(Value::Integer)
1863 .map_err(|_| FilterError::TypeError(format!("SQLite {name} `{value}` exceeds i64::MAX")))
1864}
1865
1866#[derive(Debug)]
1867struct SqliteColumnValueError {
1868 column_name: &'static str,
1869 column_type: &'static str,
1870 message: String,
1871}
1872
1873impl Display for SqliteColumnValueError {
1874 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1875 write!(
1876 f,
1877 "could not convert SQLite column `{}` with declared type `{}`: {}",
1878 self.column_name, self.column_type, self.message
1879 )
1880 }
1881}
1882
1883impl std::error::Error for SqliteColumnValueError {}
1884
1885fn sqlite_column_value_error(
1886 index: usize,
1887 value_type: Type,
1888 column: &Column,
1889 message: impl Into<String>,
1890) -> rusqlite::Error {
1891 rusqlite::Error::FromSqlConversionFailure(
1892 index,
1893 value_type,
1894 Box::new(SqliteColumnValueError {
1895 column_name: column.name,
1896 column_type: column.col_type,
1897 message: message.into(),
1898 }),
1899 )
1900}
1901
1902fn sqlite_number_value(
1903 index: usize,
1904 value_type: Type,
1905 column: &Column,
1906 value: f64,
1907) -> rusqlite::Result<serde_json::Value> {
1908 let number = serde_json::Number::from_f64(value).ok_or_else(|| {
1909 sqlite_column_value_error(index, value_type, column, "non-finite float value")
1910 })?;
1911
1912 Ok(serde_json::Value::Number(number))
1913}
1914
1915fn sqlite_text_value(
1916 index: usize,
1917 value_type: Type,
1918 column: &Column,
1919 value: &[u8],
1920) -> rusqlite::Result<serde_json::Value> {
1921 let value = std::str::from_utf8(value).map_err(|e| {
1922 sqlite_column_value_error(
1923 index,
1924 value_type,
1925 column,
1926 format!("invalid UTF-8 text: {e}"),
1927 )
1928 })?;
1929
1930 Ok(serde_json::Value::String(value.to_string()))
1931}
1932
1933fn sqlite_column_declares_json(column_type: &str) -> bool {
1934 column_type
1935 .split_whitespace()
1936 .next()
1937 .is_some_and(|token| token.eq_ignore_ascii_case("JSON"))
1938}
1939
1940fn sqlite_json_text_value(
1941 index: usize,
1942 value_type: Type,
1943 column: &Column,
1944 value: &[u8],
1945) -> rusqlite::Result<serde_json::Value> {
1946 let value = std::str::from_utf8(value).map_err(|e| {
1947 sqlite_column_value_error(
1948 index,
1949 value_type,
1950 column,
1951 format!("invalid UTF-8 JSON text: {e}"),
1952 )
1953 })?;
1954
1955 serde_json::from_str(value).map_err(|e| {
1956 sqlite_column_value_error(index, value_type, column, format!("invalid JSON text: {e}"))
1957 })
1958}
1959
1960fn sqlite_column_value_to_json(
1961 index: usize,
1962 column: &Column,
1963 value: ValueRef<'_>,
1964) -> rusqlite::Result<serde_json::Value> {
1965 let value_type = value.data_type();
1966
1967 if sqlite_column_declares_json(column.col_type) {
1968 return match value {
1969 ValueRef::Null => Ok(serde_json::Value::Null),
1970 ValueRef::Text(value) => sqlite_json_text_value(index, value_type, column, value),
1971 ValueRef::Integer(value) => Ok(serde_json::Value::Number(value.into())),
1972 ValueRef::Real(value) => sqlite_number_value(index, value_type, column, value),
1973 ValueRef::Blob(value) => sqlite_json_text_value(index, value_type, column, value),
1974 };
1975 }
1976
1977 let column_affinity = SqliteColumnAffinity::from_column_type(column.col_type);
1978
1979 match (column_affinity, value) {
1980 (_, ValueRef::Null) => Ok(serde_json::Value::Null),
1981 (SqliteColumnAffinity::Boolean, ValueRef::Integer(0)) => Ok(serde_json::Value::Bool(false)),
1982 (SqliteColumnAffinity::Boolean, ValueRef::Integer(1)) => Ok(serde_json::Value::Bool(true)),
1983 (SqliteColumnAffinity::Boolean, _) => Err(sqlite_column_value_error(
1984 index,
1985 value_type,
1986 column,
1987 "stored SQLite boolean value must be 0 or 1",
1988 )),
1989 (_, ValueRef::Text(value)) => sqlite_text_value(index, value_type, column, value),
1990 (_, ValueRef::Integer(value)) => Ok(serde_json::Value::Number(value.into())),
1991 (_, ValueRef::Real(value)) => sqlite_number_value(index, value_type, column, value),
1992 (_, ValueRef::Blob(value)) => Ok(serde_json::to_value(value)
1993 .map_err(|e| sqlite_column_value_error(index, value_type, column, e.to_string()))?),
1994 }
1995}
1996
1997fn sqlite_id_value_to_string(index: usize, value: ValueRef<'_>) -> rusqlite::Result<String> {
1998 match value {
1999 ValueRef::Integer(value) => Ok(value.to_string()),
2000 ValueRef::Real(value) => Ok(value.to_string()),
2001 ValueRef::Text(value) => std::str::from_utf8(value)
2002 .map(ToString::to_string)
2003 .map_err(|e| {
2004 rusqlite::Error::FromSqlConversionFailure(
2005 index,
2006 Type::Text,
2007 Box::new(SqliteColumnValueError {
2008 column_name: "id",
2009 column_type: "TEXT",
2010 message: format!("invalid UTF-8 text: {e}"),
2011 }),
2012 )
2013 }),
2014 value => Err(rusqlite::Error::FromSqlConversionFailure(
2015 index,
2016 value.data_type(),
2017 Box::new(SqliteColumnValueError {
2018 column_name: "id",
2019 column_type: "TEXT or INTEGER",
2020 message: "id cannot be NULL or BLOB".to_string(),
2021 }),
2022 )),
2023 }
2024}
2025
2026impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
2027 for SqliteVectorIndex<E, T>
2028{
2029 type Filter = SqliteSearchFilter;
2030
2031 async fn top_n<D>(
2032 &self,
2033 req: VectorSearchRequest<SqliteSearchFilter>,
2034 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
2035 where
2036 D: for<'de> Deserialize<'de>,
2037 {
2038 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
2039 if req.samples() == 0 {
2040 return Ok(Vec::new());
2041 }
2042
2043 let embedding = self.embedding_model.embed_text(req.query()).await?;
2044 let query_vec: Vec<f32> = serialize_embedding(&embedding);
2045 let table_name = T::name();
2046 let embedding_map_table_name = format!("{table_name}_embedding_map");
2047
2048 let columns = T::schema();
2049 let id_column_index = columns
2050 .iter()
2051 .position(|column| column.name == "id")
2052 .ok_or_else(|| {
2053 VectorStoreError::DatastoreError(Box::new(SqliteMissingIdColumn {
2054 table_name: table_name.to_string(),
2055 }))
2056 })?;
2057
2058 let outer_select_cols = columns
2059 .iter()
2060 .map(|column| format!("d.{} AS {}", column.name, column.name))
2061 .collect::<Vec<_>>()
2062 .join(", ");
2063
2064 let distance_metric = self.store.distance_metric;
2065 let score_expression = distance_metric.score_expression("?1", "e.embedding");
2066 let filters = render_search_filters(&req, distance_metric, &self.store.metadata_columns)?;
2067 let candidate_limit = self
2068 .store
2069 .candidate_limit(req.samples(), filters.has_post_filters())
2070 .await?;
2071 let search_query = build_search_query(query_vec, filters, candidate_limit)?;
2072 let where_clause = search_query.vector_where_clause;
2073 let document_filter_clause = search_query.document_filter_clause;
2074 let mut params = search_query.params;
2075 params.push(sqlite_limit_param(req.samples(), "result limit")?);
2076
2077 let rows = self
2078 .store
2079 .conn
2080 .call(move |conn| {
2081 let mut stmt = conn.prepare(&format!(
2082 "WITH scored AS (
2083 SELECT m.document_rowid AS __rig_document_rowid,
2084 {score_expression} AS __rig_score,
2085 ROW_NUMBER() OVER (
2086 PARTITION BY m.document_rowid
2087 ORDER BY {score_expression} DESC, e.rowid ASC
2088 ) AS __rig_rank
2089 FROM {table_name}_embeddings e
2090 JOIN {embedding_map_table_name} m ON e.rowid = m.embedding_rowid
2091 {where_clause}
2092 )
2093 SELECT {outer_select_cols}, scored.__rig_score
2094 FROM scored
2095 JOIN {table_name} d ON scored.__rig_document_rowid = d.rowid
2096 WHERE scored.__rig_rank = 1
2097 {document_filter_clause}
2098 ORDER BY scored.__rig_score DESC, d.id ASC
2099 LIMIT ?"
2100 ))?;
2101
2102 let rows = stmt
2103 .query_map(rusqlite::params_from_iter(params), |row| {
2104 let mut map = serde_json::Map::new();
2106 for (i, column) in columns.iter().enumerate() {
2107 let value = sqlite_column_value_to_json(i, column, row.get_ref(i)?)?;
2108 map.insert(column.name.to_string(), value);
2109 }
2110 let score: f64 = row.get(columns.len())?;
2111 let id = sqlite_id_value_to_string(
2112 id_column_index,
2113 row.get_ref(id_column_index)?,
2114 )?;
2115
2116 Ok((id, serde_json::Value::Object(map), score))
2117 })?
2118 .collect::<Result<Vec<_>, _>>()?;
2119 Ok(rows)
2120 })
2121 .await
2122 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
2123
2124 debug!("Found {} potential matches", rows.len());
2125 let mut top_n = Vec::new();
2126 for (id, doc_value, score) in rows {
2127 match serde_json::from_value::<D>(doc_value) {
2128 Ok(doc) => {
2129 top_n.push((score, id, doc));
2130 }
2131 Err(e) => {
2132 debug!("Failed to deserialize document {}: {}", id, e);
2133 continue;
2134 }
2135 }
2136 }
2137
2138 debug!("Returning {} matches", top_n.len());
2139 Ok(top_n)
2140 }
2141
2142 async fn top_n_ids(
2143 &self,
2144 req: VectorSearchRequest<SqliteSearchFilter>,
2145 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
2146 tracing::debug!(
2147 "Finding top {} document IDs for query",
2148 req.samples() as usize
2149 );
2150 if req.samples() == 0 {
2151 return Ok(Vec::new());
2152 }
2153
2154 let embedding = self.embedding_model.embed_text(req.query()).await?;
2155 let query_vec = serialize_embedding(&embedding);
2156 let table_name = T::name();
2157 let embedding_map_table_name = format!("{table_name}_embedding_map");
2158
2159 let distance_metric = self.store.distance_metric;
2160 let score_expression = distance_metric.score_expression("?1", "e.embedding");
2161 let filters = render_search_filters(&req, distance_metric, &self.store.metadata_columns)?;
2162 let candidate_limit = self
2163 .store
2164 .candidate_limit(req.samples(), filters.has_post_filters())
2165 .await?;
2166 let search_query = build_search_query(query_vec, filters, candidate_limit)?;
2167 let where_clause = search_query.vector_where_clause;
2168 let document_filter_clause = search_query.document_filter_clause;
2169 let mut params = search_query.params;
2170 params.push(sqlite_limit_param(req.samples(), "result limit")?);
2171
2172 let results = self
2173 .store
2174 .conn
2175 .call(move |conn| {
2176 let mut stmt = conn.prepare(&format!(
2177 "WITH scored AS (
2178 SELECT m.document_rowid AS __rig_document_rowid,
2179 {score_expression} AS __rig_score,
2180 ROW_NUMBER() OVER (
2181 PARTITION BY m.document_rowid
2182 ORDER BY {score_expression} DESC, e.rowid ASC
2183 ) AS __rig_rank
2184 FROM {table_name}_embeddings e
2185 JOIN {embedding_map_table_name} m ON e.rowid = m.embedding_rowid
2186 {where_clause}
2187 )
2188 SELECT d.id, scored.__rig_score
2189 FROM scored
2190 JOIN {table_name} d ON scored.__rig_document_rowid = d.rowid
2191 WHERE scored.__rig_rank = 1
2192 {document_filter_clause}
2193 ORDER BY scored.__rig_score DESC, d.id ASC
2194 LIMIT ?"
2195 ))?;
2196
2197 let results = stmt
2198 .query_map(rusqlite::params_from_iter(params), |row| {
2199 Ok((
2200 row.get::<_, f64>(1)?,
2201 sqlite_id_value_to_string(0, row.get_ref(0)?)?,
2202 ))
2203 })?
2204 .collect::<Result<Vec<_>, _>>()?;
2205 Ok(results)
2206 })
2207 .await
2208 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
2209
2210 debug!("Found {} matching document IDs", results.len());
2211 Ok(results)
2212 }
2213}
2214
2215fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
2216 embedding.vec.iter().map(|x| *x as f32).collect()
2217}
2218
2219impl ColumnValue for String {
2220 fn to_sql_value(&self) -> Value {
2221 Value::Text(self.clone())
2222 }
2223
2224 fn column_type(&self) -> &'static str {
2225 "TEXT"
2226 }
2227}
2228
2229impl ColumnValue for i64 {
2230 fn to_sql_value(&self) -> Value {
2231 Value::Integer(*self)
2232 }
2233
2234 fn column_type(&self) -> &'static str {
2235 "INTEGER"
2236 }
2237}
2238
2239impl ColumnValue for i32 {
2240 fn to_sql_value(&self) -> Value {
2241 Value::Integer(i64::from(*self))
2242 }
2243
2244 fn column_type(&self) -> &'static str {
2245 "INTEGER"
2246 }
2247}
2248
2249impl ColumnValue for f64 {
2250 fn to_sql_value(&self) -> Value {
2251 Value::Real(*self)
2252 }
2253
2254 fn column_type(&self) -> &'static str {
2255 "FLOAT"
2256 }
2257}
2258
2259impl ColumnValue for f32 {
2260 fn to_sql_value(&self) -> Value {
2261 Value::Real(f64::from(*self))
2262 }
2263
2264 fn column_type(&self) -> &'static str {
2265 "FLOAT"
2266 }
2267}
2268
2269impl ColumnValue for bool {
2270 fn to_sql_value(&self) -> Value {
2271 Value::Integer(if *self { 1 } else { 0 })
2272 }
2273
2274 fn column_type(&self) -> &'static str {
2275 "BOOLEAN"
2276 }
2277}
2278
2279impl ColumnValue for serde_json::Value {
2280 fn to_sql_value(&self) -> Value {
2281 Value::Text(self.to_string())
2282 }
2283
2284 fn column_type(&self) -> &'static str {
2285 "JSON"
2286 }
2287}
2288
2289#[cfg(test)]
2290mod tests {
2291 use super::*;
2292 use rig_core::embeddings::EmbeddingError;
2293 use rusqlite::ffi::{sqlite3, sqlite3_api_routines, sqlite3_auto_extension};
2294 use sqlite_vec::sqlite3_vec_init;
2295 use std::cmp::Ordering;
2296 use std::os::raw::c_char;
2297 use std::sync::Once;
2298 use tokio_rusqlite::Connection;
2299
2300 const SCORE_EPSILON: f64 = 1e-5;
2301
2302 fn test_metadata_columns() -> Vec<SqliteMetadataColumn> {
2303 vec![SqliteMetadataColumn {
2304 name: "category",
2305 metadata_type: SqliteMetadataType::Text,
2306 }]
2307 }
2308
2309 fn typed_metadata_columns() -> Vec<SqliteMetadataColumn> {
2310 vec![
2311 SqliteMetadataColumn {
2312 name: "priority",
2313 metadata_type: SqliteMetadataType::Integer,
2314 },
2315 SqliteMetadataColumn {
2316 name: "rating",
2317 metadata_type: SqliteMetadataType::Float,
2318 },
2319 SqliteMetadataColumn {
2320 name: "published",
2321 metadata_type: SqliteMetadataType::Boolean,
2322 },
2323 ]
2324 }
2325
2326 #[test]
2327 fn json_column_text_decodes_to_json_object() -> anyhow::Result<()> {
2328 let column = Column::new("metadata", "JSON");
2329 let value = sqlite_column_value_to_json(
2330 0,
2331 &column,
2332 ValueRef::Text(br#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#),
2333 )?;
2334
2335 let expected = serde_json::json!({
2336 "knowledge_doc_id": 361,
2337 "knowledge_id": 1,
2338 "user_id": 1
2339 });
2340 anyhow::ensure!(
2341 value == expected,
2342 "JSON column text should decode to a JSON object, got {value:?}"
2343 );
2344
2345 Ok(())
2346 }
2347
2348 #[test]
2349 fn text_column_json_looking_text_stays_string() -> anyhow::Result<()> {
2350 let column = Column::new("metadata", "TEXT");
2351 let value = sqlite_column_value_to_json(
2352 0,
2353 &column,
2354 ValueRef::Text(br#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#),
2355 )?;
2356
2357 let expected =
2358 serde_json::json!(r#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#);
2359 anyhow::ensure!(
2360 value == expected,
2361 "TEXT column should preserve JSON-looking text as a string, got {value:?}"
2362 );
2363
2364 Ok(())
2365 }
2366
2367 #[test]
2368 fn json_column_invalid_text_returns_conversion_error() -> anyhow::Result<()> {
2369 let column = Column::new("metadata", "JSON");
2370 let err = match sqlite_column_value_to_json(0, &column, ValueRef::Text(b"not json")) {
2371 Ok(value) => anyhow::bail!("invalid JSON column text should fail, got {value:?}"),
2372 Err(err) => err,
2373 };
2374
2375 anyhow::ensure!(
2376 matches!(
2377 err,
2378 rusqlite::Error::FromSqlConversionFailure(0, Type::Text, _)
2379 ),
2380 "invalid JSON column text should return a conversion error, got {err}"
2381 );
2382
2383 Ok(())
2384 }
2385
2386 #[test]
2387 fn serde_json_value_column_value_round_trips_json_column() -> anyhow::Result<()> {
2388 let value = serde_json::json!({
2389 "knowledge_doc_id": 361,
2390 "knowledge_id": 1,
2391 "user_id": 1
2392 });
2393 anyhow::ensure!(
2394 value.column_type() == "JSON",
2395 "serde_json::Value should declare JSON column type"
2396 );
2397
2398 let text = match value.to_sql_value() {
2399 Value::Text(text) => text,
2400 value => {
2401 anyhow::bail!("serde_json::Value should serialize as JSON text, got {value:?}")
2402 }
2403 };
2404
2405 let column = Column::new("metadata", "JSON");
2406 let round_trip = sqlite_column_value_to_json(0, &column, ValueRef::Text(text.as_bytes()))?;
2407 anyhow::ensure!(
2408 round_trip == value,
2409 "serde_json::Value should round-trip through a JSON column, got {round_trip:?}"
2410 );
2411
2412 Ok(())
2413 }
2414
2415 fn filter_error<T: std::fmt::Debug>(
2416 result: Result<T, FilterError>,
2417 context: &str,
2418 ) -> anyhow::Result<FilterError> {
2419 match result {
2420 Ok(value) => anyhow::bail!("{context} should have failed, got {value:?}"),
2421 Err(err) => Ok(err),
2422 }
2423 }
2424
2425 fn ensure_vector_store_filter_error<T: std::fmt::Debug>(
2426 result: Result<T, VectorStoreError>,
2427 context: &str,
2428 ) -> anyhow::Result<()> {
2429 match result {
2430 Err(VectorStoreError::FilterError(_)) => Ok(()),
2431 Err(err) => anyhow::bail!("{context} returned unexpected error: {err}"),
2432 Ok(value) => anyhow::bail!("{context} should have failed, got {value:?}"),
2433 }
2434 }
2435
2436 #[test]
2437 fn threshold_filter_uses_computed_similarity_expression() -> anyhow::Result<()> {
2438 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2439 .query("needle")
2440 .samples(5)
2441 .threshold(0.95)
2442 .build();
2443
2444 let (where_clause, params) =
2445 build_where_clause(&req, vec![1.0, 0.0], SqliteDistanceMetric::Cosine, &[], 5)?;
2446
2447 anyhow::ensure!(
2448 where_clause.contains("e.embedding MATCH ?"),
2449 "missing vector match constraint: {where_clause}"
2450 );
2451 anyhow::ensure!(
2452 where_clause.contains("k = ?"),
2453 "missing vector k constraint: {where_clause}"
2454 );
2455 anyhow::ensure!(
2456 where_clause.contains("(1 - vec_distance_cosine(?1, e.embedding)) >= ?"),
2457 "threshold should use computed similarity expression: {where_clause}"
2458 );
2459 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2460 anyhow::ensure!(
2461 params.get(3) == Some(&Value::Real(0.95)),
2462 "unexpected threshold param: {params:?}"
2463 );
2464
2465 Ok(())
2466 }
2467
2468 #[test]
2469 fn l2_threshold_filter_uses_l2_score_expression() -> anyhow::Result<()> {
2470 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2471 .query("needle")
2472 .samples(5)
2473 .threshold(-1.5)
2474 .build();
2475
2476 let (where_clause, params) =
2477 build_where_clause(&req, vec![1.0, 0.0], SqliteDistanceMetric::L2, &[], 5)?;
2478
2479 anyhow::ensure!(
2480 where_clause.contains("(-vec_distance_l2(?1, e.embedding)) >= ?"),
2481 "threshold should use L2 score expression: {where_clause}"
2482 );
2483 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2484 anyhow::ensure!(
2485 params.get(3) == Some(&Value::Real(-1.5)),
2486 "unexpected threshold param: {params:?}"
2487 );
2488
2489 Ok(())
2490 }
2491
2492 #[test]
2493 fn no_threshold_does_not_add_similarity_predicate() -> anyhow::Result<()> {
2494 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2495 .query("needle")
2496 .samples(5)
2497 .build();
2498
2499 let (where_clause, params) =
2500 build_where_clause(&req, vec![1.0, 0.0], SqliteDistanceMetric::Cosine, &[], 5)?;
2501
2502 anyhow::ensure!(
2503 where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2504 "unexpected where clause: {where_clause}"
2505 );
2506 anyhow::ensure!(params.len() == 3, "unexpected params: {params:?}");
2507
2508 Ok(())
2509 }
2510
2511 #[test]
2512 fn or_filter_uses_document_filter_to_preserve_boolean_semantics() -> anyhow::Result<()> {
2513 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs")).or(
2514 SqliteSearchFilter::eq("title", serde_json::json!("archive")),
2515 );
2516
2517 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2518 .query("needle")
2519 .samples(5)
2520 .filter(filter)
2521 .build();
2522
2523 let filters =
2524 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2525 anyhow::ensure!(
2526 filters.has_post_filters(),
2527 "OR filters should be applied after vector candidate search"
2528 );
2529 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2530
2531 anyhow::ensure!(
2532 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2533 "OR filters should not be partially pushed into sqlite-vec: {}",
2534 query.vector_where_clause
2535 );
2536 anyhow::ensure!(
2537 query.document_filter_clause == "AND ((d.category = ?) OR (d.title = ?))",
2538 "unexpected document filter clause: {}",
2539 query.document_filter_clause
2540 );
2541 anyhow::ensure!(
2542 query.params.get(3) == Some(&Value::Text("docs".to_string()))
2543 && query.params.get(4) == Some(&Value::Text("archive".to_string())),
2544 "unexpected OR filter params: {:?}",
2545 query.params
2546 );
2547
2548 Ok(())
2549 }
2550
2551 #[test]
2552 fn indexed_filter_uses_vec0_metadata_constraint() -> anyhow::Result<()> {
2553 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2554 .query("needle")
2555 .samples(5)
2556 .filter(SqliteSearchFilter::eq(
2557 "category",
2558 serde_json::json!("docs"),
2559 ))
2560 .build();
2561
2562 let (where_clause, params) = build_where_clause(
2563 &req,
2564 vec![1.0, 0.0],
2565 SqliteDistanceMetric::Cosine,
2566 &test_metadata_columns(),
2567 5,
2568 )?;
2569
2570 anyhow::ensure!(
2571 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.category = ?)",
2572 "unexpected where clause: {where_clause}"
2573 );
2574 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2575 anyhow::ensure!(
2576 params.get(3) == Some(&Value::Text("docs".to_string())),
2577 "unexpected filter param: {params:?}"
2578 );
2579
2580 Ok(())
2581 }
2582
2583 #[test]
2584 fn negated_eq_filter_uses_vec0_metadata_inequality() -> anyhow::Result<()> {
2585 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2586 .query("needle")
2587 .samples(5)
2588 .filter(SqliteSearchFilter::eq("category", serde_json::json!("docs")).not())
2589 .build();
2590
2591 let (where_clause, params) = build_where_clause(
2592 &req,
2593 vec![1.0, 0.0],
2594 SqliteDistanceMetric::Cosine,
2595 &test_metadata_columns(),
2596 5,
2597 )?;
2598
2599 anyhow::ensure!(
2600 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.category != ?)",
2601 "unexpected where clause: {where_clause}"
2602 );
2603 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2604 anyhow::ensure!(
2605 params.get(3) == Some(&Value::Text("docs".to_string())),
2606 "unexpected filter param: {params:?}"
2607 );
2608
2609 Ok(())
2610 }
2611
2612 #[test]
2613 fn negated_range_comparison_uses_vec0_metadata_boundary() -> anyhow::Result<()> {
2614 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2615 .query("needle")
2616 .samples(5)
2617 .filter(SqliteSearchFilter::gt("priority", serde_json::json!(10)).not())
2618 .build();
2619
2620 let (where_clause, params) = build_where_clause(
2621 &req,
2622 vec![1.0, 0.0],
2623 SqliteDistanceMetric::Cosine,
2624 &typed_metadata_columns(),
2625 5,
2626 )?;
2627
2628 anyhow::ensure!(
2629 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.priority <= ?)",
2630 "unexpected where clause: {where_clause}"
2631 );
2632 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2633 anyhow::ensure!(
2634 params.get(3) == Some(&Value::Integer(10)),
2635 "unexpected filter param: {params:?}"
2636 );
2637
2638 Ok(())
2639 }
2640
2641 #[test]
2642 fn negated_boolean_eq_filter_uses_vec0_metadata_inequality() -> anyhow::Result<()> {
2643 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2644 .query("needle")
2645 .samples(5)
2646 .filter(SqliteSearchFilter::eq("published", serde_json::json!(true)).not())
2647 .build();
2648
2649 let (where_clause, params) = build_where_clause(
2650 &req,
2651 vec![1.0, 0.0],
2652 SqliteDistanceMetric::Cosine,
2653 &typed_metadata_columns(),
2654 5,
2655 )?;
2656
2657 anyhow::ensure!(
2658 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.published != ?)",
2659 "unexpected where clause: {where_clause}"
2660 );
2661 anyhow::ensure!(
2662 params.get(3) == Some(&Value::Integer(1)),
2663 "unexpected boolean filter param: {params:?}"
2664 );
2665
2666 Ok(())
2667 }
2668
2669 #[test]
2670 fn negated_between_filter_uses_document_filter() -> anyhow::Result<()> {
2671 let filter = SqliteSearchFilter::between("priority".to_string(), 1_i64..=10_i64).not();
2672 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2673 .query("needle")
2674 .samples(5)
2675 .filter(filter)
2676 .build();
2677
2678 let filters = render_search_filters(
2679 &req,
2680 SqliteDistanceMetric::Cosine,
2681 &typed_metadata_columns(),
2682 )?;
2683 anyhow::ensure!(
2684 filters.has_post_filters(),
2685 "negated range filters should be applied after vector candidate search"
2686 );
2687 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2688
2689 anyhow::ensure!(
2690 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2691 "negated range filters should not be partially pushed into sqlite-vec: {}",
2692 query.vector_where_clause
2693 );
2694 anyhow::ensure!(
2695 query.document_filter_clause == "AND (NOT (d.priority between ? and ?))",
2696 "unexpected document filter clause: {}",
2697 query.document_filter_clause
2698 );
2699 anyhow::ensure!(
2700 query.params.get(3) == Some(&Value::Integer(1))
2701 && query.params.get(4) == Some(&Value::Integer(10)),
2702 "unexpected negated between params: {:?}",
2703 query.params
2704 );
2705
2706 Ok(())
2707 }
2708
2709 #[test]
2710 fn boolean_range_filter_is_rejected() -> anyhow::Result<()> {
2711 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2712 .query("needle")
2713 .samples(5)
2714 .filter(SqliteSearchFilter::gt(
2715 "published",
2716 serde_json::json!(false),
2717 ))
2718 .build();
2719
2720 let err = filter_error(
2721 build_where_clause(
2722 &req,
2723 vec![1.0, 0.0],
2724 SqliteDistanceMetric::Cosine,
2725 &typed_metadata_columns(),
2726 5,
2727 ),
2728 "boolean range filters",
2729 )?;
2730
2731 anyhow::ensure!(
2732 err.to_string().contains("BOOLEAN"),
2733 "unexpected error for boolean range filter: {err}"
2734 );
2735
2736 Ok(())
2737 }
2738
2739 #[test]
2740 fn indexed_between_filter_uses_vec0_metadata_constraints() -> anyhow::Result<()> {
2741 let filter = SqliteSearchFilter::between("priority".to_string(), 1_i64..=10_i64);
2742 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2743 .query("needle")
2744 .samples(5)
2745 .filter(filter)
2746 .build();
2747
2748 let (where_clause, params) = build_where_clause(
2749 &req,
2750 vec![1.0, 0.0],
2751 SqliteDistanceMetric::Cosine,
2752 &typed_metadata_columns(),
2753 5,
2754 )?;
2755
2756 anyhow::ensure!(
2757 where_clause
2758 == "WHERE e.embedding MATCH ? AND k = ? AND (e.priority >= ? AND e.priority <= ?)",
2759 "unexpected where clause: {where_clause}"
2760 );
2761 anyhow::ensure!(params.len() == 5, "unexpected params: {params:?}");
2762 anyhow::ensure!(
2763 params.get(3) == Some(&Value::Integer(1)) && params.get(4) == Some(&Value::Integer(10)),
2764 "between bounds should be bound as parameters: {params:?}"
2765 );
2766
2767 Ok(())
2768 }
2769
2770 #[test]
2771 fn mismatched_metadata_filter_value_types_are_rejected() -> anyhow::Result<()> {
2772 let cases = [
2773 (
2774 SqliteSearchFilter::eq("published", serde_json::json!("true")),
2775 "boolean filter value",
2776 ),
2777 (
2778 SqliteSearchFilter::gt("priority", serde_json::json!(1.5)),
2779 "integer filter value",
2780 ),
2781 (
2782 SqliteSearchFilter::eq("category", serde_json::json!({ "name": "docs" })),
2783 "string filter value",
2784 ),
2785 (
2786 SqliteSearchFilter::between(
2787 "priority".to_string(),
2788 "1".to_string()..="10".to_string(),
2789 ),
2790 "integer filter value",
2791 ),
2792 ];
2793
2794 for (filter, expected) in cases {
2795 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2796 .query("needle")
2797 .samples(5)
2798 .filter(filter)
2799 .build();
2800
2801 let err = filter_error(
2802 build_where_clause(
2803 &req,
2804 vec![1.0, 0.0],
2805 SqliteDistanceMetric::Cosine,
2806 &typed_metadata_columns()
2807 .into_iter()
2808 .chain(test_metadata_columns())
2809 .collect::<Vec<_>>(),
2810 5,
2811 ),
2812 "mismatched metadata filter value",
2813 )?;
2814
2815 anyhow::ensure!(
2816 err.to_string().contains(expected),
2817 "unexpected error for mismatched metadata filter value: {err}"
2818 );
2819 }
2820
2821 Ok(())
2822 }
2823
2824 #[test]
2825 fn pattern_and_null_filters_use_document_filter() -> anyhow::Result<()> {
2826 let filter = SqliteSearchFilter::like("title".to_string(), "%O'Reilly%")
2827 .and(SqliteSearchFilter::glob("category".to_string(), "doc*"))
2828 .and(SqliteSearchFilter::is_null(
2829 "metadata->>'$.missing'".to_string(),
2830 ));
2831 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2832 .query("needle")
2833 .samples(5)
2834 .filter(filter)
2835 .build();
2836
2837 let filters =
2838 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2839 anyhow::ensure!(
2840 filters.has_post_filters(),
2841 "pattern and null filters should be applied after vector candidate search"
2842 );
2843 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2844
2845 anyhow::ensure!(
2846 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2847 "pattern filters should not be pushed into sqlite-vec: {}",
2848 query.vector_where_clause
2849 );
2850 anyhow::ensure!(
2851 query.document_filter_clause
2852 == "AND (d.title like ?) AND (d.category glob ?) AND (d.metadata->>'$.missing' is null)",
2853 "unexpected document filter clause: {}",
2854 query.document_filter_clause
2855 );
2856 anyhow::ensure!(
2857 query.params.get(3) == Some(&Value::Text("%O'Reilly%".to_string()))
2858 && query.params.get(4) == Some(&Value::Text("doc*".to_string())),
2859 "unexpected pattern filter params: {:?}",
2860 query.params
2861 );
2862
2863 Ok(())
2864 }
2865
2866 #[test]
2867 fn nonindexed_filters_use_document_filter() -> anyhow::Result<()> {
2868 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2869 .query("needle")
2870 .samples(5)
2871 .filter(SqliteSearchFilter::eq("title", serde_json::json!("docs")))
2872 .build();
2873
2874 let filters =
2875 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2876 anyhow::ensure!(
2877 filters.has_post_filters(),
2878 "non-indexed filters should be applied after vector candidate search"
2879 );
2880 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2881
2882 anyhow::ensure!(
2883 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2884 "unexpected vector where clause: {}",
2885 query.vector_where_clause
2886 );
2887 anyhow::ensure!(
2888 query.document_filter_clause == "AND (d.title = ?)",
2889 "unexpected document filter clause: {}",
2890 query.document_filter_clause
2891 );
2892 anyhow::ensure!(
2893 query.params.get(3) == Some(&Value::Text("docs".to_string())),
2894 "unexpected document filter param: {:?}",
2895 query.params
2896 );
2897
2898 Ok(())
2899 }
2900
2901 #[test]
2902 fn json_metadata_expression_uses_document_filter() -> anyhow::Result<()> {
2903 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2904 .query("needle")
2905 .samples(5)
2906 .filter(SqliteSearchFilter::eq(
2907 "metadata->>'$.xxx'",
2908 serde_json::json!("vvv"),
2909 ))
2910 .build();
2911
2912 let filters =
2913 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2914 anyhow::ensure!(
2915 filters.has_post_filters(),
2916 "JSON metadata expressions should be applied after vector candidate search"
2917 );
2918 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2919
2920 anyhow::ensure!(
2921 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2922 "unexpected vector where clause: {}",
2923 query.vector_where_clause
2924 );
2925 anyhow::ensure!(
2926 query.document_filter_clause == "AND (d.metadata->>'$.xxx' = ?)",
2927 "unexpected document filter clause: {}",
2928 query.document_filter_clause
2929 );
2930 anyhow::ensure!(
2931 query.params.get(3) == Some(&Value::Text("vvv".to_string())),
2932 "unexpected JSON metadata filter param: {:?}",
2933 query.params
2934 );
2935
2936 Ok(())
2937 }
2938
2939 #[test]
2940 fn json_metadata_arrow_expression_binds_rhs_as_json_text() -> anyhow::Result<()> {
2941 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2942 .query("needle")
2943 .samples(5)
2944 .filter(SqliteSearchFilter::eq(
2945 "metadata->'$.xxx'",
2946 serde_json::json!("vvv"),
2947 ))
2948 .build();
2949
2950 let filters =
2951 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2952 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2953
2954 anyhow::ensure!(
2955 query.document_filter_clause == "AND (d.metadata->'$.xxx' = ?)",
2956 "unexpected document filter clause: {}",
2957 query.document_filter_clause
2958 );
2959 anyhow::ensure!(
2960 query.params.get(3) == Some(&Value::Text("\"vvv\"".to_string())),
2961 "SQLite `->` should compare against JSON text: {:?}",
2962 query.params
2963 );
2964
2965 Ok(())
2966 }
2967
2968 #[test]
2969 fn chained_json_metadata_expression_uses_final_operator_for_param_mode() -> anyhow::Result<()> {
2970 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2971 .query("needle")
2972 .samples(5)
2973 .filter(SqliteSearchFilter::eq(
2974 "metadata->'$.nested'->>'$.xxx'",
2975 serde_json::json!("vvv"),
2976 ))
2977 .build();
2978
2979 let filters =
2980 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2981 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2982
2983 anyhow::ensure!(
2984 query.document_filter_clause == "AND (d.metadata->'$.nested'->>'$.xxx' = ?)",
2985 "unexpected document filter clause: {}",
2986 query.document_filter_clause
2987 );
2988 anyhow::ensure!(
2989 query.params.get(3) == Some(&Value::Text("vvv".to_string())),
2990 "final `->>` should compare against SQL scalar text: {:?}",
2991 query.params
2992 );
2993
2994 Ok(())
2995 }
2996
2997 #[test]
2998 fn unsupported_document_filter_expressions_are_rejected() -> anyhow::Result<()> {
2999 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3000 .query("needle")
3001 .samples(5)
3002 .filter(SqliteSearchFilter::eq(
3003 "metadata) OR 1 = 1 --",
3004 serde_json::json!("vvv"),
3005 ))
3006 .build();
3007
3008 let err = filter_error(
3009 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns()),
3010 "unsupported document filter expressions",
3011 )?;
3012
3013 anyhow::ensure!(
3014 err.to_string()
3015 .contains("supported SQLite document filter expression"),
3016 "unexpected error for unsupported document filter expression: {err}"
3017 );
3018
3019 Ok(())
3020 }
3021
3022 #[tokio::test]
3023 async fn live_search_orders_by_similarity_and_applies_threshold() -> anyhow::Result<()> {
3024 let index = live_test_index(
3025 "live_search_orders_by_similarity_and_applies_threshold",
3026 vec![
3027 row("exact", "docs", "exact match", vec![1.0, 0.0]),
3028 row("close", "docs", "close match", vec![0.8, 0.6]),
3029 row("opposite", "docs", "opposite match", vec![-1.0, 0.0]),
3030 ],
3031 )
3032 .await?;
3033
3034 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3035 .query("needle")
3036 .samples(3)
3037 .threshold(0.75)
3038 .build();
3039
3040 let results = index.top_n::<TestDocument>(req.clone()).await?;
3041 let ids = results
3042 .iter()
3043 .map(|(_, id, _)| id.as_str())
3044 .collect::<Vec<_>>();
3045 let exact_score = results.first().map(|(score, _, _)| *score);
3046 let close_score = results.get(1).map(|(score, _, _)| *score);
3047
3048 anyhow::ensure!(
3049 ids.as_slice() == ["exact", "close"],
3050 "unexpected ids: {ids:?}"
3051 );
3052 anyhow::ensure!(
3053 exact_score
3054 .zip(close_score)
3055 .is_some_and(|(exact, close)| exact > close),
3056 "expected exact score to be greater than close score: {results:?}"
3057 );
3058 anyhow::ensure!(
3059 results.iter().all(|(score, _, _)| *score > 0.75),
3060 "threshold should remove low-scoring rows: {results:?}"
3061 );
3062
3063 let id_results = index.top_n_ids(req).await?;
3064 let result_ids = id_results
3065 .iter()
3066 .map(|(_, id)| id.as_str())
3067 .collect::<Vec<_>>();
3068
3069 anyhow::ensure!(
3070 result_ids.as_slice() == ["exact", "close"],
3071 "unexpected top_n_ids ids: {id_results:?}"
3072 );
3073 anyhow::ensure!(
3074 id_results.iter().all(|(score, _)| *score > 0.75),
3075 "top_n_ids threshold should remove low-scoring rows: {id_results:?}"
3076 );
3077
3078 Ok(())
3079 }
3080
3081 #[tokio::test]
3082 async fn live_reinsert_same_document_id_removes_stale_vec0_candidates() -> anyhow::Result<()> {
3083 register_sqlite_vec_extension();
3084
3085 let conn = Connection::open(
3086 "file:live_reinsert_same_document_id_removes_stale_vec0_candidates?mode=memory",
3087 )
3088 .await?;
3089 let model = TestEmbeddingModel;
3090 let vector_store: SqliteVectorStore<_, TestDocument> =
3091 SqliteVectorStore::new(conn, &model).await?;
3092
3093 vector_store
3094 .add_rows(vec![row(
3095 "replace",
3096 "docs",
3097 "original near vector",
3098 vec![1.0, 0.0],
3099 )])
3100 .await?;
3101 vector_store
3102 .add_rows(vec![
3103 row("replace", "docs", "replacement far vector", vec![-1.0, 0.0]),
3104 row("fresh", "docs", "fresh near vector", vec![0.9, 0.1]),
3105 ])
3106 .await?;
3107
3108 let index = vector_store.index(model);
3109 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3110 .query("needle")
3111 .samples(1)
3112 .build();
3113
3114 let results = index.top_n::<TestDocument>(req.clone()).await?;
3115 let ids = results
3116 .iter()
3117 .map(|(_, id, _)| id.as_str())
3118 .collect::<Vec<_>>();
3119 anyhow::ensure!(
3120 ids.as_slice() == ["fresh"],
3121 "stale replaced vectors should not consume sqlite-vec candidates: {results:?}"
3122 );
3123
3124 let id_results = index.top_n_ids(req).await?;
3125 let result_ids = id_results
3126 .iter()
3127 .map(|(_, id)| id.as_str())
3128 .collect::<Vec<_>>();
3129 anyhow::ensure!(
3130 result_ids.as_slice() == ["fresh"],
3131 "top_n_ids should not return or be starved by stale replaced vectors: {id_results:?}"
3132 );
3133
3134 Ok(())
3135 }
3136
3137 #[tokio::test]
3138 async fn live_reinsert_preserves_unrelated_multivector_embeddings() -> anyhow::Result<()> {
3139 register_sqlite_vec_extension();
3140
3141 let conn = Connection::open(
3142 "file:live_reinsert_preserves_unrelated_multivector_embeddings?mode=memory",
3143 )
3144 .await?;
3145 let model = TestEmbeddingModel;
3146 let vector_store: SqliteVectorStore<_, TestDocument> =
3147 SqliteVectorStore::new(conn, &model).await?;
3148
3149 let multi_document = TestDocument {
3150 id: "multi".to_string(),
3151 category: "docs".to_string(),
3152 title: "multi-vector document".to_string(),
3153 };
3154 vector_store
3155 .add_rows(vec![
3156 (
3157 multi_document.clone(),
3158 OneOrMany::many(vec![
3159 Embedding {
3160 document: "far chunk".to_string(),
3161 vec: vec![-1.0, 0.0],
3162 },
3163 Embedding {
3164 document: "exact chunk".to_string(),
3165 vec: vec![1.0, 0.0],
3166 },
3167 ])?,
3168 ),
3169 row(
3170 "replace",
3171 "docs",
3172 "initial replacement vector",
3173 vec![0.8, 0.2],
3174 ),
3175 ])
3176 .await?;
3177 vector_store
3178 .add_rows(vec![row(
3179 "replace",
3180 "docs",
3181 "replacement far vector",
3182 vec![-1.0, 0.0],
3183 )])
3184 .await?;
3185
3186 let index = vector_store.index(model);
3187 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3188 .query("needle")
3189 .samples(1)
3190 .threshold(0.9)
3191 .build();
3192
3193 let results = index.top_n::<TestDocument>(req.clone()).await?;
3194 let ids = results
3195 .iter()
3196 .map(|(_, id, _)| id.as_str())
3197 .collect::<Vec<_>>();
3198 anyhow::ensure!(
3199 ids.as_slice() == ["multi"],
3200 "reinsert should not delete another document's best embedding: {results:?}"
3201 );
3202
3203 let id_results = index.top_n_ids(req).await?;
3204 let result_ids = id_results
3205 .iter()
3206 .map(|(_, id)| id.as_str())
3207 .collect::<Vec<_>>();
3208 anyhow::ensure!(
3209 result_ids.as_slice() == ["multi"],
3210 "top_n_ids should preserve unrelated multivector embeddings after reinsert: {id_results:?}"
3211 );
3212
3213 Ok(())
3214 }
3215
3216 #[tokio::test]
3217 async fn live_multiple_embeddings_per_document_use_best_embedding() -> anyhow::Result<()> {
3218 let multi_document = TestDocument {
3219 id: "multi".to_string(),
3220 category: "docs".to_string(),
3221 title: "multi-vector document".to_string(),
3222 };
3223 let index = live_test_index(
3224 "live_multiple_embeddings_per_document_use_best_embedding",
3225 vec![
3226 (
3227 multi_document.clone(),
3228 OneOrMany::many(vec![
3229 Embedding {
3230 document: "far chunk".to_string(),
3231 vec: vec![-1.0, 0.0],
3232 },
3233 Embedding {
3234 document: "exact chunk".to_string(),
3235 vec: vec![1.0, 0.0],
3236 },
3237 ])?,
3238 ),
3239 row("single", "docs", "single close chunk", vec![0.8, 0.6]),
3240 ],
3241 )
3242 .await?;
3243
3244 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3245 .query("needle")
3246 .samples(2)
3247 .build();
3248 let results = index.top_n::<TestDocument>(req.clone()).await?;
3249 let ids = results
3250 .iter()
3251 .map(|(_, id, _)| id.as_str())
3252 .collect::<Vec<_>>();
3253 anyhow::ensure!(
3254 ids.as_slice() == ["multi", "single"],
3255 "top_n should return each document once using its best embedding: {results:?}"
3256 );
3257
3258 let id_results = index.top_n_ids(req).await?;
3259 let result_ids = id_results
3260 .iter()
3261 .map(|(_, id)| id.as_str())
3262 .collect::<Vec<_>>();
3263 anyhow::ensure!(
3264 result_ids.as_slice() == ["multi", "single"],
3265 "top_n_ids should return each document once using its best embedding: {id_results:?}"
3266 );
3267
3268 let threshold_req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3269 .query("needle")
3270 .samples(2)
3271 .threshold(1.0)
3272 .build();
3273 let threshold_results = index.top_n::<TestDocument>(threshold_req.clone()).await?;
3274 let threshold_ids = threshold_results
3275 .iter()
3276 .map(|(_, id, _)| id.as_str())
3277 .collect::<Vec<_>>();
3278 anyhow::ensure!(
3279 threshold_ids.as_slice() == ["multi"],
3280 "threshold should include scores equal to the minimum and filter lower scores: {threshold_results:?}"
3281 );
3282
3283 let threshold_id_results = index.top_n_ids(threshold_req).await?;
3284 let threshold_result_ids = threshold_id_results
3285 .iter()
3286 .map(|(_, id)| id.as_str())
3287 .collect::<Vec<_>>();
3288 anyhow::ensure!(
3289 threshold_result_ids.as_slice() == ["multi"],
3290 "top_n_ids threshold should include scores equal to the minimum: {threshold_id_results:?}"
3291 );
3292
3293 Ok(())
3294 }
3295
3296 #[tokio::test]
3297 async fn live_equal_score_results_are_ordered_by_document_id() -> anyhow::Result<()> {
3298 let index = live_test_index(
3299 "live_equal_score_results_are_ordered_by_document_id",
3300 vec![
3301 row("b", "docs", "second id exact match", vec![1.0, 0.0]),
3302 row("a", "docs", "first id exact match", vec![1.0, 0.0]),
3303 ],
3304 )
3305 .await?;
3306
3307 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3308 .query("needle")
3309 .samples(2)
3310 .build();
3311
3312 let results = index.top_n::<TestDocument>(req.clone()).await?;
3313 let ids = results
3314 .iter()
3315 .map(|(_, id, _)| id.as_str())
3316 .collect::<Vec<_>>();
3317 anyhow::ensure!(
3318 ids.as_slice() == ["a", "b"],
3319 "equal-score top_n results should use document id as a stable tie-breaker: {results:?}"
3320 );
3321
3322 let id_results = index.top_n_ids(req).await?;
3323 let result_ids = id_results
3324 .iter()
3325 .map(|(_, id)| id.as_str())
3326 .collect::<Vec<_>>();
3327 anyhow::ensure!(
3328 result_ids.as_slice() == ["a", "b"],
3329 "equal-score top_n_ids results should use document id as a stable tie-breaker: {id_results:?}"
3330 );
3331
3332 Ok(())
3333 }
3334
3335 #[tokio::test]
3336 async fn live_common_sqlite_text_types_round_trip_in_top_n() -> anyhow::Result<()> {
3337 let index = live_common_type_test_index(
3338 "live_common_sqlite_text_types_round_trip_in_top_n",
3339 vec![common_type_row(
3340 "common",
3341 "varchar name",
3342 "clob notes",
3343 7,
3344 vec![1.0, 0.0],
3345 )],
3346 )
3347 .await?;
3348
3349 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3350 .query("needle")
3351 .samples(1)
3352 .build();
3353 let results = index.top_n::<CommonTypeDocument>(req).await?;
3354
3355 let Some((_, id, doc)) = results.first() else {
3356 anyhow::bail!("expected common type document result");
3357 };
3358 anyhow::ensure!(id == "common", "unexpected id: {id}");
3359 anyhow::ensure!(
3360 doc.name == "varchar name",
3361 "VARCHAR value should round-trip: {doc:?}"
3362 );
3363 anyhow::ensure!(
3364 doc.notes == "clob notes",
3365 "CLOB value should round-trip: {doc:?}"
3366 );
3367 anyhow::ensure!(doc.rank == 7, "NUMERIC value should round-trip: {doc:?}");
3368
3369 Ok(())
3370 }
3371
3372 #[tokio::test]
3373 async fn live_json_column_structured_metadata_round_trips_in_top_n() -> anyhow::Result<()> {
3374 let metadata = StructuredMetadata {
3375 user_id: 1,
3376 knowledge_id: 1,
3377 knowledge_doc_id: 361,
3378 };
3379 let index = live_structured_json_metadata_test_index(
3380 "live_json_column_structured_metadata_round_trips_in_top_n",
3381 vec![structured_json_metadata_row(
3382 "structured",
3383 metadata.clone(),
3384 "metadata document",
3385 vec![1.0, 0.0],
3386 )],
3387 )
3388 .await?;
3389
3390 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3391 .query("needle")
3392 .samples(1)
3393 .build();
3394 let results = index
3395 .top_n::<StructuredJsonMetadataDocument>(req.clone())
3396 .await?;
3397
3398 let Some((_, id, doc)) = results.first() else {
3399 anyhow::bail!("expected structured JSON metadata document result");
3400 };
3401 anyhow::ensure!(id == "structured", "unexpected id: {id}");
3402 anyhow::ensure!(
3403 doc.metadata == metadata,
3404 "JSON column should deserialize into structured metadata: {doc:?}"
3405 );
3406
3407 let id_results = index.top_n_ids(req).await?;
3408 anyhow::ensure!(
3409 id_results.first().is_some_and(|(_, id)| id == "structured"),
3410 "top_n_ids should still return the structured metadata document id: {id_results:?}"
3411 );
3412
3413 Ok(())
3414 }
3415
3416 #[tokio::test]
3417 async fn live_text_affinity_metadata_filters_during_candidate_search() -> anyhow::Result<()> {
3418 let index = live_common_type_test_index(
3419 "live_text_affinity_metadata_filters_during_candidate_search",
3420 vec![
3421 common_type_row("nearest", "misc", "nearest excluded", 1, vec![1.0, 0.0]),
3422 common_type_row("docs", "docs", "docs match", 2, vec![0.0, 1.0]),
3423 ],
3424 )
3425 .await?;
3426
3427 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3428 .query("needle")
3429 .samples(1)
3430 .filter(SqliteSearchFilter::eq("name", serde_json::json!("docs")))
3431 .build();
3432
3433 let results = index.top_n::<CommonTypeDocument>(req.clone()).await?;
3434 let ids = results
3435 .iter()
3436 .map(|(_, id, _)| id.as_str())
3437 .collect::<Vec<_>>();
3438
3439 anyhow::ensure!(
3440 ids.as_slice() == ["docs"],
3441 "VARCHAR metadata filters should constrain sqlite-vec candidate search: {results:?}"
3442 );
3443
3444 let id_results = index.top_n_ids(req).await?;
3445 let result_ids = id_results
3446 .iter()
3447 .map(|(_, id)| id.as_str())
3448 .collect::<Vec<_>>();
3449
3450 anyhow::ensure!(
3451 result_ids.as_slice() == ["docs"],
3452 "top_n_ids should use VARCHAR metadata filters during candidate search: {id_results:?}"
3453 );
3454
3455 Ok(())
3456 }
3457
3458 #[tokio::test]
3459 async fn live_l2_metric_is_consistent() -> anyhow::Result<()> {
3460 let index = live_test_index_with_metric(
3461 "live_l2_metric_is_consistent",
3462 vec![
3463 row("exact", "docs", "exact match", vec![1.0, 0.0]),
3464 row("l2-close", "docs", "l2 close match", vec![1.0, 1.0]),
3465 row(
3466 "same-direction-far",
3467 "docs",
3468 "same direction far away",
3469 vec![10.0, 0.0],
3470 ),
3471 ],
3472 SqliteDistanceMetric::L2,
3473 )
3474 .await?;
3475
3476 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3477 .query("needle")
3478 .samples(2)
3479 .threshold(-2.0)
3480 .build();
3481
3482 let results = index.top_n::<TestDocument>(req.clone()).await?;
3483 let ids = results
3484 .iter()
3485 .map(|(_, id, _)| id.as_str())
3486 .collect::<Vec<_>>();
3487 let exact_score = results
3488 .iter()
3489 .find(|(_, id, _)| id == "exact")
3490 .map(|(score, _, _)| *score);
3491 let close_score = results
3492 .iter()
3493 .find(|(_, id, _)| id == "l2-close")
3494 .map(|(score, _, _)| *score);
3495
3496 anyhow::ensure!(
3497 ids.as_slice() == ["exact", "l2-close"],
3498 "L2 search should return the nearest L2 candidates: {results:?}"
3499 );
3500 anyhow::ensure!(
3501 exact_score
3502 .zip(close_score)
3503 .is_some_and(|(exact, close)| exact > close && close > -2.0),
3504 "expected L2 scores to be ordered and thresholded: {results:?}"
3505 );
3506 anyhow::ensure!(
3507 results.iter().all(|(score, _, _)| *score > -2.0),
3508 "threshold should be applied to L2 scores: {results:?}"
3509 );
3510
3511 let id_results = index.top_n_ids(req).await?;
3512 let result_ids = id_results
3513 .iter()
3514 .map(|(_, id)| id.as_str())
3515 .collect::<Vec<_>>();
3516
3517 anyhow::ensure!(
3518 result_ids.as_slice() == ["exact", "l2-close"],
3519 "top_n_ids should use the same L2 metric: {id_results:?}"
3520 );
3521
3522 Ok(())
3523 }
3524
3525 #[tokio::test]
3526 async fn live_indexed_filter_is_applied_during_candidate_search() -> anyhow::Result<()> {
3527 let index = live_test_index(
3528 "live_indexed_filter_is_applied_during_candidate_search",
3529 vec![
3530 row(
3531 "nearest",
3532 "misc",
3533 "nearest excluded category",
3534 vec![1.0, 0.0],
3535 ),
3536 row("docs", "docs", "docs match", vec![0.0, 1.0]),
3537 ],
3538 )
3539 .await?;
3540
3541 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3542 .query("needle")
3543 .samples(1)
3544 .filter(SqliteSearchFilter::eq(
3545 "category",
3546 serde_json::json!("docs"),
3547 ))
3548 .build();
3549
3550 let results = index.top_n::<TestDocument>(req.clone()).await?;
3551 let ids = results
3552 .iter()
3553 .map(|(_, id, _)| id.as_str())
3554 .collect::<Vec<_>>();
3555
3556 anyhow::ensure!(
3557 ids.as_slice() == ["docs"],
3558 "indexed filters should constrain sqlite-vec candidate search: {results:?}"
3559 );
3560
3561 let id_results = index.top_n_ids(req).await?;
3562 let result_ids = id_results
3563 .iter()
3564 .map(|(_, id)| id.as_str())
3565 .collect::<Vec<_>>();
3566
3567 anyhow::ensure!(
3568 result_ids.as_slice() == ["docs"],
3569 "top_n_ids should use indexed filters during candidate search: {id_results:?}"
3570 );
3571
3572 Ok(())
3573 }
3574
3575 #[tokio::test]
3576 async fn live_nonindexed_filter_is_applied_after_candidate_search() -> anyhow::Result<()> {
3577 let index = live_test_index(
3578 "live_nonindexed_filter_is_applied_after_candidate_search",
3579 vec![
3580 row("nearest", "docs", "nearest excluded title", vec![1.0, 0.0]),
3581 row("wanted", "docs", "wanted title", vec![0.0, 1.0]),
3582 ],
3583 )
3584 .await?;
3585
3586 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3587 .query("needle")
3588 .samples(1)
3589 .filter(SqliteSearchFilter::eq(
3590 "title",
3591 serde_json::json!("wanted title"),
3592 ))
3593 .build();
3594
3595 let results = index.top_n::<TestDocument>(req.clone()).await?;
3596 let ids = results
3597 .iter()
3598 .map(|(_, id, _)| id.as_str())
3599 .collect::<Vec<_>>();
3600 anyhow::ensure!(
3601 ids.as_slice() == ["wanted"],
3602 "non-indexed filters should not be starved by the initial candidate limit: {results:?}"
3603 );
3604
3605 let id_results = index.top_n_ids(req).await?;
3606 let result_ids = id_results
3607 .iter()
3608 .map(|(_, id)| id.as_str())
3609 .collect::<Vec<_>>();
3610 anyhow::ensure!(
3611 result_ids.as_slice() == ["wanted"],
3612 "top_n_ids should apply non-indexed filters after candidate search: {id_results:?}"
3613 );
3614
3615 Ok(())
3616 }
3617
3618 #[tokio::test]
3619 async fn live_json_metadata_filter_is_applied_after_candidate_search() -> anyhow::Result<()> {
3620 let index = live_json_metadata_test_index(
3621 "live_json_metadata_filter_is_applied_after_candidate_search",
3622 vec![
3623 json_metadata_row("nearest", "docs", "skip", "nearest skipped", vec![1.0, 0.0]),
3624 json_metadata_row("matched", "docs", "vvv", "metadata match", vec![0.0, 1.0]),
3625 ],
3626 )
3627 .await?;
3628
3629 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3630 .query("needle")
3631 .samples(1)
3632 .filter(SqliteSearchFilter::eq(
3633 "metadata->>'$.xxx'",
3634 serde_json::json!("vvv"),
3635 ))
3636 .build();
3637
3638 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
3639 let ids = results
3640 .iter()
3641 .map(|(_, id, _)| id.as_str())
3642 .collect::<Vec<_>>();
3643 anyhow::ensure!(
3644 ids.as_slice() == ["matched"],
3645 "JSON metadata filters should not be starved by the initial candidate limit: {results:?}"
3646 );
3647
3648 let id_results = index.top_n_ids(req).await?;
3649 let result_ids = id_results
3650 .iter()
3651 .map(|(_, id)| id.as_str())
3652 .collect::<Vec<_>>();
3653 anyhow::ensure!(
3654 result_ids.as_slice() == ["matched"],
3655 "top_n_ids should apply JSON metadata filters after candidate search: {id_results:?}"
3656 );
3657
3658 Ok(())
3659 }
3660
3661 #[tokio::test]
3662 async fn live_json_arrow_filter_compares_against_json_text() -> anyhow::Result<()> {
3663 let index = live_json_metadata_test_index(
3664 "live_json_arrow_filter_compares_against_json_text",
3665 vec![
3666 json_metadata_row("nearest", "docs", "skip", "nearest skipped", vec![1.0, 0.0]),
3667 json_metadata_row("matched", "docs", "vvv", "metadata match", vec![0.0, 1.0]),
3668 ],
3669 )
3670 .await?;
3671
3672 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3673 .query("needle")
3674 .samples(1)
3675 .filter(SqliteSearchFilter::eq(
3676 "metadata->'$.xxx'",
3677 serde_json::json!("vvv"),
3678 ))
3679 .build();
3680
3681 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
3682 let ids = results
3683 .iter()
3684 .map(|(_, id, _)| id.as_str())
3685 .collect::<Vec<_>>();
3686 anyhow::ensure!(
3687 ids.as_slice() == ["matched"],
3688 "SQLite `->` JSON filters should compare against JSON text: {results:?}"
3689 );
3690
3691 let id_results = index.top_n_ids(req).await?;
3692 let result_ids = id_results
3693 .iter()
3694 .map(|(_, id)| id.as_str())
3695 .collect::<Vec<_>>();
3696 anyhow::ensure!(
3697 result_ids.as_slice() == ["matched"],
3698 "top_n_ids should apply SQLite `->` JSON filters: {id_results:?}"
3699 );
3700
3701 Ok(())
3702 }
3703
3704 #[tokio::test]
3705 async fn live_mixed_indexed_and_json_metadata_filters_are_applied() -> anyhow::Result<()> {
3706 let index = live_json_metadata_test_index(
3707 "live_mixed_indexed_and_json_metadata_filters_are_applied",
3708 vec![
3709 json_metadata_row(
3710 "nearest-docs",
3711 "docs",
3712 "skip",
3713 "nearest docs skipped by JSON metadata",
3714 vec![1.0, 0.0],
3715 ),
3716 json_metadata_row(
3717 "nearest-json",
3718 "misc",
3719 "vvv",
3720 "nearest JSON match skipped by category",
3721 vec![0.9, 0.1],
3722 ),
3723 json_metadata_row(
3724 "matched",
3725 "docs",
3726 "vvv",
3727 "matching category and JSON metadata",
3728 vec![0.0, 1.0],
3729 ),
3730 ],
3731 )
3732 .await?;
3733
3734 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs")).and(
3735 SqliteSearchFilter::eq("metadata->>'$.xxx'", serde_json::json!("vvv")),
3736 );
3737 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3738 .query("needle")
3739 .samples(1)
3740 .filter(filter)
3741 .build();
3742
3743 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
3744 let ids = results
3745 .iter()
3746 .map(|(_, id, _)| id.as_str())
3747 .collect::<Vec<_>>();
3748 anyhow::ensure!(
3749 ids.as_slice() == ["matched"],
3750 "indexed and JSON metadata filters should both be applied: {results:?}"
3751 );
3752
3753 let id_results = index.top_n_ids(req).await?;
3754 let result_ids = id_results
3755 .iter()
3756 .map(|(_, id)| id.as_str())
3757 .collect::<Vec<_>>();
3758 anyhow::ensure!(
3759 result_ids.as_slice() == ["matched"],
3760 "top_n_ids should apply both indexed and JSON metadata filters: {id_results:?}"
3761 );
3762
3763 Ok(())
3764 }
3765
3766 #[tokio::test]
3767 async fn live_negated_eq_filter_is_applied_during_candidate_search() -> anyhow::Result<()> {
3768 let index = live_test_index(
3769 "live_negated_eq_filter_is_applied_during_candidate_search",
3770 vec![
3771 row(
3772 "nearest",
3773 "misc",
3774 "nearest excluded category",
3775 vec![1.0, 0.0],
3776 ),
3777 row("docs", "docs", "docs match", vec![0.0, 1.0]),
3778 ],
3779 )
3780 .await?;
3781
3782 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3783 .query("needle")
3784 .samples(1)
3785 .filter(SqliteSearchFilter::eq("category", serde_json::json!("misc")).not())
3786 .build();
3787
3788 let results = index.top_n::<TestDocument>(req.clone()).await?;
3789 let ids = results
3790 .iter()
3791 .map(|(_, id, _)| id.as_str())
3792 .collect::<Vec<_>>();
3793
3794 anyhow::ensure!(
3795 ids.as_slice() == ["docs"],
3796 "negated filters should constrain sqlite-vec candidate search: {results:?}"
3797 );
3798
3799 let id_results = index.top_n_ids(req).await?;
3800 let result_ids = id_results
3801 .iter()
3802 .map(|(_, id)| id.as_str())
3803 .collect::<Vec<_>>();
3804
3805 anyhow::ensure!(
3806 result_ids.as_slice() == ["docs"],
3807 "top_n_ids should use negated filters during candidate search: {id_results:?}"
3808 );
3809
3810 Ok(())
3811 }
3812
3813 #[tokio::test]
3814 async fn live_top_n_reads_id_by_column_name_not_schema_position() -> anyhow::Result<()> {
3815 register_sqlite_vec_extension();
3816
3817 let conn = Connection::open(
3818 "file:live_top_n_reads_id_by_column_name_not_schema_position?mode=memory",
3819 )
3820 .await?;
3821 let model = TestEmbeddingModel;
3822 let vector_store: SqliteVectorStore<_, ReorderedIdDocument> =
3823 SqliteVectorStore::new(conn, &model).await?;
3824
3825 vector_store
3826 .add_rows(vec![
3827 reordered_id_row("winner", "winner title", "docs", vec![1.0, 0.0]),
3828 reordered_id_row("other", "other title", "docs", vec![0.0, 1.0]),
3829 ])
3830 .await?;
3831
3832 let index = vector_store.index(model);
3833 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3834 .query("needle")
3835 .samples(1)
3836 .build();
3837
3838 let results = index.top_n::<ReorderedIdDocument>(req.clone()).await?;
3839 let Some((_, id, doc)) = results.first() else {
3840 anyhow::bail!("expected reordered-id result");
3841 };
3842 anyhow::ensure!(
3843 id == "winner",
3844 "top_n should return the id column, not the first schema column: {results:?}"
3845 );
3846 anyhow::ensure!(
3847 doc.id == "winner" && doc.title == "winner title",
3848 "document columns should still deserialize in schema order: {doc:?}"
3849 );
3850
3851 let id_results = index.top_n_ids(req).await?;
3852 anyhow::ensure!(
3853 id_results.first().map(|(_, id)| id.as_str()) == Some("winner"),
3854 "top_n_ids should agree with top_n id handling: {id_results:?}"
3855 );
3856
3857 Ok(())
3858 }
3859
3860 #[tokio::test]
3861 async fn live_internal_score_and_rank_column_names_do_not_shadow_search_columns()
3862 -> anyhow::Result<()> {
3863 register_sqlite_vec_extension();
3864
3865 let conn = Connection::open(
3866 "file:live_internal_score_and_rank_column_names_do_not_shadow_search_columns?mode=memory",
3867 )
3868 .await?;
3869 let model = TestEmbeddingModel;
3870 let vector_store: SqliteVectorStore<_, InternalAliasDocument> =
3871 SqliteVectorStore::new(conn, &model).await?;
3872
3873 vector_store
3874 .add_rows(vec![
3875 internal_alias_row(
3876 "winner",
3877 "payload score",
3878 "payload rank",
3879 "winner title",
3880 vec![1.0, 0.0],
3881 ),
3882 internal_alias_row(
3883 "other",
3884 "other score",
3885 "other rank",
3886 "other title",
3887 vec![0.0, 1.0],
3888 ),
3889 ])
3890 .await?;
3891
3892 let index = vector_store.index(model);
3893 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3894 .query("needle")
3895 .samples(1)
3896 .threshold(0.9)
3897 .build();
3898
3899 let results = index.top_n::<InternalAliasDocument>(req.clone()).await?;
3900 let Some((score, id, doc)) = results.first() else {
3901 anyhow::bail!("expected internal-alias document result");
3902 };
3903
3904 anyhow::ensure!(id == "winner", "unexpected id: {results:?}");
3905 anyhow::ensure!(
3906 (*score - 1.0).abs() <= SCORE_EPSILON,
3907 "top_n should return computed score, not the document __rig_score column: {results:?}"
3908 );
3909 anyhow::ensure!(
3910 doc.rig_score == "payload score" && doc.rig_rank == "payload rank",
3911 "document columns with internal-looking names should still deserialize: {doc:?}"
3912 );
3913
3914 let id_results = index.top_n_ids(req).await?;
3915 anyhow::ensure!(
3916 id_results
3917 .first()
3918 .map(|(score, id)| ((*score - 1.0).abs() <= SCORE_EPSILON, id.as_str()))
3919 == Some((true, "winner")),
3920 "top_n_ids should agree with top_n despite internal-looking document columns: {id_results:?}"
3921 );
3922
3923 Ok(())
3924 }
3925
3926 #[tokio::test]
3927 async fn live_typed_columns_round_trip_and_filter_during_candidate_search() -> anyhow::Result<()>
3928 {
3929 let index = live_typed_test_index(
3930 "live_typed_columns_round_trip_and_filter_during_candidate_search",
3931 vec![
3932 typed_row(
3933 1,
3934 "misc",
3935 100,
3936 0.99,
3937 true,
3938 "nearest excluded by typed metadata",
3939 vec![1.0, 0.0],
3940 ),
3941 typed_row(2, "docs", 5, 0.95, true, "typed docs match", vec![0.0, 1.0]),
3942 typed_row(
3943 3,
3944 "docs",
3945 5,
3946 0.97,
3947 false,
3948 "unpublished docs match",
3949 vec![0.0, 0.9],
3950 ),
3951 ],
3952 )
3953 .await?;
3954
3955 let filter = SqliteSearchFilter::lt("priority", serde_json::json!(10))
3956 .and(SqliteSearchFilter::gt("rating", serde_json::json!(0.9)))
3957 .and(SqliteSearchFilter::eq("published", serde_json::json!(true)));
3958 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3959 .query("needle")
3960 .samples(1)
3961 .filter(filter)
3962 .build();
3963
3964 let results = index.top_n::<TypedTestDocument>(req.clone()).await?;
3965 anyhow::ensure!(
3966 results.len() == 1,
3967 "expected one typed document result: {results:?}"
3968 );
3969
3970 let Some((_, id, doc)) = results.first() else {
3971 anyhow::bail!("expected one typed document result");
3972 };
3973 anyhow::ensure!(id == "2", "expected integer id to be returned as string");
3974 anyhow::ensure!(doc.id == 2, "typed integer id should round-trip: {doc:?}");
3975 anyhow::ensure!(
3976 doc.priority == 5,
3977 "typed integer field should round-trip: {doc:?}"
3978 );
3979 anyhow::ensure!(
3980 (doc.rating - 0.95).abs() < f64::EPSILON,
3981 "typed float field should round-trip: {doc:?}"
3982 );
3983 anyhow::ensure!(
3984 doc.published,
3985 "typed boolean field should round-trip: {doc:?}"
3986 );
3987
3988 let id_results = index.top_n_ids(req).await?;
3989 let result_ids = id_results
3990 .iter()
3991 .map(|(_, id)| id.as_str())
3992 .collect::<Vec<_>>();
3993 anyhow::ensure!(
3994 result_ids.as_slice() == ["2"],
3995 "top_n_ids should use the same typed metadata filters: {id_results:?}"
3996 );
3997
3998 Ok(())
3999 }
4000
4001 #[tokio::test]
4002 async fn live_boolean_range_filter_is_rejected() -> anyhow::Result<()> {
4003 let index = live_typed_test_index(
4004 "live_boolean_range_filter_is_rejected",
4005 vec![
4006 typed_row(
4007 1,
4008 "misc",
4009 1,
4010 0.5,
4011 false,
4012 "nearest unpublished doc",
4013 vec![1.0, 0.0],
4014 ),
4015 typed_row(2, "docs", 2, 0.7, true, "published doc", vec![0.0, 1.0]),
4016 ],
4017 )
4018 .await?;
4019
4020 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4021 .query("needle")
4022 .samples(2)
4023 .filter(SqliteSearchFilter::gt(
4024 "published",
4025 serde_json::json!(false),
4026 ))
4027 .build();
4028
4029 ensure_vector_store_filter_error(
4030 index.top_n::<TypedTestDocument>(req.clone()).await,
4031 "top_n boolean range filter",
4032 )?;
4033 ensure_vector_store_filter_error(
4034 index.top_n_ids(req).await,
4035 "top_n_ids boolean range filter",
4036 )?;
4037
4038 Ok(())
4039 }
4040
4041 #[tokio::test]
4042 async fn live_mismatched_metadata_filter_value_type_is_rejected() -> anyhow::Result<()> {
4043 let index = live_typed_test_index(
4044 "live_mismatched_metadata_filter_value_type_is_rejected",
4045 vec![typed_row(
4046 1,
4047 "docs",
4048 1,
4049 0.95,
4050 true,
4051 "published doc",
4052 vec![1.0, 0.0],
4053 )],
4054 )
4055 .await?;
4056
4057 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4058 .query("needle")
4059 .samples(1)
4060 .filter(SqliteSearchFilter::eq(
4061 "published",
4062 serde_json::json!("true"),
4063 ))
4064 .build();
4065
4066 ensure_vector_store_filter_error(
4067 index.top_n::<TypedTestDocument>(req.clone()).await,
4068 "top_n mismatched metadata filter value type",
4069 )?;
4070 ensure_vector_store_filter_error(
4071 index.top_n_ids(req).await,
4072 "top_n_ids mismatched metadata filter value type",
4073 )?;
4074
4075 Ok(())
4076 }
4077
4078 #[tokio::test]
4079 async fn live_matches_exact_oracle_for_metrics_filters_and_thresholds() -> anyhow::Result<()> {
4080 let query = vec![1.0, 0.0];
4081 let rows = oracle_test_rows();
4082 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs"))
4083 .and(SqliteSearchFilter::lt("priority", serde_json::json!(10)))
4084 .and(SqliteSearchFilter::gt("rating", serde_json::json!(0.8)))
4085 .and(SqliteSearchFilter::eq("published", serde_json::json!(true)));
4086
4087 for distance_metric in [
4088 SqliteDistanceMetric::Cosine,
4089 SqliteDistanceMetric::L2,
4090 SqliteDistanceMetric::L1,
4091 ] {
4092 let threshold = oracle_threshold(distance_metric);
4093 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4094 .query("needle")
4095 .samples(u64::try_from(rows.len())?)
4096 .threshold(threshold)
4097 .filter(filter.clone())
4098 .build();
4099 let expected = exact_oracle_results(
4100 &rows,
4101 &query,
4102 distance_metric,
4103 threshold,
4104 rows.len(),
4105 |row| {
4106 row.category == "docs" && row.priority < 10 && row.rating > 0.8 && row.published
4107 },
4108 )?;
4109 let test_name =
4110 format!("live_matches_exact_oracle_for_{distance_metric:?}").to_ascii_lowercase();
4111 let index = live_typed_test_index_with_metric(
4112 &test_name,
4113 sqlite_oracle_rows(&rows),
4114 distance_metric,
4115 )
4116 .await?;
4117
4118 let results = index.top_n::<TypedTestDocument>(req.clone()).await?;
4119 let scored_ids = results
4120 .iter()
4121 .map(|(score, id, doc)| {
4122 anyhow::ensure!(
4123 id == &doc.id.to_string(),
4124 "top_n returned mismatched id and document: id={id}, doc={doc:?}"
4125 );
4126 Ok((*score, id.clone()))
4127 })
4128 .collect::<anyhow::Result<Vec<_>>>()?;
4129 assert_scored_ids_match(&scored_ids, &expected, distance_metric, "top_n")?;
4130
4131 let id_results = index.top_n_ids(req).await?;
4132 assert_scored_ids_match(&id_results, &expected, distance_metric, "top_n_ids")?;
4133 }
4134
4135 Ok(())
4136 }
4137
4138 #[tokio::test]
4139 async fn live_or_filter_preserves_mixed_document_semantics() -> anyhow::Result<()> {
4140 let index = live_test_index(
4141 "live_or_filter_preserves_mixed_document_semantics",
4142 vec![
4143 row(
4144 "nearest",
4145 "misc",
4146 "nearest excluded category",
4147 vec![1.0, 0.0],
4148 ),
4149 row("special", "misc", "special title", vec![0.9, 0.1]),
4150 row("docs", "docs", "far docs match", vec![0.0, 1.0]),
4151 ],
4152 )
4153 .await?;
4154
4155 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs")).or(
4156 SqliteSearchFilter::eq("title", serde_json::json!("special title")),
4157 );
4158
4159 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4160 .query("needle")
4161 .samples(1)
4162 .filter(filter)
4163 .build();
4164
4165 let results = index.top_n::<TestDocument>(req.clone()).await?;
4166 let ids = results
4167 .iter()
4168 .map(|(_, id, _)| id.as_str())
4169 .collect::<Vec<_>>();
4170 anyhow::ensure!(
4171 ids.as_slice() == ["special"],
4172 "OR filters should be applied as a whole document predicate: {results:?}"
4173 );
4174
4175 let id_results = index.top_n_ids(req).await?;
4176 let result_ids = id_results
4177 .iter()
4178 .map(|(_, id)| id.as_str())
4179 .collect::<Vec<_>>();
4180 anyhow::ensure!(
4181 result_ids.as_slice() == ["special"],
4182 "top_n_ids should preserve OR document semantics: {id_results:?}"
4183 );
4184
4185 Ok(())
4186 }
4187
4188 #[tokio::test]
4189 async fn live_pattern_and_null_filters_are_applied_after_candidate_search() -> anyhow::Result<()>
4190 {
4191 let index = live_json_metadata_test_index(
4192 "live_pattern_and_null_filters_are_applied_after_candidate_search",
4193 vec![
4194 json_metadata_row("nearest", "docs", "skip", "skip this", vec![1.0, 0.0]),
4195 json_metadata_row("matched", "docs", "vvv", "metadata match", vec![0.0, 1.0]),
4196 ],
4197 )
4198 .await?;
4199
4200 let filter = SqliteSearchFilter::is_null("metadata->>'$.missing'".to_string())
4201 .and(SqliteSearchFilter::like("title".to_string(), "metadata%"))
4202 .and(SqliteSearchFilter::glob("category".to_string(), "doc*"));
4203
4204 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4205 .query("needle")
4206 .samples(1)
4207 .filter(filter)
4208 .build();
4209
4210 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
4211 let ids = results
4212 .iter()
4213 .map(|(_, id, _)| id.as_str())
4214 .collect::<Vec<_>>();
4215 anyhow::ensure!(
4216 ids.as_slice() == ["matched"],
4217 "pattern and null filters should not be starved by the initial candidate limit: {results:?}"
4218 );
4219
4220 let id_results = index.top_n_ids(req).await?;
4221 let result_ids = id_results
4222 .iter()
4223 .map(|(_, id)| id.as_str())
4224 .collect::<Vec<_>>();
4225 anyhow::ensure!(
4226 result_ids.as_slice() == ["matched"],
4227 "top_n_ids should apply pattern and null filters after candidate search: {id_results:?}"
4228 );
4229
4230 Ok(())
4231 }
4232
4233 type SqliteExtensionFn =
4234 unsafe extern "C" fn(*mut sqlite3, *mut *mut c_char, *const sqlite3_api_routines) -> i32;
4235
4236 fn register_sqlite_vec_extension() {
4237 static REGISTER_SQLITE_VEC: Once = Once::new();
4238
4239 REGISTER_SQLITE_VEC.call_once(|| unsafe {
4240 sqlite3_auto_extension(Some(std::mem::transmute::<*const (), SqliteExtensionFn>(
4241 sqlite3_vec_init as *const (),
4242 )));
4243 });
4244 }
4245
4246 async fn live_test_index(
4247 name: &str,
4248 rows: Vec<(TestDocument, OneOrMany<Embedding>)>,
4249 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TestDocument>> {
4250 live_test_index_with_metric(name, rows, SqliteDistanceMetric::Cosine).await
4251 }
4252
4253 async fn live_test_index_with_metric(
4254 name: &str,
4255 rows: Vec<(TestDocument, OneOrMany<Embedding>)>,
4256 distance_metric: SqliteDistanceMetric,
4257 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TestDocument>> {
4258 register_sqlite_vec_extension();
4259
4260 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4261 let model = TestEmbeddingModel;
4262 let vector_store =
4263 SqliteVectorStore::with_distance_metric(conn, &model, distance_metric).await?;
4264
4265 vector_store.add_rows(rows).await?;
4266
4267 Ok(vector_store.index(model))
4268 }
4269
4270 async fn live_typed_test_index(
4271 name: &str,
4272 rows: Vec<(TypedTestDocument, OneOrMany<Embedding>)>,
4273 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TypedTestDocument>> {
4274 live_typed_test_index_with_metric(name, rows, SqliteDistanceMetric::Cosine).await
4275 }
4276
4277 async fn live_typed_test_index_with_metric(
4278 name: &str,
4279 rows: Vec<(TypedTestDocument, OneOrMany<Embedding>)>,
4280 distance_metric: SqliteDistanceMetric,
4281 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TypedTestDocument>> {
4282 register_sqlite_vec_extension();
4283
4284 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4285 let model = TestEmbeddingModel;
4286 let vector_store: SqliteVectorStore<_, TypedTestDocument> =
4287 SqliteVectorStore::with_distance_metric(conn, &model, distance_metric).await?;
4288
4289 vector_store.add_rows(rows).await?;
4290
4291 Ok(vector_store.index(model))
4292 }
4293
4294 async fn live_common_type_test_index(
4295 name: &str,
4296 rows: Vec<(CommonTypeDocument, OneOrMany<Embedding>)>,
4297 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, CommonTypeDocument>> {
4298 register_sqlite_vec_extension();
4299
4300 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4301 let model = TestEmbeddingModel;
4302 let vector_store: SqliteVectorStore<_, CommonTypeDocument> =
4303 SqliteVectorStore::new(conn, &model).await?;
4304
4305 vector_store.add_rows(rows).await?;
4306
4307 Ok(vector_store.index(model))
4308 }
4309
4310 async fn live_json_metadata_test_index(
4311 name: &str,
4312 rows: Vec<(JsonMetadataDocument, OneOrMany<Embedding>)>,
4313 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, JsonMetadataDocument>> {
4314 register_sqlite_vec_extension();
4315
4316 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4317 let model = TestEmbeddingModel;
4318 let vector_store: SqliteVectorStore<_, JsonMetadataDocument> =
4319 SqliteVectorStore::new(conn, &model).await?;
4320
4321 vector_store.add_rows(rows).await?;
4322
4323 Ok(vector_store.index(model))
4324 }
4325
4326 async fn live_structured_json_metadata_test_index(
4327 name: &str,
4328 rows: Vec<(StructuredJsonMetadataDocument, OneOrMany<Embedding>)>,
4329 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, StructuredJsonMetadataDocument>> {
4330 register_sqlite_vec_extension();
4331
4332 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4333 let model = TestEmbeddingModel;
4334 let vector_store: SqliteVectorStore<_, StructuredJsonMetadataDocument> =
4335 SqliteVectorStore::new(conn, &model).await?;
4336
4337 vector_store.add_rows(rows).await?;
4338
4339 Ok(vector_store.index(model))
4340 }
4341
4342 fn row(
4343 id: impl Into<String>,
4344 category: impl Into<String>,
4345 title: impl Into<String>,
4346 vec: Vec<f64>,
4347 ) -> (TestDocument, OneOrMany<Embedding>) {
4348 let document = TestDocument {
4349 id: id.into(),
4350 category: category.into(),
4351 title: title.into(),
4352 };
4353
4354 (
4355 document.clone(),
4356 OneOrMany::one(Embedding {
4357 document: document.title,
4358 vec,
4359 }),
4360 )
4361 }
4362
4363 fn common_type_row(
4364 id: impl Into<String>,
4365 name: impl Into<String>,
4366 notes: impl Into<String>,
4367 rank: i64,
4368 vec: Vec<f64>,
4369 ) -> (CommonTypeDocument, OneOrMany<Embedding>) {
4370 let document = CommonTypeDocument {
4371 id: id.into(),
4372 name: name.into(),
4373 notes: notes.into(),
4374 rank,
4375 };
4376
4377 (
4378 document.clone(),
4379 OneOrMany::one(Embedding {
4380 document: document.name.clone(),
4381 vec,
4382 }),
4383 )
4384 }
4385
4386 fn json_metadata_row(
4387 id: impl Into<String>,
4388 category: impl Into<String>,
4389 xxx: impl AsRef<str>,
4390 title: impl Into<String>,
4391 vec: Vec<f64>,
4392 ) -> (JsonMetadataDocument, OneOrMany<Embedding>) {
4393 let document = JsonMetadataDocument {
4394 id: id.into(),
4395 category: category.into(),
4396 metadata: serde_json::json!({ "xxx": xxx.as_ref() }).to_string(),
4397 title: title.into(),
4398 };
4399
4400 (
4401 document.clone(),
4402 OneOrMany::one(Embedding {
4403 document: document.title.clone(),
4404 vec,
4405 }),
4406 )
4407 }
4408
4409 fn structured_json_metadata_row(
4410 id: impl Into<String>,
4411 metadata: StructuredMetadata,
4412 title: impl Into<String>,
4413 vec: Vec<f64>,
4414 ) -> (StructuredJsonMetadataDocument, OneOrMany<Embedding>) {
4415 let document = StructuredJsonMetadataDocument {
4416 id: id.into(),
4417 metadata,
4418 title: title.into(),
4419 };
4420
4421 (
4422 document.clone(),
4423 OneOrMany::one(Embedding {
4424 document: document.title.clone(),
4425 vec,
4426 }),
4427 )
4428 }
4429
4430 fn reordered_id_row(
4431 id: impl Into<String>,
4432 title: impl Into<String>,
4433 category: impl Into<String>,
4434 vec: Vec<f64>,
4435 ) -> (ReorderedIdDocument, OneOrMany<Embedding>) {
4436 let document = ReorderedIdDocument {
4437 title: title.into(),
4438 id: id.into(),
4439 category: category.into(),
4440 };
4441
4442 (
4443 document.clone(),
4444 OneOrMany::one(Embedding {
4445 document: document.title.clone(),
4446 vec,
4447 }),
4448 )
4449 }
4450
4451 fn internal_alias_row(
4452 id: impl Into<String>,
4453 rig_score: impl Into<String>,
4454 rig_rank: impl Into<String>,
4455 title: impl Into<String>,
4456 vec: Vec<f64>,
4457 ) -> (InternalAliasDocument, OneOrMany<Embedding>) {
4458 let document = InternalAliasDocument {
4459 id: id.into(),
4460 rig_score: rig_score.into(),
4461 rig_rank: rig_rank.into(),
4462 title: title.into(),
4463 };
4464
4465 (
4466 document.clone(),
4467 OneOrMany::one(Embedding {
4468 document: document.title.clone(),
4469 vec,
4470 }),
4471 )
4472 }
4473
4474 fn typed_row(
4475 id: i64,
4476 category: impl Into<String>,
4477 priority: i64,
4478 rating: f64,
4479 published: bool,
4480 title: impl Into<String>,
4481 vec: Vec<f64>,
4482 ) -> (TypedTestDocument, OneOrMany<Embedding>) {
4483 let document = TypedTestDocument {
4484 id,
4485 category: category.into(),
4486 priority,
4487 rating,
4488 published,
4489 title: title.into(),
4490 };
4491
4492 (
4493 document.clone(),
4494 OneOrMany::one(Embedding {
4495 document: document.title,
4496 vec,
4497 }),
4498 )
4499 }
4500
4501 #[derive(Clone, Debug)]
4502 struct OracleRow {
4503 document: TypedTestDocument,
4504 embedding: Vec<f64>,
4505 }
4506
4507 #[derive(Debug)]
4508 struct ExpectedScoredId {
4509 id: String,
4510 score: f64,
4511 }
4512
4513 fn oracle_test_rows() -> Vec<OracleRow> {
4514 vec![
4515 oracle_row(1, "docs", 1, 0.95, true, "exact match", vec![1.0, 0.0]),
4516 oracle_row(2, "docs", 2, 0.90, true, "close match", vec![0.8, 0.6]),
4517 oracle_row(3, "docs", 3, 0.81, true, "borderline match", vec![0.5, 0.5]),
4518 oracle_row(
4519 4,
4520 "docs",
4521 4,
4522 0.70,
4523 true,
4524 "filtered by rating",
4525 vec![0.95, 0.05],
4526 ),
4527 oracle_row(
4528 5,
4529 "docs",
4530 15,
4531 0.99,
4532 true,
4533 "filtered by priority",
4534 vec![1.0, 0.0],
4535 ),
4536 oracle_row(
4537 6,
4538 "docs",
4539 5,
4540 0.99,
4541 false,
4542 "filtered by published",
4543 vec![1.0, 0.0],
4544 ),
4545 oracle_row(
4546 7,
4547 "misc",
4548 1,
4549 0.99,
4550 true,
4551 "filtered by category",
4552 vec![1.0, 0.0],
4553 ),
4554 oracle_row(8, "docs", 5, 0.95, true, "far match", vec![0.0, 1.0]),
4555 ]
4556 }
4557
4558 fn oracle_row(
4559 id: i64,
4560 category: impl Into<String>,
4561 priority: i64,
4562 rating: f64,
4563 published: bool,
4564 title: impl Into<String>,
4565 embedding: Vec<f64>,
4566 ) -> OracleRow {
4567 OracleRow {
4568 document: TypedTestDocument {
4569 id,
4570 category: category.into(),
4571 priority,
4572 rating,
4573 published,
4574 title: title.into(),
4575 },
4576 embedding,
4577 }
4578 }
4579
4580 fn sqlite_oracle_rows(rows: &[OracleRow]) -> Vec<(TypedTestDocument, OneOrMany<Embedding>)> {
4581 rows.iter()
4582 .map(|row| {
4583 (
4584 row.document.clone(),
4585 OneOrMany::one(Embedding {
4586 document: row.document.title.clone(),
4587 vec: row.embedding.clone(),
4588 }),
4589 )
4590 })
4591 .collect()
4592 }
4593
4594 fn oracle_threshold(distance_metric: SqliteDistanceMetric) -> f64 {
4595 match distance_metric {
4596 SqliteDistanceMetric::Cosine => 0.75,
4597 SqliteDistanceMetric::L2 => -0.8,
4598 SqliteDistanceMetric::L1 => -0.9,
4599 }
4600 }
4601
4602 fn exact_oracle_results(
4603 rows: &[OracleRow],
4604 query: &[f64],
4605 distance_metric: SqliteDistanceMetric,
4606 threshold: f64,
4607 samples: usize,
4608 filter: impl Fn(&TypedTestDocument) -> bool,
4609 ) -> anyhow::Result<Vec<ExpectedScoredId>> {
4610 let mut expected = Vec::new();
4611 for row in rows {
4612 if !filter(&row.document) {
4613 continue;
4614 }
4615
4616 let score = oracle_score(distance_metric, query, &row.embedding)?;
4617 if score >= threshold {
4618 expected.push(ExpectedScoredId {
4619 id: row.document.id.to_string(),
4620 score,
4621 });
4622 }
4623 }
4624
4625 sort_expected_scores(&mut expected);
4626 expected.truncate(samples);
4627 Ok(expected)
4628 }
4629
4630 fn sort_expected_scores(expected: &mut [ExpectedScoredId]) {
4631 expected.sort_by(|lhs, rhs| {
4632 rhs.score
4633 .partial_cmp(&lhs.score)
4634 .unwrap_or(Ordering::Equal)
4635 .then_with(|| lhs.id.cmp(&rhs.id))
4636 });
4637 }
4638
4639 fn oracle_score(
4640 distance_metric: SqliteDistanceMetric,
4641 query: &[f64],
4642 embedding: &[f64],
4643 ) -> anyhow::Result<f64> {
4644 anyhow::ensure!(
4645 query.len() == embedding.len(),
4646 "query and embedding dimensions differ: query={}, embedding={}",
4647 query.len(),
4648 embedding.len()
4649 );
4650
4651 let query = query.iter().map(|value| *value as f32).collect::<Vec<_>>();
4652 let embedding = embedding
4653 .iter()
4654 .map(|value| *value as f32)
4655 .collect::<Vec<_>>();
4656
4657 let score = match distance_metric {
4658 SqliteDistanceMetric::Cosine => {
4659 let dot = query
4660 .iter()
4661 .zip(&embedding)
4662 .map(|(lhs, rhs)| lhs * rhs)
4663 .sum::<f32>();
4664 let query_norm = query.iter().map(|value| value * value).sum::<f32>().sqrt();
4665 let embedding_norm = embedding
4666 .iter()
4667 .map(|value| value * value)
4668 .sum::<f32>()
4669 .sqrt();
4670 anyhow::ensure!(
4671 query_norm > 0.0 && embedding_norm > 0.0,
4672 "cosine oracle requires non-zero vectors"
4673 );
4674 dot / (query_norm * embedding_norm)
4675 }
4676 SqliteDistanceMetric::L2 => -query
4677 .iter()
4678 .zip(&embedding)
4679 .map(|(lhs, rhs)| {
4680 let delta = lhs - rhs;
4681 delta * delta
4682 })
4683 .sum::<f32>()
4684 .sqrt(),
4685 SqliteDistanceMetric::L1 => -query
4686 .iter()
4687 .zip(&embedding)
4688 .map(|(lhs, rhs)| (lhs - rhs).abs())
4689 .sum::<f32>(),
4690 };
4691
4692 Ok(f64::from(score))
4693 }
4694
4695 fn assert_scored_ids_match(
4696 actual: &[(f64, String)],
4697 expected: &[ExpectedScoredId],
4698 distance_metric: SqliteDistanceMetric,
4699 context: &str,
4700 ) -> anyhow::Result<()> {
4701 let actual_ids = actual.iter().map(|(_, id)| id.as_str()).collect::<Vec<_>>();
4702 let expected_ids = expected
4703 .iter()
4704 .map(|expected| expected.id.as_str())
4705 .collect::<Vec<_>>();
4706 anyhow::ensure!(
4707 actual_ids == expected_ids,
4708 "{context} ids for {distance_metric:?} did not match exact oracle: actual={actual:?}, expected={expected:?}"
4709 );
4710
4711 for ((actual_score, actual_id), expected) in actual.iter().zip(expected) {
4712 anyhow::ensure!(
4713 (actual_score - expected.score).abs() <= SCORE_EPSILON,
4714 "{context} score for {distance_metric:?} id `{actual_id}` did not match exact oracle: actual={actual_score}, expected={}",
4715 expected.score
4716 );
4717 }
4718
4719 Ok(())
4720 }
4721
4722 #[derive(Clone, Debug, Deserialize, Serialize)]
4723 struct TestDocument {
4724 id: String,
4725 category: String,
4726 title: String,
4727 }
4728
4729 impl SqliteVectorStoreTable for TestDocument {
4730 fn name() -> &'static str {
4731 "live_test_documents"
4732 }
4733
4734 fn schema() -> Vec<Column> {
4735 vec![
4736 Column::new("id", "TEXT PRIMARY KEY"),
4737 Column::new("category", "TEXT").indexed(),
4738 Column::new("title", "TEXT"),
4739 ]
4740 }
4741
4742 fn id(&self) -> String {
4743 self.id.clone()
4744 }
4745
4746 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4747 vec![
4748 ("id", Box::new(self.id.clone())),
4749 ("category", Box::new(self.category.clone())),
4750 ("title", Box::new(self.title.clone())),
4751 ]
4752 }
4753 }
4754
4755 #[derive(Clone, Debug, Deserialize, Serialize)]
4756 struct ReorderedIdDocument {
4757 title: String,
4758 id: String,
4759 category: String,
4760 }
4761
4762 impl SqliteVectorStoreTable for ReorderedIdDocument {
4763 fn name() -> &'static str {
4764 "live_reordered_id_test_documents"
4765 }
4766
4767 fn schema() -> Vec<Column> {
4768 vec![
4769 Column::new("title", "TEXT"),
4770 Column::new("id", "TEXT PRIMARY KEY"),
4771 Column::new("category", "TEXT").indexed(),
4772 ]
4773 }
4774
4775 fn id(&self) -> String {
4776 self.id.clone()
4777 }
4778
4779 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4780 vec![
4781 ("title", Box::new(self.title.clone())),
4782 ("id", Box::new(self.id.clone())),
4783 ("category", Box::new(self.category.clone())),
4784 ]
4785 }
4786 }
4787
4788 #[derive(Clone, Debug, Deserialize, Serialize)]
4789 struct InternalAliasDocument {
4790 id: String,
4791 #[serde(rename = "__rig_score")]
4792 rig_score: String,
4793 #[serde(rename = "__rig_rank")]
4794 rig_rank: String,
4795 title: String,
4796 }
4797
4798 impl SqliteVectorStoreTable for InternalAliasDocument {
4799 fn name() -> &'static str {
4800 "live_internal_alias_test_documents"
4801 }
4802
4803 fn schema() -> Vec<Column> {
4804 vec![
4805 Column::new("id", "TEXT PRIMARY KEY"),
4806 Column::new("__rig_score", "TEXT"),
4807 Column::new("__rig_rank", "TEXT"),
4808 Column::new("title", "TEXT"),
4809 ]
4810 }
4811
4812 fn id(&self) -> String {
4813 self.id.clone()
4814 }
4815
4816 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4817 vec![
4818 ("id", Box::new(self.id.clone())),
4819 ("__rig_score", Box::new(self.rig_score.clone())),
4820 ("__rig_rank", Box::new(self.rig_rank.clone())),
4821 ("title", Box::new(self.title.clone())),
4822 ]
4823 }
4824 }
4825
4826 #[derive(Clone, Debug, Deserialize, Serialize)]
4827 struct CommonTypeDocument {
4828 id: String,
4829 name: String,
4830 notes: String,
4831 rank: i64,
4832 }
4833
4834 impl SqliteVectorStoreTable for CommonTypeDocument {
4835 fn name() -> &'static str {
4836 "live_common_type_test_documents"
4837 }
4838
4839 fn schema() -> Vec<Column> {
4840 vec![
4841 Column::new("id", "TEXT PRIMARY KEY"),
4842 Column::new("name", "VARCHAR(255)").indexed(),
4843 Column::new("notes", "CLOB"),
4844 Column::new("rank", "NUMERIC"),
4845 ]
4846 }
4847
4848 fn id(&self) -> String {
4849 self.id.clone()
4850 }
4851
4852 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4853 vec![
4854 ("id", Box::new(self.id.clone())),
4855 ("name", Box::new(self.name.clone())),
4856 ("notes", Box::new(self.notes.clone())),
4857 ("rank", Box::new(self.rank)),
4858 ]
4859 }
4860 }
4861
4862 #[derive(Clone, Debug, Deserialize, Serialize)]
4863 struct JsonMetadataDocument {
4864 id: String,
4865 category: String,
4866 metadata: String,
4867 title: String,
4868 }
4869
4870 impl SqliteVectorStoreTable for JsonMetadataDocument {
4871 fn name() -> &'static str {
4872 "live_json_metadata_test_documents"
4873 }
4874
4875 fn schema() -> Vec<Column> {
4876 vec![
4877 Column::new("id", "TEXT PRIMARY KEY"),
4878 Column::new("category", "TEXT").indexed(),
4879 Column::new("metadata", "TEXT"),
4880 Column::new("title", "TEXT"),
4881 ]
4882 }
4883
4884 fn id(&self) -> String {
4885 self.id.clone()
4886 }
4887
4888 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4889 vec![
4890 ("id", Box::new(self.id.clone())),
4891 ("category", Box::new(self.category.clone())),
4892 ("metadata", Box::new(self.metadata.clone())),
4893 ("title", Box::new(self.title.clone())),
4894 ]
4895 }
4896 }
4897
4898 #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
4899 struct StructuredMetadata {
4900 user_id: i64,
4901 knowledge_id: i64,
4902 knowledge_doc_id: i64,
4903 }
4904
4905 #[derive(Clone, Debug, Deserialize, Serialize)]
4906 struct StructuredJsonMetadataDocument {
4907 id: String,
4908 metadata: StructuredMetadata,
4909 title: String,
4910 }
4911
4912 impl SqliteVectorStoreTable for StructuredJsonMetadataDocument {
4913 fn name() -> &'static str {
4914 "live_structured_json_metadata_test_documents"
4915 }
4916
4917 fn schema() -> Vec<Column> {
4918 vec![
4919 Column::new("id", "TEXT PRIMARY KEY"),
4920 Column::new("metadata", "JSON"),
4921 Column::new("title", "TEXT"),
4922 ]
4923 }
4924
4925 fn id(&self) -> String {
4926 self.id.clone()
4927 }
4928
4929 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4930 vec![
4931 ("id", Box::new(self.id.clone())),
4932 (
4933 "metadata",
4934 Box::new(serde_json::json!({
4935 "user_id": self.metadata.user_id,
4936 "knowledge_id": self.metadata.knowledge_id,
4937 "knowledge_doc_id": self.metadata.knowledge_doc_id,
4938 })),
4939 ),
4940 ("title", Box::new(self.title.clone())),
4941 ]
4942 }
4943 }
4944
4945 #[derive(Clone, Debug, Deserialize, Serialize)]
4946 struct TypedTestDocument {
4947 id: i64,
4948 category: String,
4949 priority: i64,
4950 rating: f64,
4951 published: bool,
4952 title: String,
4953 }
4954
4955 impl SqliteVectorStoreTable for TypedTestDocument {
4956 fn name() -> &'static str {
4957 "live_typed_test_documents"
4958 }
4959
4960 fn schema() -> Vec<Column> {
4961 vec![
4962 Column::new("id", "INTEGER PRIMARY KEY"),
4963 Column::new("category", "TEXT").indexed(),
4964 Column::new("priority", "INTEGER").indexed(),
4965 Column::new("rating", "FLOAT").indexed(),
4966 Column::new("published", "BOOLEAN").indexed(),
4967 Column::new("title", "TEXT"),
4968 ]
4969 }
4970
4971 fn id(&self) -> String {
4972 self.id.to_string()
4973 }
4974
4975 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
4976 vec![
4977 ("id", Box::new(self.id)),
4978 ("category", Box::new(self.category.clone())),
4979 ("priority", Box::new(self.priority)),
4980 ("rating", Box::new(self.rating)),
4981 ("published", Box::new(self.published)),
4982 ("title", Box::new(self.title.clone())),
4983 ]
4984 }
4985 }
4986
4987 #[derive(Clone)]
4988 struct TestEmbeddingModel;
4989
4990 impl EmbeddingModel for TestEmbeddingModel {
4991 const MAX_DOCUMENTS: usize = 16;
4992
4993 type Client = ();
4994
4995 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
4996 Self
4997 }
4998
4999 fn ndims(&self) -> usize {
5000 2
5001 }
5002
5003 async fn embed_texts(
5004 &self,
5005 texts: impl IntoIterator<Item = String> + WasmCompatSend,
5006 ) -> Result<Vec<Embedding>, EmbeddingError> {
5007 Ok(texts
5008 .into_iter()
5009 .map(|text| Embedding {
5010 document: text,
5011 vec: vec![1.0, 0.0],
5012 })
5013 .collect())
5014 }
5015 }
5016}