1use rig_core::embeddings::{Embedding, EmbeddingModel};
11use rig_core::vector_store::request::{FilterError, SearchFilter, VectorSearchRequest};
12use rig_core::vector_store::{InsertDocuments, VectorStoreError, VectorStoreIndex};
13use rig_core::wasm_compat::{WasmCompatSend, WasmCompatSync};
14use rig_core::{Embed, OneOrMany};
15use rusqlite::OptionalExtension;
16use rusqlite::types::{Type, Value, ValueRef};
17use serde::{Deserialize, Serialize};
18use std::fmt::{self, Display};
19use std::marker::PhantomData;
20use std::ops::RangeInclusive;
21use tokio_rusqlite::Connection;
22use tracing::{debug, info};
23use zerocopy::IntoBytes;
24
25const SQLITE_VEC_MAX_K: u64 = 4096;
33
34#[derive(Debug)]
35pub enum SqliteError {
36 DatabaseError(Box<dyn std::error::Error + Send + Sync>),
37 SerializationError(Box<dyn std::error::Error + Send + Sync>),
38 InvalidColumnType(String),
39}
40
41pub trait ColumnValue: Send + Sync {
45 fn to_sql_value(&self) -> Value;
47
48 fn column_type(&self) -> &'static str;
50}
51
52#[derive(Clone, Debug)]
53pub struct Column {
54 name: &'static str,
55 col_type: &'static str,
56 indexed: bool,
57}
58
59impl Column {
60 pub fn new(name: &'static str, col_type: &'static str) -> Self {
61 Self {
62 name,
63 col_type,
64 indexed: false,
65 }
66 }
67
68 pub fn indexed(mut self) -> Self {
76 self.indexed = true;
77 self
78 }
79}
80
81pub trait SqliteVectorStoreTable: Send + Sync + Clone {
119 fn name() -> &'static str;
120 fn schema() -> Vec<Column>;
121 fn id(&self) -> String;
122 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)>;
123}
124
125#[derive(Clone, Copy, Debug, Default, Deserialize, Eq, PartialEq, Serialize)]
133pub enum SqliteDistanceMetric {
134 #[default]
136 Cosine,
137 L2,
139 L1,
141}
142
143impl SqliteDistanceMetric {
144 fn vec0_name(self) -> &'static str {
145 match self {
146 Self::Cosine => "cosine",
147 Self::L2 => "l2",
148 Self::L1 => "l1",
149 }
150 }
151
152 fn score_expression(self, query_param: &str, embedding_expr: &str) -> String {
153 match self {
154 Self::Cosine => {
155 format!("(1 - vec_distance_cosine({query_param}, {embedding_expr}))")
156 }
157 Self::L2 => format!("(-vec_distance_l2({query_param}, {embedding_expr}))"),
158 Self::L1 => format!("(-vec_distance_l1({query_param}, {embedding_expr}))"),
159 }
160 }
161}
162
163#[derive(Debug)]
164struct SqliteDistanceMetricMismatch {
165 table_name: String,
166 requested: SqliteDistanceMetric,
167 configured: SqliteDistanceMetric,
168}
169
170impl Display for SqliteDistanceMetricMismatch {
171 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
172 write!(
173 f,
174 "SQLite vector table `{}` uses {:?}, but {:?} was requested",
175 self.table_name, self.configured, self.requested
176 )
177 }
178}
179
180impl std::error::Error for SqliteDistanceMetricMismatch {}
181
182#[derive(Debug)]
183struct SqliteVectorTableMissingSchema {
184 table_name: String,
185}
186
187impl Display for SqliteVectorTableMissingSchema {
188 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
189 write!(
190 f,
191 "SQLite vector table `{}` was created but is missing from sqlite_schema",
192 self.table_name
193 )
194 }
195}
196
197impl std::error::Error for SqliteVectorTableMissingSchema {}
198
199#[derive(Clone, Copy, Debug, Eq, PartialEq)]
200enum SqliteMetadataType {
201 Text,
202 Integer,
203 Float,
204 Boolean,
205}
206
207impl SqliteMetadataType {
208 fn from_column_type(column_type: &str) -> Option<Self> {
209 let first_type_token = column_type
210 .split_whitespace()
211 .next()
212 .unwrap_or_default()
213 .to_ascii_uppercase();
214
215 match first_type_token.as_str() {
216 "TEXT" => Some(Self::Text),
217 "INTEGER" | "INT" | "INT64" | "INTEGER64" => Some(Self::Integer),
218 "FLOAT" | "REAL" | "DOUBLE" | "FLOAT64" | "F64" => Some(Self::Float),
219 "BOOLEAN" | "BOOL" => Some(Self::Boolean),
220 _ => match SqliteColumnAffinity::from_column_type(column_type) {
221 SqliteColumnAffinity::Text => Some(Self::Text),
222 SqliteColumnAffinity::Integer => Some(Self::Integer),
223 SqliteColumnAffinity::Float => Some(Self::Float),
224 SqliteColumnAffinity::Boolean => Some(Self::Boolean),
225 SqliteColumnAffinity::Numeric | SqliteColumnAffinity::Blob => None,
226 },
227 }
228 }
229
230 fn vec0_name(self) -> &'static str {
231 match self {
232 Self::Text => "TEXT",
233 Self::Integer => "INTEGER",
234 Self::Float => "FLOAT",
235 Self::Boolean => "BOOLEAN",
236 }
237 }
238
239 fn supports_native_comparison(self, op: SqliteComparisonOp) -> bool {
240 !matches!(
241 (self, op),
242 (
243 Self::Boolean,
244 SqliteComparisonOp::Gt
245 | SqliteComparisonOp::Lt
246 | SqliteComparisonOp::Gte
247 | SqliteComparisonOp::Lte
248 )
249 )
250 }
251}
252
253#[derive(Clone, Copy, Debug, Eq, PartialEq)]
254enum SqliteColumnAffinity {
255 Text,
256 Integer,
257 Float,
258 Boolean,
259 Numeric,
260 Blob,
261}
262
263impl SqliteColumnAffinity {
264 fn from_column_type(column_type: &str) -> Self {
265 let column_type = column_type.to_ascii_uppercase();
266
267 if column_type.contains("INT") {
268 Self::Integer
269 } else if column_type.contains("CHAR")
270 || column_type.contains("CLOB")
271 || column_type.contains("TEXT")
272 {
273 Self::Text
274 } else if column_type.contains("BLOB") || column_type.trim().is_empty() {
275 Self::Blob
276 } else if column_type.contains("REAL")
277 || column_type.contains("FLOA")
278 || column_type.contains("DOUB")
279 {
280 Self::Float
281 } else if column_type.contains("BOOL") {
282 Self::Boolean
283 } else {
284 Self::Numeric
285 }
286 }
287}
288
289#[derive(Clone, Debug, Eq, PartialEq)]
290struct SqliteMetadataColumn {
291 name: &'static str,
292 metadata_type: SqliteMetadataType,
293}
294
295#[derive(Debug)]
296struct SqliteUnsupportedMetadataColumn {
297 column_name: &'static str,
298 column_type: &'static str,
299}
300
301impl Display for SqliteUnsupportedMetadataColumn {
302 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
303 write!(
304 f,
305 "SQLite metadata column `{}` has unsupported type `{}`",
306 self.column_name, self.column_type
307 )
308 }
309}
310
311impl std::error::Error for SqliteUnsupportedMetadataColumn {}
312
313#[derive(Debug)]
314struct SqliteMetadataSchemaMismatch {
315 table_name: String,
316 column_name: &'static str,
317 column_type: SqliteMetadataType,
318}
319
320impl Display for SqliteMetadataSchemaMismatch {
321 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
322 write!(
323 f,
324 "SQLite vector table `{}` is missing metadata column `{} {}`",
325 self.table_name,
326 self.column_name,
327 self.column_type.vec0_name()
328 )
329 }
330}
331
332impl std::error::Error for SqliteMetadataSchemaMismatch {}
333
334#[derive(Debug)]
335struct SqliteMetadataValueError {
336 column_name: &'static str,
337 column_type: SqliteMetadataType,
338 value_type: Type,
339}
340
341impl Display for SqliteMetadataValueError {
342 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
343 write!(
344 f,
345 "could not convert SQLite value type `{:?}` for metadata column `{} {}`",
346 self.value_type,
347 self.column_name,
348 self.column_type.vec0_name()
349 )
350 }
351}
352
353impl std::error::Error for SqliteMetadataValueError {}
354
355#[derive(Debug)]
356struct SqliteMissingIdColumn {
357 table_name: String,
358}
359
360impl Display for SqliteMissingIdColumn {
361 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
362 write!(
363 f,
364 "SQLite vector store table `{}` is missing an `id` column",
365 self.table_name
366 )
367 }
368}
369
370impl std::error::Error for SqliteMissingIdColumn {}
371
372fn sqlite_metadata_columns(
373 schema: &[Column],
374) -> Result<Vec<SqliteMetadataColumn>, VectorStoreError> {
375 schema
376 .iter()
377 .filter(|column| column.indexed)
378 .map(|column| {
379 let metadata_type =
380 SqliteMetadataType::from_column_type(column.col_type).ok_or_else(|| {
381 VectorStoreError::DatastoreError(Box::new(SqliteUnsupportedMetadataColumn {
382 column_name: column.name,
383 column_type: column.col_type,
384 }))
385 })?;
386
387 Ok(SqliteMetadataColumn {
388 name: column.name,
389 metadata_type,
390 })
391 })
392 .collect()
393}
394
395fn sqlite_metadata_value(
396 values: &[(&'static str, Box<dyn ColumnValue>)],
397 column: &SqliteMetadataColumn,
398) -> rusqlite::Result<Value> {
399 let value = values
400 .iter()
401 .find(|(name, _)| *name == column.name)
402 .ok_or_else(|| rusqlite::Error::InvalidParameterName(column.name.to_string()))?
403 .1
404 .to_sql_value();
405
406 match (column.metadata_type, value) {
407 (SqliteMetadataType::Text, Value::Text(value)) => Ok(Value::Text(value)),
408 (SqliteMetadataType::Integer, Value::Integer(value)) => Ok(Value::Integer(value)),
409 (SqliteMetadataType::Float, Value::Real(value)) => Ok(Value::Real(value)),
410 (SqliteMetadataType::Float, Value::Integer(value)) => Ok(Value::Real(value as f64)),
411 (SqliteMetadataType::Boolean, Value::Integer(value @ (0 | 1))) => Ok(Value::Integer(value)),
412 (_, value) => Err(rusqlite::Error::ToSqlConversionFailure(Box::new(
413 SqliteMetadataValueError {
414 column_name: column.name,
415 column_type: column.metadata_type,
416 value_type: value.data_type(),
417 },
418 ))),
419 }
420}
421
422#[derive(Clone)]
423pub struct SqliteVectorStore<E, T>
424where
425 E: EmbeddingModel + 'static,
426 T: SqliteVectorStoreTable + 'static,
427{
428 conn: Connection,
429 distance_metric: SqliteDistanceMetric,
430 metadata_columns: Vec<SqliteMetadataColumn>,
431 _phantom: PhantomData<(E, T)>,
432}
433
434impl<E, T> SqliteVectorStore<E, T>
435where
436 E: EmbeddingModel + 'static,
437 T: SqliteVectorStoreTable + 'static,
438{
439 async fn candidate_limit(&self, samples: u64, exhaustive: bool) -> Result<u64, VectorStoreError>
440 where
441 Self: 'static,
442 {
443 if samples == 0 {
444 return Ok(0);
445 }
446
447 let embedding_map_table_name = format!("{}_embedding_map", T::name());
448 let (embedding_count, document_count) = self
449 .conn
450 .call(move |conn| {
451 Ok(conn.query_row(
452 &format!(
453 "SELECT COUNT(*), COUNT(DISTINCT document_rowid) FROM {embedding_map_table_name}"
454 ),
455 [],
456 |row| Ok((row.get::<_, i64>(0)?, row.get::<_, i64>(1)?)),
457 )?)
458 })
459 .await
460 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
461
462 let embedding_count = u64::try_from(embedding_count).unwrap_or(0);
463 let document_count = u64::try_from(document_count).unwrap_or(0);
464
465 if exhaustive {
466 Ok(embedding_count.max(samples))
470 } else if embedding_count > document_count {
471 Ok(samples
479 .saturating_add(embedding_count - document_count)
480 .min(embedding_count))
481 } else {
482 Ok(samples)
483 }
484 }
485}
486
487impl<E, T> SqliteVectorStore<E, T>
488where
489 E: EmbeddingModel + Clone + 'static,
490 T: SqliteVectorStoreTable + 'static,
491{
492 pub async fn new(conn: Connection, embedding_model: &E) -> Result<Self, VectorStoreError> {
494 Self::with_distance_metric(conn, embedding_model, SqliteDistanceMetric::default()).await
495 }
496
497 pub async fn with_distance_metric(
503 conn: Connection,
504 embedding_model: &E,
505 distance_metric: SqliteDistanceMetric,
506 ) -> Result<Self, VectorStoreError> {
507 let dims = embedding_model.ndims();
508 let table_name = T::name();
509 let embeddings_table_name = format!("{table_name}_embeddings");
510 let embeddings_table_name_for_sql = embeddings_table_name.clone();
511 let embedding_map_table_name_for_sql = format!("{table_name}_embedding_map");
512 let schema = T::schema();
513 let metadata_columns = sqlite_metadata_columns(&schema)?;
514 let metadata_columns_for_schema_check = metadata_columns.clone();
515 let distance_metric_name = distance_metric.vec0_name();
516 let mut embeddings_columns =
517 format!("embedding float[{dims}] distance_metric={distance_metric_name}");
518 for column in &metadata_columns {
519 embeddings_columns.push_str(&format!(
520 ", {} {}",
521 column.name,
522 column.metadata_type.vec0_name()
523 ));
524 }
525
526 let mut create_table = format!("CREATE TABLE IF NOT EXISTS {table_name} (");
528
529 let mut first = true;
531 for column in &schema {
532 if !first {
533 create_table.push(',');
534 }
535 create_table.push_str(&format!("\n {} {}", column.name, column.col_type));
536 first = false;
537 }
538
539 create_table.push_str("\n)");
540
541 let mut create_indexes = vec![format!(
543 "CREATE INDEX IF NOT EXISTS idx_{}_id ON {}(id)",
544 table_name, table_name
545 )];
546
547 for column in schema {
549 if column.indexed {
550 create_indexes.push(format!(
551 "CREATE INDEX IF NOT EXISTS idx_{}_{} ON {}({})",
552 table_name, column.name, table_name, column.name
553 ));
554 }
555 }
556
557 let embeddings_table_sql = conn
558 .call(move |conn| {
559 conn.execute_batch("BEGIN")?;
560
561 conn.execute_batch(&create_table)?;
563
564 for index_stmt in create_indexes {
566 conn.execute_batch(&index_stmt)?;
567 }
568
569 conn.execute_batch(&format!(
571 "CREATE VIRTUAL TABLE IF NOT EXISTS {embeddings_table_name_for_sql} USING vec0({embeddings_columns})"
572 ))?;
573 conn.execute_batch(&format!(
574 "CREATE TABLE IF NOT EXISTS {embedding_map_table_name_for_sql} (
575 embedding_rowid INTEGER PRIMARY KEY AUTOINCREMENT,
576 document_rowid INTEGER NOT NULL
577 )"
578 ))?;
579 conn.execute_batch(&format!(
580 "CREATE INDEX IF NOT EXISTS idx_{table_name}_embedding_map_document_rowid ON {embedding_map_table_name_for_sql}(document_rowid)"
581 ))?;
582
583 conn.execute_batch("COMMIT")?;
584
585 let schema_sql = conn
586 .query_row(
587 "SELECT sql FROM sqlite_schema WHERE name = ?1",
588 [&embeddings_table_name_for_sql],
589 |row| row.get::<_, String>(0),
590 )
591 .optional()?;
592
593 Ok(schema_sql)
594 })
595 .await
596 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
597
598 let schema_sql = embeddings_table_sql.ok_or_else(|| {
599 VectorStoreError::DatastoreError(Box::new(SqliteVectorTableMissingSchema {
600 table_name: embeddings_table_name.clone(),
601 }))
602 })?;
603
604 let configured = sqlite_distance_metric_from_schema(&schema_sql);
605 if configured != distance_metric {
606 return Err(VectorStoreError::DatastoreError(Box::new(
607 SqliteDistanceMetricMismatch {
608 table_name: embeddings_table_name,
609 requested: distance_metric,
610 configured,
611 },
612 )));
613 }
614 for column in metadata_columns_for_schema_check {
615 if !sqlite_schema_contains_metadata_column(&schema_sql, &column) {
616 return Err(VectorStoreError::DatastoreError(Box::new(
617 SqliteMetadataSchemaMismatch {
618 table_name: embeddings_table_name.clone(),
619 column_name: column.name,
620 column_type: column.metadata_type,
621 },
622 )));
623 }
624 }
625
626 Ok(Self {
627 conn,
628 distance_metric,
629 metadata_columns,
630 _phantom: PhantomData,
631 })
632 }
633
634 pub fn index(self, model: E) -> SqliteVectorIndex<E, T> {
635 SqliteVectorIndex::new(model, self)
636 }
637
638 pub fn add_rows_with_txn(
639 &self,
640 txn: &rusqlite::Transaction<'_>,
641 documents: Vec<(T, OneOrMany<Embedding>)>,
642 ) -> Result<i64, tokio_rusqlite::Error> {
643 info!("Adding {} documents to store", documents.len());
644 let table_name = T::name();
645 let embeddings_table_name = format!("{table_name}_embeddings");
646 let embedding_map_table_name = format!("{table_name}_embedding_map");
647 let mut last_id = 0;
648 let embedding_columns = std::iter::once("rowid")
649 .chain(std::iter::once("embedding"))
650 .chain(self.metadata_columns.iter().map(|column| column.name))
651 .collect::<Vec<_>>();
652 let embedding_placeholders = (1..=embedding_columns.len())
653 .map(|i| format!("?{i}"))
654 .collect::<Vec<_>>();
655 let embeddings_sql = format!(
656 "INSERT INTO {embeddings_table_name} ({}) VALUES ({})",
657 embedding_columns.join(", "),
658 embedding_placeholders.join(", ")
659 );
660 let existing_rowid_sql = format!("SELECT rowid FROM {table_name} WHERE id = ?1");
661 let existing_embedding_rowids_sql = format!(
662 "SELECT embedding_rowid FROM {embedding_map_table_name} WHERE document_rowid = ?1"
663 );
664 let insert_embedding_map_sql =
665 format!("INSERT INTO {embedding_map_table_name}(document_rowid) VALUES (?1)");
666 let delete_embedding_map_sql =
667 format!("DELETE FROM {embedding_map_table_name} WHERE document_rowid = ?1");
668 let delete_embeddings_sql = format!("DELETE FROM {embeddings_table_name} WHERE rowid = ?1");
669
670 for (doc, embeddings) in &documents {
671 debug!("Storing document with id {}", doc.id());
672
673 let values = doc.column_values();
674 let id_value = values
675 .iter()
676 .find(|(name, _)| *name == "id")
677 .map(|(_, value)| value.to_sql_value())
678 .unwrap_or_else(|| Value::Text(doc.id()));
679 if let Some(existing_rowid) = txn
680 .query_row(&existing_rowid_sql, rusqlite::params![id_value], |row| {
681 row.get::<_, i64>(0)
682 })
683 .optional()?
684 {
685 let existing_embedding_rowids = txn
686 .prepare(&existing_embedding_rowids_sql)?
687 .query_map([existing_rowid], |row| row.get::<_, i64>(0))?
688 .collect::<rusqlite::Result<Vec<_>>>()?;
689 for embedding_rowid in existing_embedding_rowids {
690 txn.execute(&delete_embeddings_sql, [embedding_rowid])?;
691 }
692 txn.execute(&delete_embedding_map_sql, [existing_rowid])?;
693 }
694
695 let columns = values.iter().map(|(col, _)| *col).collect::<Vec<_>>();
696
697 let placeholders = (1..=values.len())
698 .map(|i| format!("?{i}"))
699 .collect::<Vec<_>>();
700
701 let insert_sql = format!(
702 "INSERT OR REPLACE INTO {} ({}) VALUES ({})",
703 table_name,
704 columns.join(", "),
705 placeholders.join(", ")
706 );
707
708 txn.execute(
709 &insert_sql,
710 rusqlite::params_from_iter(values.iter().map(|(_, val)| val.to_sql_value())),
711 )?;
712 last_id = txn.last_insert_rowid();
713
714 let metadata_values = self
715 .metadata_columns
716 .iter()
717 .map(|column| sqlite_metadata_value(&values, column))
718 .collect::<rusqlite::Result<Vec<_>>>()?;
719
720 let mut stmt = txn.prepare(&embeddings_sql)?;
721 for (i, embedding) in embeddings.iter().enumerate() {
722 let vec = serialize_embedding(embedding);
723 debug!(
724 "Storing embedding {} of {} (size: {} bytes)",
725 i + 1,
726 embeddings.len(),
727 vec.len() * 4
728 );
729 txn.execute(&insert_embedding_map_sql, [last_id])?;
730 let embedding_rowid = txn.last_insert_rowid();
731 let mut params = Vec::with_capacity(2 + metadata_values.len());
732 params.push(Value::Integer(embedding_rowid));
733 params.push(Value::Blob(vec.as_bytes().to_vec()));
734 params.extend(metadata_values.iter().cloned());
735 stmt.execute(rusqlite::params_from_iter(params))?;
736 }
737 }
738
739 Ok(last_id)
740 }
741
742 pub async fn add_rows(
743 &self,
744 documents: Vec<(T, OneOrMany<Embedding>)>,
745 ) -> Result<i64, VectorStoreError>
746 where
747 T: 'static,
748 Self: 'static,
749 {
750 let cloned = self.clone();
751
752 self.conn
753 .call(move |conn| {
754 let tx = conn.transaction()?;
755 let result = cloned.add_rows_with_txn(&tx, documents)?;
756 tx.commit()?;
757
758 Ok(result)
759 })
760 .await
761 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))
762 }
763}
764
765impl<E, T> InsertDocuments for SqliteVectorStore<E, T>
766where
767 E: EmbeddingModel + Clone + WasmCompatSend + WasmCompatSync + 'static,
768 T: SqliteVectorStoreTable
769 + for<'de> Deserialize<'de>
770 + WasmCompatSend
771 + WasmCompatSync
772 + 'static,
773{
774 async fn insert_documents<Doc: Serialize + Embed + WasmCompatSend>(
775 &self,
776 documents: Vec<(Doc, OneOrMany<Embedding>)>,
777 ) -> Result<(), VectorStoreError> {
778 if documents.is_empty() {
779 return Ok(());
780 }
781
782 let rows = documents
783 .into_iter()
784 .map(|(document, embeddings)| {
785 let document = serde_json::to_value(document)?;
786 let row = serde_json::from_value::<T>(document)?;
787
788 Ok((row, embeddings))
789 })
790 .collect::<Result<Vec<_>, VectorStoreError>>()?;
791
792 self.add_rows(rows).await?;
793
794 Ok(())
795 }
796}
797
798#[derive(Clone, Deserialize, Serialize, Debug)]
810pub struct SqliteSearchFilter {
811 expr: SqliteSearchFilterExpr,
812}
813
814impl Default for SqliteSearchFilter {
815 fn default() -> Self {
816 Self {
817 expr: SqliteSearchFilterExpr::Raw {
818 condition: "1 = 1".to_string(),
819 params: Vec::new(),
820 },
821 }
822 }
823}
824
825#[derive(Clone, Deserialize, Serialize, Debug)]
826enum SqliteSearchFilterExpr {
827 Comparison {
828 key: String,
829 op: SqliteComparisonOp,
830 value: serde_json::Value,
831 },
832 And(Box<SqliteSearchFilterExpr>, Box<SqliteSearchFilterExpr>),
833 Or(Box<SqliteSearchFilterExpr>, Box<SqliteSearchFilterExpr>),
834 Not(Box<SqliteSearchFilterExpr>),
835 Between {
836 key: String,
837 lo: serde_json::Value,
838 hi: serde_json::Value,
839 },
840 NullCheck {
841 key: String,
842 negated: bool,
843 },
844 Pattern {
845 key: String,
846 op: SqlitePatternOp,
847 pattern: String,
848 },
849 Raw {
850 condition: String,
851 params: Vec<serde_json::Value>,
852 },
853}
854
855#[derive(Clone, Copy, Deserialize, Eq, PartialEq, Serialize, Debug)]
856enum SqliteComparisonOp {
857 Eq,
858 Ne,
859 Gt,
860 Gte,
861 Lt,
862 Lte,
863}
864
865impl SqliteComparisonOp {
866 fn as_sql(self) -> &'static str {
867 match self {
868 Self::Eq => "=",
869 Self::Ne => "!=",
870 Self::Gt => ">",
871 Self::Gte => ">=",
872 Self::Lt => "<",
873 Self::Lte => "<=",
874 }
875 }
876
877 fn negate(self) -> Self {
878 match self {
879 Self::Eq => Self::Ne,
880 Self::Ne => Self::Eq,
881 Self::Gt => Self::Lte,
882 Self::Gte => Self::Lt,
883 Self::Lt => Self::Gte,
884 Self::Lte => Self::Gt,
885 }
886 }
887}
888
889#[derive(Clone, Copy, Deserialize, Serialize, Debug)]
890enum SqlitePatternOp {
891 Glob,
892 Like,
893}
894
895impl SqlitePatternOp {
896 fn as_sql(self) -> &'static str {
897 match self {
898 Self::Glob => "glob",
899 Self::Like => "like",
900 }
901 }
902}
903
904#[derive(Debug, Default)]
905struct SqliteRenderedFilters {
906 native: Vec<SqliteRenderedFilter>,
907 post: Vec<SqliteRenderedFilter>,
908}
909
910impl SqliteRenderedFilters {
911 fn extend(&mut self, rhs: Self) {
912 self.native.extend(rhs.native);
913 self.post.extend(rhs.post);
914 }
915
916 fn has_post_filters(&self) -> bool {
917 !self.post.is_empty()
918 }
919}
920
921#[derive(Debug)]
922struct SqliteRenderedFilter {
923 condition: String,
924 params: Vec<Value>,
925}
926
927#[derive(Clone, Copy, Debug, Eq, PartialEq)]
928enum SqliteDocumentValueMode {
929 Sql,
930 JsonText,
931}
932
933#[derive(Debug)]
934struct SqliteQualifiedDocumentKey {
935 expression: String,
936 value_mode: SqliteDocumentValueMode,
937 plain_column: Option<String>,
938}
939
940impl SearchFilter for SqliteSearchFilter {
941 type Value = serde_json::Value;
942
943 fn eq(key: impl AsRef<str>, value: Self::Value) -> Self {
944 Self {
945 expr: SqliteSearchFilterExpr::Comparison {
946 key: key.as_ref().to_string(),
947 op: SqliteComparisonOp::Eq,
948 value,
949 },
950 }
951 }
952
953 fn gt(key: impl AsRef<str>, value: Self::Value) -> Self {
954 Self {
955 expr: SqliteSearchFilterExpr::Comparison {
956 key: key.as_ref().to_string(),
957 op: SqliteComparisonOp::Gt,
958 value,
959 },
960 }
961 }
962
963 fn lt(key: impl AsRef<str>, value: Self::Value) -> Self {
964 Self {
965 expr: SqliteSearchFilterExpr::Comparison {
966 key: key.as_ref().to_string(),
967 op: SqliteComparisonOp::Lt,
968 value,
969 },
970 }
971 }
972
973 fn and(self, rhs: Self) -> Self {
974 Self {
975 expr: SqliteSearchFilterExpr::And(Box::new(self.expr), Box::new(rhs.expr)),
976 }
977 }
978
979 fn or(self, rhs: Self) -> Self {
980 Self {
981 expr: SqliteSearchFilterExpr::Or(Box::new(self.expr), Box::new(rhs.expr)),
982 }
983 }
984}
985
986impl SqliteSearchFilter {
987 #[allow(clippy::should_implement_trait)]
988 pub fn not(self) -> Self {
995 Self {
996 expr: SqliteSearchFilterExpr::Not(Box::new(self.expr)),
997 }
998 }
999
1000 pub fn between<N>(key: String, range: RangeInclusive<N>) -> Self
1006 where
1007 N: Into<serde_json::Value>,
1008 {
1009 let (lo, hi) = range.into_inner();
1010
1011 Self {
1012 expr: SqliteSearchFilterExpr::Between {
1013 key,
1014 lo: lo.into(),
1015 hi: hi.into(),
1016 },
1017 }
1018 }
1019
1020 pub fn is_null(key: String) -> Self {
1022 Self {
1023 expr: SqliteSearchFilterExpr::NullCheck {
1024 key,
1025 negated: false,
1026 },
1027 }
1028 }
1029
1030 pub fn is_not_null(key: String) -> Self {
1031 Self {
1032 expr: SqliteSearchFilterExpr::NullCheck { key, negated: true },
1033 }
1034 }
1035
1036 pub fn glob(key: String, pattern: impl Into<String>) -> Self {
1041 Self {
1042 expr: SqliteSearchFilterExpr::Pattern {
1043 key,
1044 op: SqlitePatternOp::Glob,
1045 pattern: pattern.into(),
1046 },
1047 }
1048 }
1049
1050 pub fn like(key: String, pattern: impl Into<String>) -> Self {
1055 Self {
1056 expr: SqliteSearchFilterExpr::Pattern {
1057 key,
1058 op: SqlitePatternOp::Like,
1059 pattern: pattern.into(),
1060 },
1061 }
1062 }
1063}
1064
1065impl SqliteSearchFilter {
1066 fn raw(condition: impl Into<String>, params: Vec<serde_json::Value>) -> Self {
1067 Self {
1068 expr: SqliteSearchFilterExpr::Raw {
1069 condition: condition.into(),
1070 params,
1071 },
1072 }
1073 }
1074
1075 fn render_split(
1076 &self,
1077 metadata_columns: &[SqliteMetadataColumn],
1078 ) -> Result<SqliteRenderedFilters, FilterError> {
1079 self.expr.render_split(metadata_columns)
1080 }
1081}
1082
1083impl SqliteSearchFilterExpr {
1084 fn render_native_comparison(
1085 key: &str,
1086 op: SqliteComparisonOp,
1087 value: serde_json::Value,
1088 metadata_columns: &[SqliteMetadataColumn],
1089 ) -> Result<SqliteRenderedFilters, FilterError> {
1090 let Some(metadata_column) = sqlite_native_metadata_column(key, metadata_columns) else {
1091 return Ok(SqliteRenderedFilters {
1092 native: Vec::new(),
1093 post: vec![Self::render_document_comparison(
1094 key,
1095 op,
1096 value,
1097 metadata_columns,
1098 )?],
1099 });
1100 };
1101
1102 if !metadata_column.metadata_type.supports_native_comparison(op) {
1103 return Err(sqlite_unsupported_filter(format!(
1104 "`{key}` is a BOOLEAN metadata column, and sqlite-vec only supports `=` and `!=` filters for booleans"
1105 )));
1106 }
1107
1108 Ok(SqliteRenderedFilters {
1109 native: vec![SqliteRenderedFilter {
1110 condition: format!("e.{key} {} ?", op.as_sql()),
1111 params: vec![sqlite_metadata_filter_param(metadata_column, value)?],
1112 }],
1113 post: Vec::new(),
1114 })
1115 }
1116
1117 fn render_document_comparison(
1118 key: &str,
1119 op: SqliteComparisonOp,
1120 value: serde_json::Value,
1121 metadata_columns: &[SqliteMetadataColumn],
1122 ) -> Result<SqliteRenderedFilter, FilterError> {
1123 let key = sqlite_qualify_document_key(key)?;
1124 Ok(SqliteRenderedFilter {
1125 condition: format!("{} {} ?", key.expression, op.as_sql()),
1126 params: vec![sqlite_document_filter_param(&key, metadata_columns, value)?],
1127 })
1128 }
1129
1130 fn render_split(
1131 &self,
1132 metadata_columns: &[SqliteMetadataColumn],
1133 ) -> Result<SqliteRenderedFilters, FilterError> {
1134 match self {
1135 Self::Comparison { key, op, value } => {
1136 Self::render_native_comparison(key, *op, value.clone(), metadata_columns)
1137 }
1138 Self::And(lhs, rhs) => {
1139 let mut rendered = lhs.render_split(metadata_columns)?;
1140 rendered.extend(rhs.render_split(metadata_columns)?);
1141 Ok(rendered)
1142 }
1143 Self::Between { key, lo, hi } => {
1144 let Some(metadata_column) = sqlite_native_metadata_column(key, metadata_columns)
1145 else {
1146 return Ok(SqliteRenderedFilters {
1147 native: Vec::new(),
1148 post: vec![self.render_document(metadata_columns)?],
1149 });
1150 };
1151
1152 if metadata_column.metadata_type == SqliteMetadataType::Boolean {
1153 return Err(sqlite_unsupported_filter(format!(
1154 "`{key}` is a BOOLEAN metadata column, and sqlite-vec does not support range filters for booleans"
1155 )));
1156 }
1157
1158 Ok(SqliteRenderedFilters {
1159 native: vec![SqliteRenderedFilter {
1160 condition: format!("e.{key} >= ? AND e.{key} <= ?"),
1161 params: vec![
1162 sqlite_metadata_filter_param(metadata_column, lo.clone())?,
1163 sqlite_metadata_filter_param(metadata_column, hi.clone())?,
1164 ],
1165 }],
1166 post: Vec::new(),
1167 })
1168 }
1169 Self::Raw { condition, params } if condition == "1 = 1" && params.is_empty() => {
1170 Ok(SqliteRenderedFilters {
1171 native: Vec::new(),
1172 post: Vec::new(),
1173 })
1174 }
1175 Self::Or(_, _) => Ok(SqliteRenderedFilters {
1176 native: Vec::new(),
1177 post: vec![self.render_document(metadata_columns)?],
1178 }),
1179 Self::Not(expr) => expr.render_negated_split(metadata_columns),
1180 Self::NullCheck { .. } | Self::Pattern { .. } => Ok(SqliteRenderedFilters {
1181 native: Vec::new(),
1182 post: vec![self.render_document(metadata_columns)?],
1183 }),
1184 Self::Raw { .. } => Err(sqlite_unsupported_filter(
1185 "raw filters cannot be validated as sqlite-vec metadata constraints",
1186 )),
1187 }
1188 }
1189
1190 fn render_negated_split(
1191 &self,
1192 metadata_columns: &[SqliteMetadataColumn],
1193 ) -> Result<SqliteRenderedFilters, FilterError> {
1194 match self {
1195 Self::Comparison { key, op, value } => {
1196 Self::render_native_comparison(key, op.negate(), value.clone(), metadata_columns)
1197 }
1198 Self::Not(expr) => expr.render_split(metadata_columns),
1199 _ => {
1200 let rendered = self.render_document(metadata_columns)?;
1201 Ok(SqliteRenderedFilters {
1202 native: Vec::new(),
1203 post: vec![SqliteRenderedFilter {
1204 condition: format!("NOT ({})", rendered.condition),
1205 params: rendered.params,
1206 }],
1207 })
1208 }
1209 }
1210 }
1211
1212 fn render_vector(&self) -> Result<SqliteRenderedFilter, FilterError> {
1213 match self {
1214 Self::Comparison { key, op, value } => Ok(SqliteRenderedFilter {
1215 condition: format!("{} {} ?", sqlite_qualify_vector_key(key), op.as_sql()),
1216 params: vec![sqlite_filter_param(value.clone())?],
1217 }),
1218 Self::And(lhs, rhs) => {
1219 let lhs = lhs.render_vector()?;
1220 let rhs = rhs.render_vector()?;
1221 Ok(SqliteRenderedFilter {
1222 condition: format!("({}) AND ({})", lhs.condition, rhs.condition),
1223 params: lhs.params.into_iter().chain(rhs.params).collect(),
1224 })
1225 }
1226 Self::Or(lhs, rhs) => {
1227 let lhs = lhs.render_vector()?;
1228 let rhs = rhs.render_vector()?;
1229 Ok(SqliteRenderedFilter {
1230 condition: format!("({}) OR ({})", lhs.condition, rhs.condition),
1231 params: lhs.params.into_iter().chain(rhs.params).collect(),
1232 })
1233 }
1234 Self::Not(expr) => {
1235 let expr = expr.render_vector()?;
1236 Ok(SqliteRenderedFilter {
1237 condition: format!("NOT ({})", expr.condition),
1238 params: expr.params,
1239 })
1240 }
1241 Self::Between { key, lo, hi } => Ok(SqliteRenderedFilter {
1242 condition: format!("{} between ? and ?", sqlite_qualify_vector_key(key)),
1243 params: vec![
1244 sqlite_filter_param(lo.clone())?,
1245 sqlite_filter_param(hi.clone())?,
1246 ],
1247 }),
1248 Self::NullCheck { key, negated } => {
1249 let operator = if *negated { "is not null" } else { "is null" };
1250 Ok(SqliteRenderedFilter {
1251 condition: format!("{} {operator}", sqlite_qualify_vector_key(key)),
1252 params: Vec::new(),
1253 })
1254 }
1255 Self::Pattern { key, op, pattern } => Ok(SqliteRenderedFilter {
1256 condition: format!("{} {} ?", sqlite_qualify_vector_key(key), op.as_sql()),
1257 params: vec![Value::Text(pattern.clone())],
1258 }),
1259 Self::Raw { condition, params } => Ok(SqliteRenderedFilter {
1260 condition: condition.clone(),
1261 params: params
1262 .iter()
1263 .cloned()
1264 .map(sqlite_filter_param)
1265 .collect::<Result<Vec<_>, _>>()?,
1266 }),
1267 }
1268 }
1269
1270 fn render_document(
1271 &self,
1272 metadata_columns: &[SqliteMetadataColumn],
1273 ) -> Result<SqliteRenderedFilter, FilterError> {
1274 match self {
1275 Self::Comparison { key, op, value } => {
1276 Self::render_document_comparison(key, *op, value.clone(), metadata_columns)
1277 }
1278 Self::And(lhs, rhs) => {
1279 let lhs = lhs.render_document(metadata_columns)?;
1280 let rhs = rhs.render_document(metadata_columns)?;
1281 Ok(SqliteRenderedFilter {
1282 condition: format!("({}) AND ({})", lhs.condition, rhs.condition),
1283 params: lhs.params.into_iter().chain(rhs.params).collect(),
1284 })
1285 }
1286 Self::Or(lhs, rhs) => {
1287 let lhs = lhs.render_document(metadata_columns)?;
1288 let rhs = rhs.render_document(metadata_columns)?;
1289 Ok(SqliteRenderedFilter {
1290 condition: format!("({}) OR ({})", lhs.condition, rhs.condition),
1291 params: lhs.params.into_iter().chain(rhs.params).collect(),
1292 })
1293 }
1294 Self::Not(expr) => {
1295 let expr = expr.render_document(metadata_columns)?;
1296 Ok(SqliteRenderedFilter {
1297 condition: format!("NOT ({})", expr.condition),
1298 params: expr.params,
1299 })
1300 }
1301 Self::Between { key, lo, hi } => {
1302 let key = sqlite_qualify_document_key(key)?;
1303 Ok(SqliteRenderedFilter {
1304 condition: format!("{} between ? and ?", key.expression),
1305 params: vec![
1306 sqlite_document_filter_param(&key, metadata_columns, lo.clone())?,
1307 sqlite_document_filter_param(&key, metadata_columns, hi.clone())?,
1308 ],
1309 })
1310 }
1311 Self::NullCheck { key, negated } => {
1312 let key = sqlite_qualify_document_key(key)?;
1313 let operator = if *negated { "is not null" } else { "is null" };
1314 Ok(SqliteRenderedFilter {
1315 condition: format!("{} {operator}", key.expression),
1316 params: Vec::new(),
1317 })
1318 }
1319 Self::Pattern { key, op, pattern } => {
1320 let key = sqlite_qualify_document_key(key)?;
1321 Ok(SqliteRenderedFilter {
1322 condition: format!("{} {} ?", key.expression, op.as_sql()),
1323 params: vec![Value::Text(pattern.clone())],
1324 })
1325 }
1326 Self::Raw { .. } => Err(sqlite_unsupported_filter(
1327 "raw filters cannot be validated as document-table constraints",
1328 )),
1329 }
1330 }
1331}
1332
1333fn sqlite_native_metadata_column<'a>(
1334 key: &str,
1335 metadata_columns: &'a [SqliteMetadataColumn],
1336) -> Option<&'a SqliteMetadataColumn> {
1337 if !sqlite_is_plain_identifier(key) {
1338 return None;
1339 }
1340
1341 metadata_columns.iter().find(|column| column.name == key)
1342}
1343
1344fn sqlite_is_plain_identifier(key: &str) -> bool {
1345 let mut chars = key.chars();
1346 let Some(first) = chars.next() else {
1347 return false;
1348 };
1349
1350 (first == '_' || first.is_ascii_alphabetic())
1351 && chars.all(|c| c == '_' || c.is_ascii_alphanumeric())
1352}
1353
1354fn sqlite_leading_identifier_len(key: &str) -> Option<usize> {
1355 let mut chars = key.char_indices();
1356 let (_, first) = chars.next()?;
1357 if !(first == '_' || first.is_ascii_alphabetic()) {
1358 return None;
1359 }
1360
1361 let mut end = first.len_utf8();
1362 for (index, c) in chars {
1363 if c == '_' || c.is_ascii_alphanumeric() {
1364 end = index + c.len_utf8();
1365 } else {
1366 break;
1367 }
1368 }
1369
1370 Some(end)
1371}
1372
1373fn sqlite_unsupported_filter(reason: impl Into<String>) -> FilterError {
1374 FilterError::TypeError(format!(
1375 "SQLite filter cannot be safely lowered; {}",
1376 reason.into()
1377 ))
1378}
1379
1380fn sqlite_json_type_name(value: &serde_json::Value) -> &'static str {
1381 match value {
1382 serde_json::Value::Null => "null",
1383 serde_json::Value::Bool(_) => "boolean",
1384 serde_json::Value::Number(_) => "number",
1385 serde_json::Value::String(_) => "string",
1386 serde_json::Value::Array(_) => "array",
1387 serde_json::Value::Object(_) => "object",
1388 }
1389}
1390
1391fn sqlite_metadata_filter_type_error(
1392 column: &SqliteMetadataColumn,
1393 value: &serde_json::Value,
1394 expected: &str,
1395) -> FilterError {
1396 sqlite_unsupported_filter(format!(
1397 "`{}` is a {} metadata column and requires {expected}; got {}",
1398 column.name,
1399 column.metadata_type.vec0_name(),
1400 sqlite_json_type_name(value)
1401 ))
1402}
1403
1404fn sqlite_metadata_filter_param(
1405 column: &SqliteMetadataColumn,
1406 value: serde_json::Value,
1407) -> Result<Value, FilterError> {
1408 match column.metadata_type {
1409 SqliteMetadataType::Text => match value {
1410 serde_json::Value::String(value) => Ok(Value::Text(value)),
1411 value => Err(sqlite_metadata_filter_type_error(
1412 column,
1413 &value,
1414 "a string filter value",
1415 )),
1416 },
1417 SqliteMetadataType::Integer => match value {
1418 serde_json::Value::Number(number) => {
1419 if let Some(value) = number.as_i64() {
1420 Ok(Value::Integer(value))
1421 } else if let Some(value) = number.as_u64() {
1422 i64::try_from(value).map(Value::Integer).map_err(|_| {
1423 FilterError::TypeError(format!(
1424 "SQLite integer filter value `{number}` exceeds i64::MAX"
1425 ))
1426 })
1427 } else {
1428 let value = serde_json::Value::Number(number);
1429 Err(sqlite_metadata_filter_type_error(
1430 column,
1431 &value,
1432 "an integer filter value",
1433 ))
1434 }
1435 }
1436 value => Err(sqlite_metadata_filter_type_error(
1437 column,
1438 &value,
1439 "an integer filter value",
1440 )),
1441 },
1442 SqliteMetadataType::Float => match value {
1443 serde_json::Value::Number(number) => {
1444 number.as_f64().map(Value::Real).ok_or_else(|| {
1445 let value = serde_json::Value::Number(number);
1446 sqlite_metadata_filter_type_error(
1447 column,
1448 &value,
1449 "a finite number filter value",
1450 )
1451 })
1452 }
1453 value => Err(sqlite_metadata_filter_type_error(
1454 column,
1455 &value,
1456 "a finite number filter value",
1457 )),
1458 },
1459 SqliteMetadataType::Boolean => match value {
1460 serde_json::Value::Bool(value) => Ok(Value::Integer(value as i64)),
1461 value => Err(sqlite_metadata_filter_type_error(
1462 column,
1463 &value,
1464 "a boolean filter value",
1465 )),
1466 },
1467 }
1468}
1469
1470fn sqlite_filter_param(value: serde_json::Value) -> Result<Value, FilterError> {
1471 use serde_json::Value::*;
1472
1473 match value {
1474 Null => Ok(Value::Null),
1475 Bool(b) => Ok(Value::Integer(b as i64)),
1476 String(s) => Ok(Value::Text(s)),
1477 Number(n) => Ok(if let Some(value) = n.as_i64() {
1478 Value::Integer(value)
1479 } else if let Some(value) = n.as_u64() {
1480 let value = i64::try_from(value).map_err(|_| {
1481 FilterError::TypeError(format!(
1482 "SQLite integer filter value `{n}` exceeds i64::MAX"
1483 ))
1484 })?;
1485 Value::Integer(value)
1486 } else if let Some(float) = n.as_f64() {
1487 Value::Real(float)
1488 } else {
1489 Value::Text(n.to_string())
1490 }),
1491 Array(arr) => {
1492 let blob =
1493 serde_json::to_vec(&arr).map_err(|e| FilterError::Serialization(e.to_string()))?;
1494
1495 Ok(Value::Blob(blob))
1496 }
1497 Object(obj) => {
1498 let blob =
1499 serde_json::to_vec(&obj).map_err(|e| FilterError::Serialization(e.to_string()))?;
1500
1501 Ok(Value::Blob(blob))
1502 }
1503 }
1504}
1505
1506fn sqlite_key_is_qualified(key: &str) -> bool {
1507 key.contains('.') || key.contains('(') || key.contains(' ') || key.contains('?')
1508}
1509
1510fn sqlite_qualify_vector_key(key: &str) -> String {
1511 if sqlite_key_is_qualified(key) {
1512 key.to_string()
1513 } else {
1514 format!("e.{key}")
1515 }
1516}
1517
1518fn sqlite_qualify_document_key(key: &str) -> Result<SqliteQualifiedDocumentKey, FilterError> {
1519 if let Some(key_without_alias) = key.strip_prefix("d.") {
1520 if sqlite_is_plain_identifier(key_without_alias) {
1521 return Ok(SqliteQualifiedDocumentKey {
1522 expression: key.to_string(),
1523 value_mode: SqliteDocumentValueMode::Sql,
1524 plain_column: Some(key_without_alias.to_string()),
1525 });
1526 }
1527
1528 if let Some(value_mode) = sqlite_json_operator_value_mode(key_without_alias) {
1529 return Ok(SqliteQualifiedDocumentKey {
1530 expression: key.to_string(),
1531 value_mode,
1532 plain_column: None,
1533 });
1534 }
1535
1536 return Err(sqlite_unsupported_filter(format!(
1537 "`{key}` is not a supported SQLite document filter expression"
1538 )));
1539 }
1540
1541 if sqlite_is_plain_identifier(key) {
1542 return Ok(SqliteQualifiedDocumentKey {
1543 expression: format!("d.{key}"),
1544 value_mode: SqliteDocumentValueMode::Sql,
1545 plain_column: Some(key.to_string()),
1546 });
1547 }
1548
1549 if let Some(value_mode) = sqlite_json_operator_value_mode(key) {
1550 return Ok(SqliteQualifiedDocumentKey {
1551 expression: format!("d.{key}"),
1552 value_mode,
1553 plain_column: None,
1554 });
1555 }
1556
1557 Err(sqlite_unsupported_filter(format!(
1558 "`{key}` is not a supported SQLite document filter expression"
1559 )))
1560}
1561
1562fn sqlite_document_filter_param(
1563 key: &SqliteQualifiedDocumentKey,
1564 metadata_columns: &[SqliteMetadataColumn],
1565 value: serde_json::Value,
1566) -> Result<Value, FilterError> {
1567 match key.value_mode {
1568 SqliteDocumentValueMode::Sql => {
1569 if let Some(column_name) = key.plain_column.as_deref()
1570 && let Some(metadata_column) = metadata_columns
1571 .iter()
1572 .find(|column| column.name == column_name)
1573 {
1574 return sqlite_metadata_filter_param(metadata_column, value);
1575 }
1576
1577 sqlite_filter_param(value)
1578 }
1579 SqliteDocumentValueMode::JsonText => serde_json::to_string(&value)
1580 .map(Value::Text)
1581 .map_err(|e| FilterError::Serialization(e.to_string())),
1582 }
1583}
1584
1585fn sqlite_json_operator_value_mode(expr: &str) -> Option<SqliteDocumentValueMode> {
1586 let mut index = sqlite_leading_identifier_len(expr)?;
1587
1588 if index == expr.len() {
1589 return None;
1590 }
1591
1592 let mut value_mode = None;
1593 while index < expr.len() {
1594 let remaining = &expr[index..];
1595 let (operator_len, next_value_mode) = if remaining.starts_with("->>") {
1596 (3, SqliteDocumentValueMode::Sql)
1597 } else if remaining.starts_with("->") {
1598 (2, SqliteDocumentValueMode::JsonText)
1599 } else {
1600 return None;
1601 };
1602 value_mode = Some(next_value_mode);
1603 index += operator_len;
1604
1605 let operand_len = sqlite_json_operator_operand_len(&expr[index..])?;
1606 index += operand_len;
1607 }
1608
1609 value_mode
1610}
1611
1612fn sqlite_json_operator_operand_len(operand: &str) -> Option<usize> {
1613 if operand.is_empty() {
1614 return None;
1615 }
1616
1617 if let Some(operand) = operand.strip_prefix('\'') {
1618 let closing_quote = operand.find('\'')?;
1619 let literal = &operand[..closing_quote];
1620 if literal.chars().any(char::is_control) {
1621 return None;
1622 }
1623
1624 return Some(closing_quote + 2);
1625 }
1626
1627 let mut chars = operand.char_indices();
1628 let mut end = 0;
1629 if let Some((_, '-')) = chars.clone().next() {
1630 end = 1;
1631 chars.next();
1632 }
1633
1634 let mut has_digit = false;
1635 for (index, c) in chars {
1636 if c.is_ascii_digit() {
1637 has_digit = true;
1638 end = index + c.len_utf8();
1639 } else {
1640 break;
1641 }
1642 }
1643
1644 has_digit.then_some(end)
1645}
1646
1647pub struct SqliteVectorIndex<E, T>
1746where
1747 E: EmbeddingModel + 'static,
1748 T: SqliteVectorStoreTable + 'static,
1749{
1750 store: SqliteVectorStore<E, T>,
1751 embedding_model: E,
1752}
1753
1754impl<E, T> SqliteVectorIndex<E, T>
1755where
1756 E: EmbeddingModel + 'static,
1757 T: SqliteVectorStoreTable,
1758{
1759 pub fn new(embedding_model: E, store: SqliteVectorStore<E, T>) -> Self {
1760 Self {
1761 store,
1762 embedding_model,
1763 }
1764 }
1765}
1766
1767fn sqlite_distance_metric_from_schema(schema_sql: &str) -> SqliteDistanceMetric {
1768 let normalized = sqlite_normalized_schema(schema_sql);
1769
1770 if normalized.contains("distance_metric=cosine") {
1771 SqliteDistanceMetric::Cosine
1772 } else if normalized.contains("distance_metric=l1") {
1773 SqliteDistanceMetric::L1
1774 } else {
1775 SqliteDistanceMetric::L2
1776 }
1777}
1778
1779fn sqlite_normalized_schema(schema_sql: &str) -> String {
1780 schema_sql
1781 .chars()
1782 .filter(|c| !c.is_whitespace())
1783 .flat_map(char::to_lowercase)
1784 .collect()
1785}
1786
1787fn sqlite_schema_contains_metadata_column(schema_sql: &str, column: &SqliteMetadataColumn) -> bool {
1788 let normalized = sqlite_normalized_schema(schema_sql);
1789 let column_sql = format!(
1790 ",{}{}",
1791 column.name.to_ascii_lowercase(),
1792 column.metadata_type.vec0_name().to_ascii_lowercase()
1793 );
1794
1795 normalized.contains(&column_sql)
1796}
1797
1798struct SqliteSearchQuery {
1799 vector_where_clause: String,
1800 document_filter_clause: String,
1801 params: Vec<Value>,
1802}
1803
1804fn render_search_filters(
1805 req: &VectorSearchRequest<SqliteSearchFilter>,
1806 distance_metric: SqliteDistanceMetric,
1807 metadata_columns: &[SqliteMetadataColumn],
1808) -> Result<SqliteRenderedFilters, FilterError> {
1809 let score_expression = distance_metric.score_expression("?1", "e.embedding");
1810 let threshold_filter = req.threshold().map(|threshold| {
1811 SqliteSearchFilter::raw(format!("{score_expression} >= ?"), vec![threshold.into()])
1812 });
1813
1814 let mut filters = SqliteRenderedFilters::default();
1815 if let Some(threshold_filter) = threshold_filter {
1816 filters.native.push(threshold_filter.expr.render_vector()?);
1817 }
1818 if let Some(filter) = req.filter() {
1819 filters.extend(filter.render_split(metadata_columns)?);
1820 }
1821
1822 Ok(filters)
1823}
1824
1825fn build_search_query(
1826 query_vec: Vec<f32>,
1827 filters: SqliteRenderedFilters,
1828 candidate_limit: u64,
1829) -> Result<SqliteSearchQuery, FilterError> {
1830 let brute_force = candidate_limit > SQLITE_VEC_MAX_K;
1837
1838 let mut conditions = Vec::new();
1839 if !brute_force {
1840 conditions.push("e.embedding MATCH ?".to_string());
1841 conditions.push("k = ?".to_string());
1842 }
1843 conditions.extend(
1844 filters
1845 .native
1846 .iter()
1847 .map(|filter| format!("({})", filter.condition)),
1848 );
1849
1850 let vector_where_clause = if conditions.is_empty() {
1853 String::new()
1854 } else {
1855 format!("WHERE {}", conditions.join(" AND "))
1856 };
1857 let document_filter_clause = if filters.post.is_empty() {
1858 String::new()
1859 } else {
1860 format!(
1861 "AND {}",
1862 filters
1863 .post
1864 .iter()
1865 .map(|filter| format!("({})", filter.condition))
1866 .collect::<Vec<_>>()
1867 .join(" AND ")
1868 )
1869 };
1870
1871 let query_vec = query_vec.into_iter().flat_map(f32::to_le_bytes).collect();
1872 let query_vec = Value::Blob(query_vec);
1873
1874 let mut params = if brute_force {
1882 vec![query_vec]
1883 } else {
1884 let candidate_limit = sqlite_limit_param(candidate_limit, "candidate limit")?;
1885 vec![query_vec.clone(), query_vec, candidate_limit]
1886 };
1887 params.extend(filters.native.into_iter().flat_map(|filter| filter.params));
1888 params.extend(filters.post.into_iter().flat_map(|filter| filter.params));
1889
1890 Ok(SqliteSearchQuery {
1891 vector_where_clause,
1892 document_filter_clause,
1893 params,
1894 })
1895}
1896
1897#[cfg(test)]
1898fn build_where_clause(
1899 req: &VectorSearchRequest<SqliteSearchFilter>,
1900 query_vec: Vec<f32>,
1901 distance_metric: SqliteDistanceMetric,
1902 metadata_columns: &[SqliteMetadataColumn],
1903 candidate_limit: u64,
1904) -> Result<(String, Vec<Value>), FilterError> {
1905 let filters = render_search_filters(req, distance_metric, metadata_columns)?;
1906 let query = build_search_query(query_vec, filters, candidate_limit)?;
1907 Ok((query.vector_where_clause, query.params))
1908}
1909
1910fn sqlite_limit_param(value: u64, name: &str) -> Result<Value, FilterError> {
1911 i64::try_from(value)
1912 .map(Value::Integer)
1913 .map_err(|_| FilterError::TypeError(format!("SQLite {name} `{value}` exceeds i64::MAX")))
1914}
1915
1916#[derive(Debug)]
1917struct SqliteColumnValueError {
1918 column_name: &'static str,
1919 column_type: &'static str,
1920 message: String,
1921}
1922
1923impl Display for SqliteColumnValueError {
1924 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
1925 write!(
1926 f,
1927 "could not convert SQLite column `{}` with declared type `{}`: {}",
1928 self.column_name, self.column_type, self.message
1929 )
1930 }
1931}
1932
1933impl std::error::Error for SqliteColumnValueError {}
1934
1935fn sqlite_column_value_error(
1936 index: usize,
1937 value_type: Type,
1938 column: &Column,
1939 message: impl Into<String>,
1940) -> rusqlite::Error {
1941 rusqlite::Error::FromSqlConversionFailure(
1942 index,
1943 value_type,
1944 Box::new(SqliteColumnValueError {
1945 column_name: column.name,
1946 column_type: column.col_type,
1947 message: message.into(),
1948 }),
1949 )
1950}
1951
1952fn sqlite_number_value(
1953 index: usize,
1954 value_type: Type,
1955 column: &Column,
1956 value: f64,
1957) -> rusqlite::Result<serde_json::Value> {
1958 let number = serde_json::Number::from_f64(value).ok_or_else(|| {
1959 sqlite_column_value_error(index, value_type, column, "non-finite float value")
1960 })?;
1961
1962 Ok(serde_json::Value::Number(number))
1963}
1964
1965fn sqlite_text_value(
1966 index: usize,
1967 value_type: Type,
1968 column: &Column,
1969 value: &[u8],
1970) -> rusqlite::Result<serde_json::Value> {
1971 let value = std::str::from_utf8(value).map_err(|e| {
1972 sqlite_column_value_error(
1973 index,
1974 value_type,
1975 column,
1976 format!("invalid UTF-8 text: {e}"),
1977 )
1978 })?;
1979
1980 Ok(serde_json::Value::String(value.to_string()))
1981}
1982
1983fn sqlite_column_declares_json(column_type: &str) -> bool {
1984 column_type
1985 .split_whitespace()
1986 .next()
1987 .is_some_and(|token| token.eq_ignore_ascii_case("JSON"))
1988}
1989
1990fn sqlite_json_text_value(
1991 index: usize,
1992 value_type: Type,
1993 column: &Column,
1994 value: &[u8],
1995) -> rusqlite::Result<serde_json::Value> {
1996 let value = std::str::from_utf8(value).map_err(|e| {
1997 sqlite_column_value_error(
1998 index,
1999 value_type,
2000 column,
2001 format!("invalid UTF-8 JSON text: {e}"),
2002 )
2003 })?;
2004
2005 serde_json::from_str(value).map_err(|e| {
2006 sqlite_column_value_error(index, value_type, column, format!("invalid JSON text: {e}"))
2007 })
2008}
2009
2010fn sqlite_column_value_to_json(
2011 index: usize,
2012 column: &Column,
2013 value: ValueRef<'_>,
2014) -> rusqlite::Result<serde_json::Value> {
2015 let value_type = value.data_type();
2016
2017 if sqlite_column_declares_json(column.col_type) {
2018 return match value {
2019 ValueRef::Null => Ok(serde_json::Value::Null),
2020 ValueRef::Text(value) => sqlite_json_text_value(index, value_type, column, value),
2021 ValueRef::Integer(value) => Ok(serde_json::Value::Number(value.into())),
2022 ValueRef::Real(value) => sqlite_number_value(index, value_type, column, value),
2023 ValueRef::Blob(value) => sqlite_json_text_value(index, value_type, column, value),
2024 };
2025 }
2026
2027 let column_affinity = SqliteColumnAffinity::from_column_type(column.col_type);
2028
2029 match (column_affinity, value) {
2030 (_, ValueRef::Null) => Ok(serde_json::Value::Null),
2031 (SqliteColumnAffinity::Boolean, ValueRef::Integer(0)) => Ok(serde_json::Value::Bool(false)),
2032 (SqliteColumnAffinity::Boolean, ValueRef::Integer(1)) => Ok(serde_json::Value::Bool(true)),
2033 (SqliteColumnAffinity::Boolean, _) => Err(sqlite_column_value_error(
2034 index,
2035 value_type,
2036 column,
2037 "stored SQLite boolean value must be 0 or 1",
2038 )),
2039 (_, ValueRef::Text(value)) => sqlite_text_value(index, value_type, column, value),
2040 (_, ValueRef::Integer(value)) => Ok(serde_json::Value::Number(value.into())),
2041 (_, ValueRef::Real(value)) => sqlite_number_value(index, value_type, column, value),
2042 (_, ValueRef::Blob(value)) => Ok(serde_json::to_value(value)
2043 .map_err(|e| sqlite_column_value_error(index, value_type, column, e.to_string()))?),
2044 }
2045}
2046
2047fn sqlite_id_value_to_string(index: usize, value: ValueRef<'_>) -> rusqlite::Result<String> {
2048 match value {
2049 ValueRef::Integer(value) => Ok(value.to_string()),
2050 ValueRef::Real(value) => Ok(value.to_string()),
2051 ValueRef::Text(value) => std::str::from_utf8(value)
2052 .map(ToString::to_string)
2053 .map_err(|e| {
2054 rusqlite::Error::FromSqlConversionFailure(
2055 index,
2056 Type::Text,
2057 Box::new(SqliteColumnValueError {
2058 column_name: "id",
2059 column_type: "TEXT",
2060 message: format!("invalid UTF-8 text: {e}"),
2061 }),
2062 )
2063 }),
2064 value => Err(rusqlite::Error::FromSqlConversionFailure(
2065 index,
2066 value.data_type(),
2067 Box::new(SqliteColumnValueError {
2068 column_name: "id",
2069 column_type: "TEXT or INTEGER",
2070 message: "id cannot be NULL or BLOB".to_string(),
2071 }),
2072 )),
2073 }
2074}
2075
2076impl<E: EmbeddingModel + std::marker::Sync, T: SqliteVectorStoreTable> VectorStoreIndex
2077 for SqliteVectorIndex<E, T>
2078{
2079 type Filter = SqliteSearchFilter;
2080
2081 async fn top_n<D>(
2082 &self,
2083 req: VectorSearchRequest<SqliteSearchFilter>,
2084 ) -> Result<Vec<(f64, String, D)>, VectorStoreError>
2085 where
2086 D: for<'de> Deserialize<'de>,
2087 {
2088 tracing::debug!("Finding top {} matches for query", req.samples() as usize);
2089 if req.samples() == 0 {
2090 return Ok(Vec::new());
2091 }
2092
2093 let embedding = self.embedding_model.embed_text(req.query()).await?;
2094 let query_vec: Vec<f32> = serialize_embedding(&embedding);
2095 let table_name = T::name();
2096 let embedding_map_table_name = format!("{table_name}_embedding_map");
2097
2098 let columns = T::schema();
2099 let id_column_index = columns
2100 .iter()
2101 .position(|column| column.name == "id")
2102 .ok_or_else(|| {
2103 VectorStoreError::DatastoreError(Box::new(SqliteMissingIdColumn {
2104 table_name: table_name.to_string(),
2105 }))
2106 })?;
2107
2108 let outer_select_cols = columns
2109 .iter()
2110 .map(|column| format!("d.{} AS {}", column.name, column.name))
2111 .collect::<Vec<_>>()
2112 .join(", ");
2113
2114 let distance_metric = self.store.distance_metric;
2115 let score_expression = distance_metric.score_expression("?1", "e.embedding");
2116 let filters = render_search_filters(&req, distance_metric, &self.store.metadata_columns)?;
2117 let candidate_limit = self
2118 .store
2119 .candidate_limit(req.samples(), filters.has_post_filters())
2120 .await?;
2121 let search_query = build_search_query(query_vec, filters, candidate_limit)?;
2122 let where_clause = search_query.vector_where_clause;
2123 let document_filter_clause = search_query.document_filter_clause;
2124 let mut params = search_query.params;
2125 params.push(sqlite_limit_param(req.samples(), "result limit")?);
2126
2127 let rows = self
2128 .store
2129 .conn
2130 .call(move |conn| {
2131 let mut stmt = conn.prepare(&format!(
2132 "WITH scored AS (
2133 SELECT m.document_rowid AS __rig_document_rowid,
2134 {score_expression} AS __rig_score,
2135 ROW_NUMBER() OVER (
2136 PARTITION BY m.document_rowid
2137 ORDER BY {score_expression} DESC, e.rowid ASC
2138 ) AS __rig_rank
2139 FROM {table_name}_embeddings e
2140 JOIN {embedding_map_table_name} m ON e.rowid = m.embedding_rowid
2141 {where_clause}
2142 )
2143 SELECT {outer_select_cols}, scored.__rig_score
2144 FROM scored
2145 JOIN {table_name} d ON scored.__rig_document_rowid = d.rowid
2146 WHERE scored.__rig_rank = 1
2147 {document_filter_clause}
2148 ORDER BY scored.__rig_score DESC, d.id ASC
2149 LIMIT ?"
2150 ))?;
2151
2152 let rows = stmt
2153 .query_map(rusqlite::params_from_iter(params), |row| {
2154 let mut map = serde_json::Map::new();
2156 for (i, column) in columns.iter().enumerate() {
2157 let value = sqlite_column_value_to_json(i, column, row.get_ref(i)?)?;
2158 map.insert(column.name.to_string(), value);
2159 }
2160 let score: f64 = row.get(columns.len())?;
2161 let id = sqlite_id_value_to_string(
2162 id_column_index,
2163 row.get_ref(id_column_index)?,
2164 )?;
2165
2166 Ok((id, serde_json::Value::Object(map), score))
2167 })?
2168 .collect::<Result<Vec<_>, _>>()?;
2169 Ok(rows)
2170 })
2171 .await
2172 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
2173
2174 debug!("Found {} potential matches", rows.len());
2175 let mut top_n = Vec::new();
2176 for (id, doc_value, score) in rows {
2177 match serde_json::from_value::<D>(doc_value) {
2178 Ok(doc) => {
2179 top_n.push((score, id, doc));
2180 }
2181 Err(e) => {
2182 debug!("Failed to deserialize document {}: {}", id, e);
2183 continue;
2184 }
2185 }
2186 }
2187
2188 debug!("Returning {} matches", top_n.len());
2189 Ok(top_n)
2190 }
2191
2192 async fn top_n_ids(
2193 &self,
2194 req: VectorSearchRequest<SqliteSearchFilter>,
2195 ) -> Result<Vec<(f64, String)>, VectorStoreError> {
2196 tracing::debug!(
2197 "Finding top {} document IDs for query",
2198 req.samples() as usize
2199 );
2200 if req.samples() == 0 {
2201 return Ok(Vec::new());
2202 }
2203
2204 let embedding = self.embedding_model.embed_text(req.query()).await?;
2205 let query_vec = serialize_embedding(&embedding);
2206 let table_name = T::name();
2207 let embedding_map_table_name = format!("{table_name}_embedding_map");
2208
2209 let distance_metric = self.store.distance_metric;
2210 let score_expression = distance_metric.score_expression("?1", "e.embedding");
2211 let filters = render_search_filters(&req, distance_metric, &self.store.metadata_columns)?;
2212 let candidate_limit = self
2213 .store
2214 .candidate_limit(req.samples(), filters.has_post_filters())
2215 .await?;
2216 let search_query = build_search_query(query_vec, filters, candidate_limit)?;
2217 let where_clause = search_query.vector_where_clause;
2218 let document_filter_clause = search_query.document_filter_clause;
2219 let mut params = search_query.params;
2220 params.push(sqlite_limit_param(req.samples(), "result limit")?);
2221
2222 let results = self
2223 .store
2224 .conn
2225 .call(move |conn| {
2226 let mut stmt = conn.prepare(&format!(
2227 "WITH scored AS (
2228 SELECT m.document_rowid AS __rig_document_rowid,
2229 {score_expression} AS __rig_score,
2230 ROW_NUMBER() OVER (
2231 PARTITION BY m.document_rowid
2232 ORDER BY {score_expression} DESC, e.rowid ASC
2233 ) AS __rig_rank
2234 FROM {table_name}_embeddings e
2235 JOIN {embedding_map_table_name} m ON e.rowid = m.embedding_rowid
2236 {where_clause}
2237 )
2238 SELECT d.id, scored.__rig_score
2239 FROM scored
2240 JOIN {table_name} d ON scored.__rig_document_rowid = d.rowid
2241 WHERE scored.__rig_rank = 1
2242 {document_filter_clause}
2243 ORDER BY scored.__rig_score DESC, d.id ASC
2244 LIMIT ?"
2245 ))?;
2246
2247 let results = stmt
2248 .query_map(rusqlite::params_from_iter(params), |row| {
2249 Ok((
2250 row.get::<_, f64>(1)?,
2251 sqlite_id_value_to_string(0, row.get_ref(0)?)?,
2252 ))
2253 })?
2254 .collect::<Result<Vec<_>, _>>()?;
2255 Ok(results)
2256 })
2257 .await
2258 .map_err(|e| VectorStoreError::DatastoreError(Box::new(e)))?;
2259
2260 debug!("Found {} matching document IDs", results.len());
2261 Ok(results)
2262 }
2263}
2264
2265fn serialize_embedding(embedding: &Embedding) -> Vec<f32> {
2266 embedding.vec.iter().map(|x| *x as f32).collect()
2267}
2268
2269impl ColumnValue for String {
2270 fn to_sql_value(&self) -> Value {
2271 Value::Text(self.clone())
2272 }
2273
2274 fn column_type(&self) -> &'static str {
2275 "TEXT"
2276 }
2277}
2278
2279impl ColumnValue for i64 {
2280 fn to_sql_value(&self) -> Value {
2281 Value::Integer(*self)
2282 }
2283
2284 fn column_type(&self) -> &'static str {
2285 "INTEGER"
2286 }
2287}
2288
2289impl ColumnValue for i32 {
2290 fn to_sql_value(&self) -> Value {
2291 Value::Integer(i64::from(*self))
2292 }
2293
2294 fn column_type(&self) -> &'static str {
2295 "INTEGER"
2296 }
2297}
2298
2299impl ColumnValue for f64 {
2300 fn to_sql_value(&self) -> Value {
2301 Value::Real(*self)
2302 }
2303
2304 fn column_type(&self) -> &'static str {
2305 "FLOAT"
2306 }
2307}
2308
2309impl ColumnValue for f32 {
2310 fn to_sql_value(&self) -> Value {
2311 Value::Real(f64::from(*self))
2312 }
2313
2314 fn column_type(&self) -> &'static str {
2315 "FLOAT"
2316 }
2317}
2318
2319impl ColumnValue for bool {
2320 fn to_sql_value(&self) -> Value {
2321 Value::Integer(if *self { 1 } else { 0 })
2322 }
2323
2324 fn column_type(&self) -> &'static str {
2325 "BOOLEAN"
2326 }
2327}
2328
2329impl ColumnValue for serde_json::Value {
2330 fn to_sql_value(&self) -> Value {
2331 Value::Text(self.to_string())
2332 }
2333
2334 fn column_type(&self) -> &'static str {
2335 "JSON"
2336 }
2337}
2338
2339#[cfg(test)]
2340mod tests {
2341 use super::*;
2342 use rig_core::embeddings::EmbeddingError;
2343 use rusqlite::ffi::{sqlite3, sqlite3_api_routines, sqlite3_auto_extension};
2344 use sqlite_vec::sqlite3_vec_init;
2345 use std::cmp::Ordering;
2346 use std::os::raw::c_char;
2347 use std::sync::Once;
2348 use tokio_rusqlite::Connection;
2349
2350 const SCORE_EPSILON: f64 = 1e-5;
2351
2352 fn test_metadata_columns() -> Vec<SqliteMetadataColumn> {
2353 vec![SqliteMetadataColumn {
2354 name: "category",
2355 metadata_type: SqliteMetadataType::Text,
2356 }]
2357 }
2358
2359 fn typed_metadata_columns() -> Vec<SqliteMetadataColumn> {
2360 vec![
2361 SqliteMetadataColumn {
2362 name: "priority",
2363 metadata_type: SqliteMetadataType::Integer,
2364 },
2365 SqliteMetadataColumn {
2366 name: "rating",
2367 metadata_type: SqliteMetadataType::Float,
2368 },
2369 SqliteMetadataColumn {
2370 name: "published",
2371 metadata_type: SqliteMetadataType::Boolean,
2372 },
2373 ]
2374 }
2375
2376 #[test]
2377 fn json_column_text_decodes_to_json_object() -> anyhow::Result<()> {
2378 let column = Column::new("metadata", "JSON");
2379 let value = sqlite_column_value_to_json(
2380 0,
2381 &column,
2382 ValueRef::Text(br#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#),
2383 )?;
2384
2385 let expected = serde_json::json!({
2386 "knowledge_doc_id": 361,
2387 "knowledge_id": 1,
2388 "user_id": 1
2389 });
2390 anyhow::ensure!(
2391 value == expected,
2392 "JSON column text should decode to a JSON object, got {value:?}"
2393 );
2394
2395 Ok(())
2396 }
2397
2398 #[test]
2399 fn text_column_json_looking_text_stays_string() -> anyhow::Result<()> {
2400 let column = Column::new("metadata", "TEXT");
2401 let value = sqlite_column_value_to_json(
2402 0,
2403 &column,
2404 ValueRef::Text(br#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#),
2405 )?;
2406
2407 let expected =
2408 serde_json::json!(r#"{"knowledge_doc_id":361,"knowledge_id":1,"user_id":1}"#);
2409 anyhow::ensure!(
2410 value == expected,
2411 "TEXT column should preserve JSON-looking text as a string, got {value:?}"
2412 );
2413
2414 Ok(())
2415 }
2416
2417 #[test]
2418 fn json_column_invalid_text_returns_conversion_error() -> anyhow::Result<()> {
2419 let column = Column::new("metadata", "JSON");
2420 let err = match sqlite_column_value_to_json(0, &column, ValueRef::Text(b"not json")) {
2421 Ok(value) => anyhow::bail!("invalid JSON column text should fail, got {value:?}"),
2422 Err(err) => err,
2423 };
2424
2425 anyhow::ensure!(
2426 matches!(
2427 err,
2428 rusqlite::Error::FromSqlConversionFailure(0, Type::Text, _)
2429 ),
2430 "invalid JSON column text should return a conversion error, got {err}"
2431 );
2432
2433 Ok(())
2434 }
2435
2436 #[test]
2437 fn serde_json_value_column_value_round_trips_json_column() -> anyhow::Result<()> {
2438 let value = serde_json::json!({
2439 "knowledge_doc_id": 361,
2440 "knowledge_id": 1,
2441 "user_id": 1
2442 });
2443 anyhow::ensure!(
2444 value.column_type() == "JSON",
2445 "serde_json::Value should declare JSON column type"
2446 );
2447
2448 let text = match value.to_sql_value() {
2449 Value::Text(text) => text,
2450 value => {
2451 anyhow::bail!("serde_json::Value should serialize as JSON text, got {value:?}")
2452 }
2453 };
2454
2455 let column = Column::new("metadata", "JSON");
2456 let round_trip = sqlite_column_value_to_json(0, &column, ValueRef::Text(text.as_bytes()))?;
2457 anyhow::ensure!(
2458 round_trip == value,
2459 "serde_json::Value should round-trip through a JSON column, got {round_trip:?}"
2460 );
2461
2462 Ok(())
2463 }
2464
2465 fn filter_error<T: std::fmt::Debug>(
2466 result: Result<T, FilterError>,
2467 context: &str,
2468 ) -> anyhow::Result<FilterError> {
2469 match result {
2470 Ok(value) => anyhow::bail!("{context} should have failed, got {value:?}"),
2471 Err(err) => Ok(err),
2472 }
2473 }
2474
2475 fn ensure_vector_store_filter_error<T: std::fmt::Debug>(
2476 result: Result<T, VectorStoreError>,
2477 context: &str,
2478 ) -> anyhow::Result<()> {
2479 match result {
2480 Err(VectorStoreError::FilterError(_)) => Ok(()),
2481 Err(err) => anyhow::bail!("{context} returned unexpected error: {err}"),
2482 Ok(value) => anyhow::bail!("{context} should have failed, got {value:?}"),
2483 }
2484 }
2485
2486 #[test]
2487 fn threshold_filter_uses_computed_similarity_expression() -> anyhow::Result<()> {
2488 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2489 .query("needle")
2490 .samples(5)
2491 .threshold(0.95)
2492 .build();
2493
2494 let (where_clause, params) =
2495 build_where_clause(&req, vec![1.0, 0.0], SqliteDistanceMetric::Cosine, &[], 5)?;
2496
2497 anyhow::ensure!(
2498 where_clause.contains("e.embedding MATCH ?"),
2499 "missing vector match constraint: {where_clause}"
2500 );
2501 anyhow::ensure!(
2502 where_clause.contains("k = ?"),
2503 "missing vector k constraint: {where_clause}"
2504 );
2505 anyhow::ensure!(
2506 where_clause.contains("(1 - vec_distance_cosine(?1, e.embedding)) >= ?"),
2507 "threshold should use computed similarity expression: {where_clause}"
2508 );
2509 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2510 anyhow::ensure!(
2511 params.get(3) == Some(&Value::Real(0.95)),
2512 "unexpected threshold param: {params:?}"
2513 );
2514
2515 Ok(())
2516 }
2517
2518 #[test]
2519 fn l2_threshold_filter_uses_l2_score_expression() -> anyhow::Result<()> {
2520 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2521 .query("needle")
2522 .samples(5)
2523 .threshold(-1.5)
2524 .build();
2525
2526 let (where_clause, params) =
2527 build_where_clause(&req, vec![1.0, 0.0], SqliteDistanceMetric::L2, &[], 5)?;
2528
2529 anyhow::ensure!(
2530 where_clause.contains("(-vec_distance_l2(?1, e.embedding)) >= ?"),
2531 "threshold should use L2 score expression: {where_clause}"
2532 );
2533 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2534 anyhow::ensure!(
2535 params.get(3) == Some(&Value::Real(-1.5)),
2536 "unexpected threshold param: {params:?}"
2537 );
2538
2539 Ok(())
2540 }
2541
2542 #[test]
2543 fn no_threshold_does_not_add_similarity_predicate() -> anyhow::Result<()> {
2544 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2545 .query("needle")
2546 .samples(5)
2547 .build();
2548
2549 let (where_clause, params) =
2550 build_where_clause(&req, vec![1.0, 0.0], SqliteDistanceMetric::Cosine, &[], 5)?;
2551
2552 anyhow::ensure!(
2553 where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2554 "unexpected where clause: {where_clause}"
2555 );
2556 anyhow::ensure!(params.len() == 3, "unexpected params: {params:?}");
2557
2558 Ok(())
2559 }
2560
2561 #[test]
2562 fn candidate_limit_at_k_cap_still_uses_knn_path() -> anyhow::Result<()> {
2563 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2564 .query("needle")
2565 .samples(5)
2566 .build();
2567
2568 let (where_clause, params) = build_where_clause(
2569 &req,
2570 vec![1.0, 0.0],
2571 SqliteDistanceMetric::Cosine,
2572 &[],
2573 SQLITE_VEC_MAX_K,
2574 )?;
2575
2576 anyhow::ensure!(
2577 where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2578 "candidate limit at the cap should keep the KNN path: {where_clause}"
2579 );
2580 anyhow::ensure!(params.len() == 3, "unexpected params: {params:?}");
2582 anyhow::ensure!(
2583 params.get(2) == Some(&Value::Integer(SQLITE_VEC_MAX_K as i64)),
2584 "k param should be the candidate limit: {params:?}"
2585 );
2586
2587 Ok(())
2588 }
2589
2590 #[test]
2591 fn candidate_limit_above_k_cap_falls_back_to_brute_force_scan() -> anyhow::Result<()> {
2592 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2593 .query("needle")
2594 .samples(5)
2595 .build();
2596
2597 let (where_clause, params) = build_where_clause(
2598 &req,
2599 vec![1.0, 0.0],
2600 SqliteDistanceMetric::Cosine,
2601 &[],
2602 SQLITE_VEC_MAX_K + 1,
2603 )?;
2604
2605 anyhow::ensure!(
2609 !where_clause.contains("MATCH") && !where_clause.contains("k = ?"),
2610 "brute-force scan should drop the KNN constraints: {where_clause}"
2611 );
2612 anyhow::ensure!(
2613 where_clause.is_empty(),
2614 "brute-force scan without filters should emit no WHERE clause: {where_clause:?}"
2615 );
2616 anyhow::ensure!(params.len() == 1, "unexpected params: {params:?}");
2619 anyhow::ensure!(
2620 matches!(params.first(), Some(Value::Blob(_))),
2621 "remaining param should be the query vector: {params:?}"
2622 );
2623
2624 Ok(())
2625 }
2626
2627 #[test]
2628 fn brute_force_scan_keeps_filter_params_aligned() -> anyhow::Result<()> {
2629 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2630 .query("needle")
2631 .samples(5)
2632 .threshold(0.95)
2633 .build();
2634
2635 let (where_clause, params) = build_where_clause(
2636 &req,
2637 vec![1.0, 0.0],
2638 SqliteDistanceMetric::Cosine,
2639 &[],
2640 SQLITE_VEC_MAX_K + 1,
2641 )?;
2642
2643 anyhow::ensure!(
2646 where_clause == "WHERE ((1 - vec_distance_cosine(?1, e.embedding)) >= ?)",
2647 "brute-force scan should keep native filters: {where_clause}"
2648 );
2649 anyhow::ensure!(params.len() == 2, "unexpected params: {params:?}");
2650 anyhow::ensure!(
2651 matches!(params.first(), Some(Value::Blob(_))),
2652 "first param should be the query vector: {params:?}"
2653 );
2654 anyhow::ensure!(
2655 params.get(1) == Some(&Value::Real(0.95)),
2656 "threshold param should follow the query vector: {params:?}"
2657 );
2658
2659 Ok(())
2660 }
2661
2662 #[test]
2663 fn or_filter_uses_document_filter_to_preserve_boolean_semantics() -> anyhow::Result<()> {
2664 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs")).or(
2665 SqliteSearchFilter::eq("title", serde_json::json!("archive")),
2666 );
2667
2668 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2669 .query("needle")
2670 .samples(5)
2671 .filter(filter)
2672 .build();
2673
2674 let filters =
2675 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2676 anyhow::ensure!(
2677 filters.has_post_filters(),
2678 "OR filters should be applied after vector candidate search"
2679 );
2680 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2681
2682 anyhow::ensure!(
2683 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2684 "OR filters should not be partially pushed into sqlite-vec: {}",
2685 query.vector_where_clause
2686 );
2687 anyhow::ensure!(
2688 query.document_filter_clause == "AND ((d.category = ?) OR (d.title = ?))",
2689 "unexpected document filter clause: {}",
2690 query.document_filter_clause
2691 );
2692 anyhow::ensure!(
2693 query.params.get(3) == Some(&Value::Text("docs".to_string()))
2694 && query.params.get(4) == Some(&Value::Text("archive".to_string())),
2695 "unexpected OR filter params: {:?}",
2696 query.params
2697 );
2698
2699 Ok(())
2700 }
2701
2702 #[test]
2703 fn indexed_filter_uses_vec0_metadata_constraint() -> anyhow::Result<()> {
2704 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2705 .query("needle")
2706 .samples(5)
2707 .filter(SqliteSearchFilter::eq(
2708 "category",
2709 serde_json::json!("docs"),
2710 ))
2711 .build();
2712
2713 let (where_clause, params) = build_where_clause(
2714 &req,
2715 vec![1.0, 0.0],
2716 SqliteDistanceMetric::Cosine,
2717 &test_metadata_columns(),
2718 5,
2719 )?;
2720
2721 anyhow::ensure!(
2722 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.category = ?)",
2723 "unexpected where clause: {where_clause}"
2724 );
2725 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2726 anyhow::ensure!(
2727 params.get(3) == Some(&Value::Text("docs".to_string())),
2728 "unexpected filter param: {params:?}"
2729 );
2730
2731 Ok(())
2732 }
2733
2734 #[test]
2735 fn negated_eq_filter_uses_vec0_metadata_inequality() -> anyhow::Result<()> {
2736 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2737 .query("needle")
2738 .samples(5)
2739 .filter(SqliteSearchFilter::eq("category", serde_json::json!("docs")).not())
2740 .build();
2741
2742 let (where_clause, params) = build_where_clause(
2743 &req,
2744 vec![1.0, 0.0],
2745 SqliteDistanceMetric::Cosine,
2746 &test_metadata_columns(),
2747 5,
2748 )?;
2749
2750 anyhow::ensure!(
2751 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.category != ?)",
2752 "unexpected where clause: {where_clause}"
2753 );
2754 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2755 anyhow::ensure!(
2756 params.get(3) == Some(&Value::Text("docs".to_string())),
2757 "unexpected filter param: {params:?}"
2758 );
2759
2760 Ok(())
2761 }
2762
2763 #[test]
2764 fn negated_range_comparison_uses_vec0_metadata_boundary() -> anyhow::Result<()> {
2765 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2766 .query("needle")
2767 .samples(5)
2768 .filter(SqliteSearchFilter::gt("priority", serde_json::json!(10)).not())
2769 .build();
2770
2771 let (where_clause, params) = build_where_clause(
2772 &req,
2773 vec![1.0, 0.0],
2774 SqliteDistanceMetric::Cosine,
2775 &typed_metadata_columns(),
2776 5,
2777 )?;
2778
2779 anyhow::ensure!(
2780 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.priority <= ?)",
2781 "unexpected where clause: {where_clause}"
2782 );
2783 anyhow::ensure!(params.len() == 4, "unexpected params: {params:?}");
2784 anyhow::ensure!(
2785 params.get(3) == Some(&Value::Integer(10)),
2786 "unexpected filter param: {params:?}"
2787 );
2788
2789 Ok(())
2790 }
2791
2792 #[test]
2793 fn negated_boolean_eq_filter_uses_vec0_metadata_inequality() -> anyhow::Result<()> {
2794 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2795 .query("needle")
2796 .samples(5)
2797 .filter(SqliteSearchFilter::eq("published", serde_json::json!(true)).not())
2798 .build();
2799
2800 let (where_clause, params) = build_where_clause(
2801 &req,
2802 vec![1.0, 0.0],
2803 SqliteDistanceMetric::Cosine,
2804 &typed_metadata_columns(),
2805 5,
2806 )?;
2807
2808 anyhow::ensure!(
2809 where_clause == "WHERE e.embedding MATCH ? AND k = ? AND (e.published != ?)",
2810 "unexpected where clause: {where_clause}"
2811 );
2812 anyhow::ensure!(
2813 params.get(3) == Some(&Value::Integer(1)),
2814 "unexpected boolean filter param: {params:?}"
2815 );
2816
2817 Ok(())
2818 }
2819
2820 #[test]
2821 fn negated_between_filter_uses_document_filter() -> anyhow::Result<()> {
2822 let filter = SqliteSearchFilter::between("priority".to_string(), 1_i64..=10_i64).not();
2823 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2824 .query("needle")
2825 .samples(5)
2826 .filter(filter)
2827 .build();
2828
2829 let filters = render_search_filters(
2830 &req,
2831 SqliteDistanceMetric::Cosine,
2832 &typed_metadata_columns(),
2833 )?;
2834 anyhow::ensure!(
2835 filters.has_post_filters(),
2836 "negated range filters should be applied after vector candidate search"
2837 );
2838 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2839
2840 anyhow::ensure!(
2841 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2842 "negated range filters should not be partially pushed into sqlite-vec: {}",
2843 query.vector_where_clause
2844 );
2845 anyhow::ensure!(
2846 query.document_filter_clause == "AND (NOT (d.priority between ? and ?))",
2847 "unexpected document filter clause: {}",
2848 query.document_filter_clause
2849 );
2850 anyhow::ensure!(
2851 query.params.get(3) == Some(&Value::Integer(1))
2852 && query.params.get(4) == Some(&Value::Integer(10)),
2853 "unexpected negated between params: {:?}",
2854 query.params
2855 );
2856
2857 Ok(())
2858 }
2859
2860 #[test]
2861 fn boolean_range_filter_is_rejected() -> anyhow::Result<()> {
2862 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2863 .query("needle")
2864 .samples(5)
2865 .filter(SqliteSearchFilter::gt(
2866 "published",
2867 serde_json::json!(false),
2868 ))
2869 .build();
2870
2871 let err = filter_error(
2872 build_where_clause(
2873 &req,
2874 vec![1.0, 0.0],
2875 SqliteDistanceMetric::Cosine,
2876 &typed_metadata_columns(),
2877 5,
2878 ),
2879 "boolean range filters",
2880 )?;
2881
2882 anyhow::ensure!(
2883 err.to_string().contains("BOOLEAN"),
2884 "unexpected error for boolean range filter: {err}"
2885 );
2886
2887 Ok(())
2888 }
2889
2890 #[test]
2891 fn indexed_between_filter_uses_vec0_metadata_constraints() -> anyhow::Result<()> {
2892 let filter = SqliteSearchFilter::between("priority".to_string(), 1_i64..=10_i64);
2893 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2894 .query("needle")
2895 .samples(5)
2896 .filter(filter)
2897 .build();
2898
2899 let (where_clause, params) = build_where_clause(
2900 &req,
2901 vec![1.0, 0.0],
2902 SqliteDistanceMetric::Cosine,
2903 &typed_metadata_columns(),
2904 5,
2905 )?;
2906
2907 anyhow::ensure!(
2908 where_clause
2909 == "WHERE e.embedding MATCH ? AND k = ? AND (e.priority >= ? AND e.priority <= ?)",
2910 "unexpected where clause: {where_clause}"
2911 );
2912 anyhow::ensure!(params.len() == 5, "unexpected params: {params:?}");
2913 anyhow::ensure!(
2914 params.get(3) == Some(&Value::Integer(1)) && params.get(4) == Some(&Value::Integer(10)),
2915 "between bounds should be bound as parameters: {params:?}"
2916 );
2917
2918 Ok(())
2919 }
2920
2921 #[test]
2922 fn mismatched_metadata_filter_value_types_are_rejected() -> anyhow::Result<()> {
2923 let cases = [
2924 (
2925 SqliteSearchFilter::eq("published", serde_json::json!("true")),
2926 "boolean filter value",
2927 ),
2928 (
2929 SqliteSearchFilter::gt("priority", serde_json::json!(1.5)),
2930 "integer filter value",
2931 ),
2932 (
2933 SqliteSearchFilter::eq("category", serde_json::json!({ "name": "docs" })),
2934 "string filter value",
2935 ),
2936 (
2937 SqliteSearchFilter::between(
2938 "priority".to_string(),
2939 "1".to_string()..="10".to_string(),
2940 ),
2941 "integer filter value",
2942 ),
2943 ];
2944
2945 for (filter, expected) in cases {
2946 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2947 .query("needle")
2948 .samples(5)
2949 .filter(filter)
2950 .build();
2951
2952 let err = filter_error(
2953 build_where_clause(
2954 &req,
2955 vec![1.0, 0.0],
2956 SqliteDistanceMetric::Cosine,
2957 &typed_metadata_columns()
2958 .into_iter()
2959 .chain(test_metadata_columns())
2960 .collect::<Vec<_>>(),
2961 5,
2962 ),
2963 "mismatched metadata filter value",
2964 )?;
2965
2966 anyhow::ensure!(
2967 err.to_string().contains(expected),
2968 "unexpected error for mismatched metadata filter value: {err}"
2969 );
2970 }
2971
2972 Ok(())
2973 }
2974
2975 #[test]
2976 fn pattern_and_null_filters_use_document_filter() -> anyhow::Result<()> {
2977 let filter = SqliteSearchFilter::like("title".to_string(), "%O'Reilly%")
2978 .and(SqliteSearchFilter::glob("category".to_string(), "doc*"))
2979 .and(SqliteSearchFilter::is_null(
2980 "metadata->>'$.missing'".to_string(),
2981 ));
2982 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
2983 .query("needle")
2984 .samples(5)
2985 .filter(filter)
2986 .build();
2987
2988 let filters =
2989 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
2990 anyhow::ensure!(
2991 filters.has_post_filters(),
2992 "pattern and null filters should be applied after vector candidate search"
2993 );
2994 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
2995
2996 anyhow::ensure!(
2997 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
2998 "pattern filters should not be pushed into sqlite-vec: {}",
2999 query.vector_where_clause
3000 );
3001 anyhow::ensure!(
3002 query.document_filter_clause
3003 == "AND (d.title like ?) AND (d.category glob ?) AND (d.metadata->>'$.missing' is null)",
3004 "unexpected document filter clause: {}",
3005 query.document_filter_clause
3006 );
3007 anyhow::ensure!(
3008 query.params.get(3) == Some(&Value::Text("%O'Reilly%".to_string()))
3009 && query.params.get(4) == Some(&Value::Text("doc*".to_string())),
3010 "unexpected pattern filter params: {:?}",
3011 query.params
3012 );
3013
3014 Ok(())
3015 }
3016
3017 #[test]
3018 fn nonindexed_filters_use_document_filter() -> anyhow::Result<()> {
3019 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3020 .query("needle")
3021 .samples(5)
3022 .filter(SqliteSearchFilter::eq("title", serde_json::json!("docs")))
3023 .build();
3024
3025 let filters =
3026 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
3027 anyhow::ensure!(
3028 filters.has_post_filters(),
3029 "non-indexed filters should be applied after vector candidate search"
3030 );
3031 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
3032
3033 anyhow::ensure!(
3034 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
3035 "unexpected vector where clause: {}",
3036 query.vector_where_clause
3037 );
3038 anyhow::ensure!(
3039 query.document_filter_clause == "AND (d.title = ?)",
3040 "unexpected document filter clause: {}",
3041 query.document_filter_clause
3042 );
3043 anyhow::ensure!(
3044 query.params.get(3) == Some(&Value::Text("docs".to_string())),
3045 "unexpected document filter param: {:?}",
3046 query.params
3047 );
3048
3049 Ok(())
3050 }
3051
3052 #[test]
3053 fn json_metadata_expression_uses_document_filter() -> anyhow::Result<()> {
3054 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3055 .query("needle")
3056 .samples(5)
3057 .filter(SqliteSearchFilter::eq(
3058 "metadata->>'$.xxx'",
3059 serde_json::json!("vvv"),
3060 ))
3061 .build();
3062
3063 let filters =
3064 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
3065 anyhow::ensure!(
3066 filters.has_post_filters(),
3067 "JSON metadata expressions should be applied after vector candidate search"
3068 );
3069 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
3070
3071 anyhow::ensure!(
3072 query.vector_where_clause == "WHERE e.embedding MATCH ? AND k = ?",
3073 "unexpected vector where clause: {}",
3074 query.vector_where_clause
3075 );
3076 anyhow::ensure!(
3077 query.document_filter_clause == "AND (d.metadata->>'$.xxx' = ?)",
3078 "unexpected document filter clause: {}",
3079 query.document_filter_clause
3080 );
3081 anyhow::ensure!(
3082 query.params.get(3) == Some(&Value::Text("vvv".to_string())),
3083 "unexpected JSON metadata filter param: {:?}",
3084 query.params
3085 );
3086
3087 Ok(())
3088 }
3089
3090 #[test]
3091 fn json_metadata_arrow_expression_binds_rhs_as_json_text() -> anyhow::Result<()> {
3092 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3093 .query("needle")
3094 .samples(5)
3095 .filter(SqliteSearchFilter::eq(
3096 "metadata->'$.xxx'",
3097 serde_json::json!("vvv"),
3098 ))
3099 .build();
3100
3101 let filters =
3102 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
3103 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
3104
3105 anyhow::ensure!(
3106 query.document_filter_clause == "AND (d.metadata->'$.xxx' = ?)",
3107 "unexpected document filter clause: {}",
3108 query.document_filter_clause
3109 );
3110 anyhow::ensure!(
3111 query.params.get(3) == Some(&Value::Text("\"vvv\"".to_string())),
3112 "SQLite `->` should compare against JSON text: {:?}",
3113 query.params
3114 );
3115
3116 Ok(())
3117 }
3118
3119 #[test]
3120 fn chained_json_metadata_expression_uses_final_operator_for_param_mode() -> anyhow::Result<()> {
3121 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3122 .query("needle")
3123 .samples(5)
3124 .filter(SqliteSearchFilter::eq(
3125 "metadata->'$.nested'->>'$.xxx'",
3126 serde_json::json!("vvv"),
3127 ))
3128 .build();
3129
3130 let filters =
3131 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns())?;
3132 let query = build_search_query(vec![1.0, 0.0], filters, 5)?;
3133
3134 anyhow::ensure!(
3135 query.document_filter_clause == "AND (d.metadata->'$.nested'->>'$.xxx' = ?)",
3136 "unexpected document filter clause: {}",
3137 query.document_filter_clause
3138 );
3139 anyhow::ensure!(
3140 query.params.get(3) == Some(&Value::Text("vvv".to_string())),
3141 "final `->>` should compare against SQL scalar text: {:?}",
3142 query.params
3143 );
3144
3145 Ok(())
3146 }
3147
3148 #[test]
3149 fn unsupported_document_filter_expressions_are_rejected() -> anyhow::Result<()> {
3150 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3151 .query("needle")
3152 .samples(5)
3153 .filter(SqliteSearchFilter::eq(
3154 "metadata) OR 1 = 1 --",
3155 serde_json::json!("vvv"),
3156 ))
3157 .build();
3158
3159 let err = filter_error(
3160 render_search_filters(&req, SqliteDistanceMetric::Cosine, &test_metadata_columns()),
3161 "unsupported document filter expressions",
3162 )?;
3163
3164 anyhow::ensure!(
3165 err.to_string()
3166 .contains("supported SQLite document filter expression"),
3167 "unexpected error for unsupported document filter expression: {err}"
3168 );
3169
3170 Ok(())
3171 }
3172
3173 #[tokio::test]
3174 async fn live_search_orders_by_similarity_and_applies_threshold() -> anyhow::Result<()> {
3175 let index = live_test_index(
3176 "live_search_orders_by_similarity_and_applies_threshold",
3177 vec![
3178 row("exact", "docs", "exact match", vec![1.0, 0.0]),
3179 row("close", "docs", "close match", vec![0.8, 0.6]),
3180 row("opposite", "docs", "opposite match", vec![-1.0, 0.0]),
3181 ],
3182 )
3183 .await?;
3184
3185 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3186 .query("needle")
3187 .samples(3)
3188 .threshold(0.75)
3189 .build();
3190
3191 let results = index.top_n::<TestDocument>(req.clone()).await?;
3192 let ids = results
3193 .iter()
3194 .map(|(_, id, _)| id.as_str())
3195 .collect::<Vec<_>>();
3196 let exact_score = results.first().map(|(score, _, _)| *score);
3197 let close_score = results.get(1).map(|(score, _, _)| *score);
3198
3199 anyhow::ensure!(
3200 ids.as_slice() == ["exact", "close"],
3201 "unexpected ids: {ids:?}"
3202 );
3203 anyhow::ensure!(
3204 exact_score
3205 .zip(close_score)
3206 .is_some_and(|(exact, close)| exact > close),
3207 "expected exact score to be greater than close score: {results:?}"
3208 );
3209 anyhow::ensure!(
3210 results.iter().all(|(score, _, _)| *score > 0.75),
3211 "threshold should remove low-scoring rows: {results:?}"
3212 );
3213
3214 let id_results = index.top_n_ids(req).await?;
3215 let result_ids = id_results
3216 .iter()
3217 .map(|(_, id)| id.as_str())
3218 .collect::<Vec<_>>();
3219
3220 anyhow::ensure!(
3221 result_ids.as_slice() == ["exact", "close"],
3222 "unexpected top_n_ids ids: {id_results:?}"
3223 );
3224 anyhow::ensure!(
3225 id_results.iter().all(|(score, _)| *score > 0.75),
3226 "top_n_ids threshold should remove low-scoring rows: {id_results:?}"
3227 );
3228
3229 Ok(())
3230 }
3231
3232 #[tokio::test]
3233 async fn live_reinsert_same_document_id_removes_stale_vec0_candidates() -> anyhow::Result<()> {
3234 register_sqlite_vec_extension();
3235
3236 let conn = Connection::open(
3237 "file:live_reinsert_same_document_id_removes_stale_vec0_candidates?mode=memory",
3238 )
3239 .await?;
3240 let model = TestEmbeddingModel;
3241 let vector_store: SqliteVectorStore<_, TestDocument> =
3242 SqliteVectorStore::new(conn, &model).await?;
3243
3244 vector_store
3245 .add_rows(vec![row(
3246 "replace",
3247 "docs",
3248 "original near vector",
3249 vec![1.0, 0.0],
3250 )])
3251 .await?;
3252 vector_store
3253 .add_rows(vec![
3254 row("replace", "docs", "replacement far vector", vec![-1.0, 0.0]),
3255 row("fresh", "docs", "fresh near vector", vec![0.9, 0.1]),
3256 ])
3257 .await?;
3258
3259 let index = vector_store.index(model);
3260 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3261 .query("needle")
3262 .samples(1)
3263 .build();
3264
3265 let results = index.top_n::<TestDocument>(req.clone()).await?;
3266 let ids = results
3267 .iter()
3268 .map(|(_, id, _)| id.as_str())
3269 .collect::<Vec<_>>();
3270 anyhow::ensure!(
3271 ids.as_slice() == ["fresh"],
3272 "stale replaced vectors should not consume sqlite-vec candidates: {results:?}"
3273 );
3274
3275 let id_results = index.top_n_ids(req).await?;
3276 let result_ids = id_results
3277 .iter()
3278 .map(|(_, id)| id.as_str())
3279 .collect::<Vec<_>>();
3280 anyhow::ensure!(
3281 result_ids.as_slice() == ["fresh"],
3282 "top_n_ids should not return or be starved by stale replaced vectors: {id_results:?}"
3283 );
3284
3285 Ok(())
3286 }
3287
3288 #[tokio::test]
3289 async fn live_reinsert_preserves_unrelated_multivector_embeddings() -> anyhow::Result<()> {
3290 register_sqlite_vec_extension();
3291
3292 let conn = Connection::open(
3293 "file:live_reinsert_preserves_unrelated_multivector_embeddings?mode=memory",
3294 )
3295 .await?;
3296 let model = TestEmbeddingModel;
3297 let vector_store: SqliteVectorStore<_, TestDocument> =
3298 SqliteVectorStore::new(conn, &model).await?;
3299
3300 let multi_document = TestDocument {
3301 id: "multi".to_string(),
3302 category: "docs".to_string(),
3303 title: "multi-vector document".to_string(),
3304 };
3305 vector_store
3306 .add_rows(vec![
3307 (
3308 multi_document.clone(),
3309 OneOrMany::many(vec![
3310 Embedding {
3311 document: "far chunk".to_string(),
3312 vec: vec![-1.0, 0.0],
3313 },
3314 Embedding {
3315 document: "exact chunk".to_string(),
3316 vec: vec![1.0, 0.0],
3317 },
3318 ])?,
3319 ),
3320 row(
3321 "replace",
3322 "docs",
3323 "initial replacement vector",
3324 vec![0.8, 0.2],
3325 ),
3326 ])
3327 .await?;
3328 vector_store
3329 .add_rows(vec![row(
3330 "replace",
3331 "docs",
3332 "replacement far vector",
3333 vec![-1.0, 0.0],
3334 )])
3335 .await?;
3336
3337 let index = vector_store.index(model);
3338 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3339 .query("needle")
3340 .samples(1)
3341 .threshold(0.9)
3342 .build();
3343
3344 let results = index.top_n::<TestDocument>(req.clone()).await?;
3345 let ids = results
3346 .iter()
3347 .map(|(_, id, _)| id.as_str())
3348 .collect::<Vec<_>>();
3349 anyhow::ensure!(
3350 ids.as_slice() == ["multi"],
3351 "reinsert should not delete another document's best embedding: {results:?}"
3352 );
3353
3354 let id_results = index.top_n_ids(req).await?;
3355 let result_ids = id_results
3356 .iter()
3357 .map(|(_, id)| id.as_str())
3358 .collect::<Vec<_>>();
3359 anyhow::ensure!(
3360 result_ids.as_slice() == ["multi"],
3361 "top_n_ids should preserve unrelated multivector embeddings after reinsert: {id_results:?}"
3362 );
3363
3364 Ok(())
3365 }
3366
3367 #[tokio::test]
3368 async fn live_multiple_embeddings_per_document_use_best_embedding() -> anyhow::Result<()> {
3369 let multi_document = TestDocument {
3370 id: "multi".to_string(),
3371 category: "docs".to_string(),
3372 title: "multi-vector document".to_string(),
3373 };
3374 let index = live_test_index(
3375 "live_multiple_embeddings_per_document_use_best_embedding",
3376 vec![
3377 (
3378 multi_document.clone(),
3379 OneOrMany::many(vec![
3380 Embedding {
3381 document: "far chunk".to_string(),
3382 vec: vec![-1.0, 0.0],
3383 },
3384 Embedding {
3385 document: "exact chunk".to_string(),
3386 vec: vec![1.0, 0.0],
3387 },
3388 ])?,
3389 ),
3390 row("single", "docs", "single close chunk", vec![0.8, 0.6]),
3391 ],
3392 )
3393 .await?;
3394
3395 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3396 .query("needle")
3397 .samples(2)
3398 .build();
3399 let results = index.top_n::<TestDocument>(req.clone()).await?;
3400 let ids = results
3401 .iter()
3402 .map(|(_, id, _)| id.as_str())
3403 .collect::<Vec<_>>();
3404 anyhow::ensure!(
3405 ids.as_slice() == ["multi", "single"],
3406 "top_n should return each document once using its best embedding: {results:?}"
3407 );
3408
3409 let id_results = index.top_n_ids(req).await?;
3410 let result_ids = id_results
3411 .iter()
3412 .map(|(_, id)| id.as_str())
3413 .collect::<Vec<_>>();
3414 anyhow::ensure!(
3415 result_ids.as_slice() == ["multi", "single"],
3416 "top_n_ids should return each document once using its best embedding: {id_results:?}"
3417 );
3418
3419 let threshold_req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3420 .query("needle")
3421 .samples(2)
3422 .threshold(1.0)
3423 .build();
3424 let threshold_results = index.top_n::<TestDocument>(threshold_req.clone()).await?;
3425 let threshold_ids = threshold_results
3426 .iter()
3427 .map(|(_, id, _)| id.as_str())
3428 .collect::<Vec<_>>();
3429 anyhow::ensure!(
3430 threshold_ids.as_slice() == ["multi"],
3431 "threshold should include scores equal to the minimum and filter lower scores: {threshold_results:?}"
3432 );
3433
3434 let threshold_id_results = index.top_n_ids(threshold_req).await?;
3435 let threshold_result_ids = threshold_id_results
3436 .iter()
3437 .map(|(_, id)| id.as_str())
3438 .collect::<Vec<_>>();
3439 anyhow::ensure!(
3440 threshold_result_ids.as_slice() == ["multi"],
3441 "top_n_ids threshold should include scores equal to the minimum: {threshold_id_results:?}"
3442 );
3443
3444 Ok(())
3445 }
3446
3447 #[tokio::test]
3453 async fn live_multivector_search_beyond_knn_k_cap_succeeds() -> anyhow::Result<()> {
3454 let filler_chunks = (0..4100)
3459 .map(|i| Embedding {
3460 document: format!("filler chunk {i}"),
3461 vec: vec![0.0, 1.0],
3462 })
3463 .collect::<Vec<_>>();
3464 let filler_document = TestDocument {
3465 id: "filler".to_string(),
3466 category: "docs".to_string(),
3467 title: "many-embedding document".to_string(),
3468 };
3469
3470 let index = live_test_index(
3471 "live_multivector_search_beyond_knn_k_cap_succeeds",
3472 vec![
3473 (filler_document, OneOrMany::many(filler_chunks)?),
3474 row("best", "docs", "best", vec![1.0, 0.0]),
3475 row("mid", "docs", "mid", vec![0.5, 0.5]),
3476 row("worst", "docs", "worst", vec![-1.0, 0.0]),
3477 ],
3478 )
3479 .await?;
3480
3481 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3482 .query("needle")
3483 .samples(3)
3484 .build();
3485
3486 let results = index.top_n::<TestDocument>(req.clone()).await?;
3487 let ids = results
3488 .iter()
3489 .map(|(_, id, _)| id.as_str())
3490 .collect::<Vec<_>>();
3491 anyhow::ensure!(
3492 ids.as_slice() == ["best", "mid", "filler"],
3493 "brute-force scan should return the exact top-n past the knn k cap: {results:?}"
3494 );
3495
3496 let id_results = index.top_n_ids(req).await?;
3497 let result_ids = id_results
3498 .iter()
3499 .map(|(_, id)| id.as_str())
3500 .collect::<Vec<_>>();
3501 anyhow::ensure!(
3502 result_ids.as_slice() == ["best", "mid", "filler"],
3503 "top_n_ids should also brute-force past the knn k cap: {id_results:?}"
3504 );
3505
3506 Ok(())
3507 }
3508
3509 #[tokio::test]
3515 async fn live_post_filter_search_beyond_knn_k_cap_succeeds() -> anyhow::Result<()> {
3516 let mut rows = (0..4096)
3517 .map(|i| row(format!("noise-{i}"), "docs", "noise title", vec![1.0, 0.0]))
3518 .collect::<Vec<_>>();
3519 rows.push(row("wanted", "docs", "wanted title", vec![-1.0, 0.0]));
3523
3524 let index =
3525 live_test_index("live_post_filter_search_beyond_knn_k_cap_succeeds", rows).await?;
3526
3527 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3528 .query("needle")
3529 .samples(1)
3530 .filter(SqliteSearchFilter::eq(
3531 "title",
3532 serde_json::json!("wanted title"),
3533 ))
3534 .build();
3535
3536 let results = index.top_n::<TestDocument>(req.clone()).await?;
3537 let ids = results
3538 .iter()
3539 .map(|(_, id, _)| id.as_str())
3540 .collect::<Vec<_>>();
3541 anyhow::ensure!(
3542 ids.as_slice() == ["wanted"],
3543 "exhaustive non-indexed filter past the knn k cap should still find the match: {results:?}"
3544 );
3545
3546 let id_results = index.top_n_ids(req).await?;
3547 let result_ids = id_results
3548 .iter()
3549 .map(|(_, id)| id.as_str())
3550 .collect::<Vec<_>>();
3551 anyhow::ensure!(
3552 result_ids.as_slice() == ["wanted"],
3553 "top_n_ids should also apply the exhaustive filter past the knn k cap: {id_results:?}"
3554 );
3555
3556 Ok(())
3557 }
3558
3559 #[tokio::test]
3560 async fn live_equal_score_results_are_ordered_by_document_id() -> anyhow::Result<()> {
3561 let index = live_test_index(
3562 "live_equal_score_results_are_ordered_by_document_id",
3563 vec![
3564 row("b", "docs", "second id exact match", vec![1.0, 0.0]),
3565 row("a", "docs", "first id exact match", vec![1.0, 0.0]),
3566 ],
3567 )
3568 .await?;
3569
3570 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3571 .query("needle")
3572 .samples(2)
3573 .build();
3574
3575 let results = index.top_n::<TestDocument>(req.clone()).await?;
3576 let ids = results
3577 .iter()
3578 .map(|(_, id, _)| id.as_str())
3579 .collect::<Vec<_>>();
3580 anyhow::ensure!(
3581 ids.as_slice() == ["a", "b"],
3582 "equal-score top_n results should use document id as a stable tie-breaker: {results:?}"
3583 );
3584
3585 let id_results = index.top_n_ids(req).await?;
3586 let result_ids = id_results
3587 .iter()
3588 .map(|(_, id)| id.as_str())
3589 .collect::<Vec<_>>();
3590 anyhow::ensure!(
3591 result_ids.as_slice() == ["a", "b"],
3592 "equal-score top_n_ids results should use document id as a stable tie-breaker: {id_results:?}"
3593 );
3594
3595 Ok(())
3596 }
3597
3598 #[tokio::test]
3599 async fn live_common_sqlite_text_types_round_trip_in_top_n() -> anyhow::Result<()> {
3600 let index = live_common_type_test_index(
3601 "live_common_sqlite_text_types_round_trip_in_top_n",
3602 vec![common_type_row(
3603 "common",
3604 "varchar name",
3605 "clob notes",
3606 7,
3607 vec![1.0, 0.0],
3608 )],
3609 )
3610 .await?;
3611
3612 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3613 .query("needle")
3614 .samples(1)
3615 .build();
3616 let results = index.top_n::<CommonTypeDocument>(req).await?;
3617
3618 let Some((_, id, doc)) = results.first() else {
3619 anyhow::bail!("expected common type document result");
3620 };
3621 anyhow::ensure!(id == "common", "unexpected id: {id}");
3622 anyhow::ensure!(
3623 doc.name == "varchar name",
3624 "VARCHAR value should round-trip: {doc:?}"
3625 );
3626 anyhow::ensure!(
3627 doc.notes == "clob notes",
3628 "CLOB value should round-trip: {doc:?}"
3629 );
3630 anyhow::ensure!(doc.rank == 7, "NUMERIC value should round-trip: {doc:?}");
3631
3632 Ok(())
3633 }
3634
3635 #[tokio::test]
3636 async fn live_json_column_structured_metadata_round_trips_in_top_n() -> anyhow::Result<()> {
3637 let metadata = StructuredMetadata {
3638 user_id: 1,
3639 knowledge_id: 1,
3640 knowledge_doc_id: 361,
3641 };
3642 let index = live_structured_json_metadata_test_index(
3643 "live_json_column_structured_metadata_round_trips_in_top_n",
3644 vec![structured_json_metadata_row(
3645 "structured",
3646 metadata.clone(),
3647 "metadata document",
3648 vec![1.0, 0.0],
3649 )],
3650 )
3651 .await?;
3652
3653 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3654 .query("needle")
3655 .samples(1)
3656 .build();
3657 let results = index
3658 .top_n::<StructuredJsonMetadataDocument>(req.clone())
3659 .await?;
3660
3661 let Some((_, id, doc)) = results.first() else {
3662 anyhow::bail!("expected structured JSON metadata document result");
3663 };
3664 anyhow::ensure!(id == "structured", "unexpected id: {id}");
3665 anyhow::ensure!(
3666 doc.metadata == metadata,
3667 "JSON column should deserialize into structured metadata: {doc:?}"
3668 );
3669
3670 let id_results = index.top_n_ids(req).await?;
3671 anyhow::ensure!(
3672 id_results.first().is_some_and(|(_, id)| id == "structured"),
3673 "top_n_ids should still return the structured metadata document id: {id_results:?}"
3674 );
3675
3676 Ok(())
3677 }
3678
3679 #[tokio::test]
3680 async fn live_text_affinity_metadata_filters_during_candidate_search() -> anyhow::Result<()> {
3681 let index = live_common_type_test_index(
3682 "live_text_affinity_metadata_filters_during_candidate_search",
3683 vec![
3684 common_type_row("nearest", "misc", "nearest excluded", 1, vec![1.0, 0.0]),
3685 common_type_row("docs", "docs", "docs match", 2, vec![0.0, 1.0]),
3686 ],
3687 )
3688 .await?;
3689
3690 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3691 .query("needle")
3692 .samples(1)
3693 .filter(SqliteSearchFilter::eq("name", serde_json::json!("docs")))
3694 .build();
3695
3696 let results = index.top_n::<CommonTypeDocument>(req.clone()).await?;
3697 let ids = results
3698 .iter()
3699 .map(|(_, id, _)| id.as_str())
3700 .collect::<Vec<_>>();
3701
3702 anyhow::ensure!(
3703 ids.as_slice() == ["docs"],
3704 "VARCHAR metadata filters should constrain sqlite-vec candidate search: {results:?}"
3705 );
3706
3707 let id_results = index.top_n_ids(req).await?;
3708 let result_ids = id_results
3709 .iter()
3710 .map(|(_, id)| id.as_str())
3711 .collect::<Vec<_>>();
3712
3713 anyhow::ensure!(
3714 result_ids.as_slice() == ["docs"],
3715 "top_n_ids should use VARCHAR metadata filters during candidate search: {id_results:?}"
3716 );
3717
3718 Ok(())
3719 }
3720
3721 #[tokio::test]
3722 async fn live_l2_metric_is_consistent() -> anyhow::Result<()> {
3723 let index = live_test_index_with_metric(
3724 "live_l2_metric_is_consistent",
3725 vec![
3726 row("exact", "docs", "exact match", vec![1.0, 0.0]),
3727 row("l2-close", "docs", "l2 close match", vec![1.0, 1.0]),
3728 row(
3729 "same-direction-far",
3730 "docs",
3731 "same direction far away",
3732 vec![10.0, 0.0],
3733 ),
3734 ],
3735 SqliteDistanceMetric::L2,
3736 )
3737 .await?;
3738
3739 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3740 .query("needle")
3741 .samples(2)
3742 .threshold(-2.0)
3743 .build();
3744
3745 let results = index.top_n::<TestDocument>(req.clone()).await?;
3746 let ids = results
3747 .iter()
3748 .map(|(_, id, _)| id.as_str())
3749 .collect::<Vec<_>>();
3750 let exact_score = results
3751 .iter()
3752 .find(|(_, id, _)| id == "exact")
3753 .map(|(score, _, _)| *score);
3754 let close_score = results
3755 .iter()
3756 .find(|(_, id, _)| id == "l2-close")
3757 .map(|(score, _, _)| *score);
3758
3759 anyhow::ensure!(
3760 ids.as_slice() == ["exact", "l2-close"],
3761 "L2 search should return the nearest L2 candidates: {results:?}"
3762 );
3763 anyhow::ensure!(
3764 exact_score
3765 .zip(close_score)
3766 .is_some_and(|(exact, close)| exact > close && close > -2.0),
3767 "expected L2 scores to be ordered and thresholded: {results:?}"
3768 );
3769 anyhow::ensure!(
3770 results.iter().all(|(score, _, _)| *score > -2.0),
3771 "threshold should be applied to L2 scores: {results:?}"
3772 );
3773
3774 let id_results = index.top_n_ids(req).await?;
3775 let result_ids = id_results
3776 .iter()
3777 .map(|(_, id)| id.as_str())
3778 .collect::<Vec<_>>();
3779
3780 anyhow::ensure!(
3781 result_ids.as_slice() == ["exact", "l2-close"],
3782 "top_n_ids should use the same L2 metric: {id_results:?}"
3783 );
3784
3785 Ok(())
3786 }
3787
3788 #[tokio::test]
3789 async fn live_indexed_filter_is_applied_during_candidate_search() -> anyhow::Result<()> {
3790 let index = live_test_index(
3791 "live_indexed_filter_is_applied_during_candidate_search",
3792 vec![
3793 row(
3794 "nearest",
3795 "misc",
3796 "nearest excluded category",
3797 vec![1.0, 0.0],
3798 ),
3799 row("docs", "docs", "docs match", vec![0.0, 1.0]),
3800 ],
3801 )
3802 .await?;
3803
3804 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3805 .query("needle")
3806 .samples(1)
3807 .filter(SqliteSearchFilter::eq(
3808 "category",
3809 serde_json::json!("docs"),
3810 ))
3811 .build();
3812
3813 let results = index.top_n::<TestDocument>(req.clone()).await?;
3814 let ids = results
3815 .iter()
3816 .map(|(_, id, _)| id.as_str())
3817 .collect::<Vec<_>>();
3818
3819 anyhow::ensure!(
3820 ids.as_slice() == ["docs"],
3821 "indexed filters should constrain sqlite-vec candidate search: {results:?}"
3822 );
3823
3824 let id_results = index.top_n_ids(req).await?;
3825 let result_ids = id_results
3826 .iter()
3827 .map(|(_, id)| id.as_str())
3828 .collect::<Vec<_>>();
3829
3830 anyhow::ensure!(
3831 result_ids.as_slice() == ["docs"],
3832 "top_n_ids should use indexed filters during candidate search: {id_results:?}"
3833 );
3834
3835 Ok(())
3836 }
3837
3838 #[tokio::test]
3839 async fn live_nonindexed_filter_is_applied_after_candidate_search() -> anyhow::Result<()> {
3840 let index = live_test_index(
3841 "live_nonindexed_filter_is_applied_after_candidate_search",
3842 vec![
3843 row("nearest", "docs", "nearest excluded title", vec![1.0, 0.0]),
3844 row("wanted", "docs", "wanted title", vec![0.0, 1.0]),
3845 ],
3846 )
3847 .await?;
3848
3849 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3850 .query("needle")
3851 .samples(1)
3852 .filter(SqliteSearchFilter::eq(
3853 "title",
3854 serde_json::json!("wanted title"),
3855 ))
3856 .build();
3857
3858 let results = index.top_n::<TestDocument>(req.clone()).await?;
3859 let ids = results
3860 .iter()
3861 .map(|(_, id, _)| id.as_str())
3862 .collect::<Vec<_>>();
3863 anyhow::ensure!(
3864 ids.as_slice() == ["wanted"],
3865 "non-indexed filters should not be starved by the initial candidate limit: {results:?}"
3866 );
3867
3868 let id_results = index.top_n_ids(req).await?;
3869 let result_ids = id_results
3870 .iter()
3871 .map(|(_, id)| id.as_str())
3872 .collect::<Vec<_>>();
3873 anyhow::ensure!(
3874 result_ids.as_slice() == ["wanted"],
3875 "top_n_ids should apply non-indexed filters after candidate search: {id_results:?}"
3876 );
3877
3878 Ok(())
3879 }
3880
3881 #[tokio::test]
3882 async fn live_json_metadata_filter_is_applied_after_candidate_search() -> anyhow::Result<()> {
3883 let index = live_json_metadata_test_index(
3884 "live_json_metadata_filter_is_applied_after_candidate_search",
3885 vec![
3886 json_metadata_row("nearest", "docs", "skip", "nearest skipped", vec![1.0, 0.0]),
3887 json_metadata_row("matched", "docs", "vvv", "metadata match", vec![0.0, 1.0]),
3888 ],
3889 )
3890 .await?;
3891
3892 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3893 .query("needle")
3894 .samples(1)
3895 .filter(SqliteSearchFilter::eq(
3896 "metadata->>'$.xxx'",
3897 serde_json::json!("vvv"),
3898 ))
3899 .build();
3900
3901 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
3902 let ids = results
3903 .iter()
3904 .map(|(_, id, _)| id.as_str())
3905 .collect::<Vec<_>>();
3906 anyhow::ensure!(
3907 ids.as_slice() == ["matched"],
3908 "JSON metadata filters should not be starved by the initial candidate limit: {results:?}"
3909 );
3910
3911 let id_results = index.top_n_ids(req).await?;
3912 let result_ids = id_results
3913 .iter()
3914 .map(|(_, id)| id.as_str())
3915 .collect::<Vec<_>>();
3916 anyhow::ensure!(
3917 result_ids.as_slice() == ["matched"],
3918 "top_n_ids should apply JSON metadata filters after candidate search: {id_results:?}"
3919 );
3920
3921 Ok(())
3922 }
3923
3924 #[tokio::test]
3925 async fn live_json_arrow_filter_compares_against_json_text() -> anyhow::Result<()> {
3926 let index = live_json_metadata_test_index(
3927 "live_json_arrow_filter_compares_against_json_text",
3928 vec![
3929 json_metadata_row("nearest", "docs", "skip", "nearest skipped", vec![1.0, 0.0]),
3930 json_metadata_row("matched", "docs", "vvv", "metadata match", vec![0.0, 1.0]),
3931 ],
3932 )
3933 .await?;
3934
3935 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
3936 .query("needle")
3937 .samples(1)
3938 .filter(SqliteSearchFilter::eq(
3939 "metadata->'$.xxx'",
3940 serde_json::json!("vvv"),
3941 ))
3942 .build();
3943
3944 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
3945 let ids = results
3946 .iter()
3947 .map(|(_, id, _)| id.as_str())
3948 .collect::<Vec<_>>();
3949 anyhow::ensure!(
3950 ids.as_slice() == ["matched"],
3951 "SQLite `->` JSON filters should compare against JSON text: {results:?}"
3952 );
3953
3954 let id_results = index.top_n_ids(req).await?;
3955 let result_ids = id_results
3956 .iter()
3957 .map(|(_, id)| id.as_str())
3958 .collect::<Vec<_>>();
3959 anyhow::ensure!(
3960 result_ids.as_slice() == ["matched"],
3961 "top_n_ids should apply SQLite `->` JSON filters: {id_results:?}"
3962 );
3963
3964 Ok(())
3965 }
3966
3967 #[tokio::test]
3968 async fn live_mixed_indexed_and_json_metadata_filters_are_applied() -> anyhow::Result<()> {
3969 let index = live_json_metadata_test_index(
3970 "live_mixed_indexed_and_json_metadata_filters_are_applied",
3971 vec![
3972 json_metadata_row(
3973 "nearest-docs",
3974 "docs",
3975 "skip",
3976 "nearest docs skipped by JSON metadata",
3977 vec![1.0, 0.0],
3978 ),
3979 json_metadata_row(
3980 "nearest-json",
3981 "misc",
3982 "vvv",
3983 "nearest JSON match skipped by category",
3984 vec![0.9, 0.1],
3985 ),
3986 json_metadata_row(
3987 "matched",
3988 "docs",
3989 "vvv",
3990 "matching category and JSON metadata",
3991 vec![0.0, 1.0],
3992 ),
3993 ],
3994 )
3995 .await?;
3996
3997 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs")).and(
3998 SqliteSearchFilter::eq("metadata->>'$.xxx'", serde_json::json!("vvv")),
3999 );
4000 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4001 .query("needle")
4002 .samples(1)
4003 .filter(filter)
4004 .build();
4005
4006 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
4007 let ids = results
4008 .iter()
4009 .map(|(_, id, _)| id.as_str())
4010 .collect::<Vec<_>>();
4011 anyhow::ensure!(
4012 ids.as_slice() == ["matched"],
4013 "indexed and JSON metadata filters should both be applied: {results:?}"
4014 );
4015
4016 let id_results = index.top_n_ids(req).await?;
4017 let result_ids = id_results
4018 .iter()
4019 .map(|(_, id)| id.as_str())
4020 .collect::<Vec<_>>();
4021 anyhow::ensure!(
4022 result_ids.as_slice() == ["matched"],
4023 "top_n_ids should apply both indexed and JSON metadata filters: {id_results:?}"
4024 );
4025
4026 Ok(())
4027 }
4028
4029 #[tokio::test]
4030 async fn live_negated_eq_filter_is_applied_during_candidate_search() -> anyhow::Result<()> {
4031 let index = live_test_index(
4032 "live_negated_eq_filter_is_applied_during_candidate_search",
4033 vec![
4034 row(
4035 "nearest",
4036 "misc",
4037 "nearest excluded category",
4038 vec![1.0, 0.0],
4039 ),
4040 row("docs", "docs", "docs match", vec![0.0, 1.0]),
4041 ],
4042 )
4043 .await?;
4044
4045 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4046 .query("needle")
4047 .samples(1)
4048 .filter(SqliteSearchFilter::eq("category", serde_json::json!("misc")).not())
4049 .build();
4050
4051 let results = index.top_n::<TestDocument>(req.clone()).await?;
4052 let ids = results
4053 .iter()
4054 .map(|(_, id, _)| id.as_str())
4055 .collect::<Vec<_>>();
4056
4057 anyhow::ensure!(
4058 ids.as_slice() == ["docs"],
4059 "negated filters should constrain sqlite-vec candidate search: {results:?}"
4060 );
4061
4062 let id_results = index.top_n_ids(req).await?;
4063 let result_ids = id_results
4064 .iter()
4065 .map(|(_, id)| id.as_str())
4066 .collect::<Vec<_>>();
4067
4068 anyhow::ensure!(
4069 result_ids.as_slice() == ["docs"],
4070 "top_n_ids should use negated filters during candidate search: {id_results:?}"
4071 );
4072
4073 Ok(())
4074 }
4075
4076 #[tokio::test]
4077 async fn live_top_n_reads_id_by_column_name_not_schema_position() -> anyhow::Result<()> {
4078 register_sqlite_vec_extension();
4079
4080 let conn = Connection::open(
4081 "file:live_top_n_reads_id_by_column_name_not_schema_position?mode=memory",
4082 )
4083 .await?;
4084 let model = TestEmbeddingModel;
4085 let vector_store: SqliteVectorStore<_, ReorderedIdDocument> =
4086 SqliteVectorStore::new(conn, &model).await?;
4087
4088 vector_store
4089 .add_rows(vec![
4090 reordered_id_row("winner", "winner title", "docs", vec![1.0, 0.0]),
4091 reordered_id_row("other", "other title", "docs", vec![0.0, 1.0]),
4092 ])
4093 .await?;
4094
4095 let index = vector_store.index(model);
4096 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4097 .query("needle")
4098 .samples(1)
4099 .build();
4100
4101 let results = index.top_n::<ReorderedIdDocument>(req.clone()).await?;
4102 let Some((_, id, doc)) = results.first() else {
4103 anyhow::bail!("expected reordered-id result");
4104 };
4105 anyhow::ensure!(
4106 id == "winner",
4107 "top_n should return the id column, not the first schema column: {results:?}"
4108 );
4109 anyhow::ensure!(
4110 doc.id == "winner" && doc.title == "winner title",
4111 "document columns should still deserialize in schema order: {doc:?}"
4112 );
4113
4114 let id_results = index.top_n_ids(req).await?;
4115 anyhow::ensure!(
4116 id_results.first().map(|(_, id)| id.as_str()) == Some("winner"),
4117 "top_n_ids should agree with top_n id handling: {id_results:?}"
4118 );
4119
4120 Ok(())
4121 }
4122
4123 #[tokio::test]
4124 async fn live_internal_score_and_rank_column_names_do_not_shadow_search_columns()
4125 -> anyhow::Result<()> {
4126 register_sqlite_vec_extension();
4127
4128 let conn = Connection::open(
4129 "file:live_internal_score_and_rank_column_names_do_not_shadow_search_columns?mode=memory",
4130 )
4131 .await?;
4132 let model = TestEmbeddingModel;
4133 let vector_store: SqliteVectorStore<_, InternalAliasDocument> =
4134 SqliteVectorStore::new(conn, &model).await?;
4135
4136 vector_store
4137 .add_rows(vec![
4138 internal_alias_row(
4139 "winner",
4140 "payload score",
4141 "payload rank",
4142 "winner title",
4143 vec![1.0, 0.0],
4144 ),
4145 internal_alias_row(
4146 "other",
4147 "other score",
4148 "other rank",
4149 "other title",
4150 vec![0.0, 1.0],
4151 ),
4152 ])
4153 .await?;
4154
4155 let index = vector_store.index(model);
4156 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4157 .query("needle")
4158 .samples(1)
4159 .threshold(0.9)
4160 .build();
4161
4162 let results = index.top_n::<InternalAliasDocument>(req.clone()).await?;
4163 let Some((score, id, doc)) = results.first() else {
4164 anyhow::bail!("expected internal-alias document result");
4165 };
4166
4167 anyhow::ensure!(id == "winner", "unexpected id: {results:?}");
4168 anyhow::ensure!(
4169 (*score - 1.0).abs() <= SCORE_EPSILON,
4170 "top_n should return computed score, not the document __rig_score column: {results:?}"
4171 );
4172 anyhow::ensure!(
4173 doc.rig_score == "payload score" && doc.rig_rank == "payload rank",
4174 "document columns with internal-looking names should still deserialize: {doc:?}"
4175 );
4176
4177 let id_results = index.top_n_ids(req).await?;
4178 anyhow::ensure!(
4179 id_results
4180 .first()
4181 .map(|(score, id)| ((*score - 1.0).abs() <= SCORE_EPSILON, id.as_str()))
4182 == Some((true, "winner")),
4183 "top_n_ids should agree with top_n despite internal-looking document columns: {id_results:?}"
4184 );
4185
4186 Ok(())
4187 }
4188
4189 #[tokio::test]
4190 async fn live_typed_columns_round_trip_and_filter_during_candidate_search() -> anyhow::Result<()>
4191 {
4192 let index = live_typed_test_index(
4193 "live_typed_columns_round_trip_and_filter_during_candidate_search",
4194 vec![
4195 typed_row(
4196 1,
4197 "misc",
4198 100,
4199 0.99,
4200 true,
4201 "nearest excluded by typed metadata",
4202 vec![1.0, 0.0],
4203 ),
4204 typed_row(2, "docs", 5, 0.95, true, "typed docs match", vec![0.0, 1.0]),
4205 typed_row(
4206 3,
4207 "docs",
4208 5,
4209 0.97,
4210 false,
4211 "unpublished docs match",
4212 vec![0.0, 0.9],
4213 ),
4214 ],
4215 )
4216 .await?;
4217
4218 let filter = SqliteSearchFilter::lt("priority", serde_json::json!(10))
4219 .and(SqliteSearchFilter::gt("rating", serde_json::json!(0.9)))
4220 .and(SqliteSearchFilter::eq("published", serde_json::json!(true)));
4221 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4222 .query("needle")
4223 .samples(1)
4224 .filter(filter)
4225 .build();
4226
4227 let results = index.top_n::<TypedTestDocument>(req.clone()).await?;
4228 anyhow::ensure!(
4229 results.len() == 1,
4230 "expected one typed document result: {results:?}"
4231 );
4232
4233 let Some((_, id, doc)) = results.first() else {
4234 anyhow::bail!("expected one typed document result");
4235 };
4236 anyhow::ensure!(id == "2", "expected integer id to be returned as string");
4237 anyhow::ensure!(doc.id == 2, "typed integer id should round-trip: {doc:?}");
4238 anyhow::ensure!(
4239 doc.priority == 5,
4240 "typed integer field should round-trip: {doc:?}"
4241 );
4242 anyhow::ensure!(
4243 (doc.rating - 0.95).abs() < f64::EPSILON,
4244 "typed float field should round-trip: {doc:?}"
4245 );
4246 anyhow::ensure!(
4247 doc.published,
4248 "typed boolean field should round-trip: {doc:?}"
4249 );
4250
4251 let id_results = index.top_n_ids(req).await?;
4252 let result_ids = id_results
4253 .iter()
4254 .map(|(_, id)| id.as_str())
4255 .collect::<Vec<_>>();
4256 anyhow::ensure!(
4257 result_ids.as_slice() == ["2"],
4258 "top_n_ids should use the same typed metadata filters: {id_results:?}"
4259 );
4260
4261 Ok(())
4262 }
4263
4264 #[tokio::test]
4265 async fn live_boolean_range_filter_is_rejected() -> anyhow::Result<()> {
4266 let index = live_typed_test_index(
4267 "live_boolean_range_filter_is_rejected",
4268 vec![
4269 typed_row(
4270 1,
4271 "misc",
4272 1,
4273 0.5,
4274 false,
4275 "nearest unpublished doc",
4276 vec![1.0, 0.0],
4277 ),
4278 typed_row(2, "docs", 2, 0.7, true, "published doc", vec![0.0, 1.0]),
4279 ],
4280 )
4281 .await?;
4282
4283 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4284 .query("needle")
4285 .samples(2)
4286 .filter(SqliteSearchFilter::gt(
4287 "published",
4288 serde_json::json!(false),
4289 ))
4290 .build();
4291
4292 ensure_vector_store_filter_error(
4293 index.top_n::<TypedTestDocument>(req.clone()).await,
4294 "top_n boolean range filter",
4295 )?;
4296 ensure_vector_store_filter_error(
4297 index.top_n_ids(req).await,
4298 "top_n_ids boolean range filter",
4299 )?;
4300
4301 Ok(())
4302 }
4303
4304 #[tokio::test]
4305 async fn live_mismatched_metadata_filter_value_type_is_rejected() -> anyhow::Result<()> {
4306 let index = live_typed_test_index(
4307 "live_mismatched_metadata_filter_value_type_is_rejected",
4308 vec![typed_row(
4309 1,
4310 "docs",
4311 1,
4312 0.95,
4313 true,
4314 "published doc",
4315 vec![1.0, 0.0],
4316 )],
4317 )
4318 .await?;
4319
4320 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4321 .query("needle")
4322 .samples(1)
4323 .filter(SqliteSearchFilter::eq(
4324 "published",
4325 serde_json::json!("true"),
4326 ))
4327 .build();
4328
4329 ensure_vector_store_filter_error(
4330 index.top_n::<TypedTestDocument>(req.clone()).await,
4331 "top_n mismatched metadata filter value type",
4332 )?;
4333 ensure_vector_store_filter_error(
4334 index.top_n_ids(req).await,
4335 "top_n_ids mismatched metadata filter value type",
4336 )?;
4337
4338 Ok(())
4339 }
4340
4341 #[tokio::test]
4342 async fn live_matches_exact_oracle_for_metrics_filters_and_thresholds() -> anyhow::Result<()> {
4343 let query = vec![1.0, 0.0];
4344 let rows = oracle_test_rows();
4345 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs"))
4346 .and(SqliteSearchFilter::lt("priority", serde_json::json!(10)))
4347 .and(SqliteSearchFilter::gt("rating", serde_json::json!(0.8)))
4348 .and(SqliteSearchFilter::eq("published", serde_json::json!(true)));
4349
4350 for distance_metric in [
4351 SqliteDistanceMetric::Cosine,
4352 SqliteDistanceMetric::L2,
4353 SqliteDistanceMetric::L1,
4354 ] {
4355 let threshold = oracle_threshold(distance_metric);
4356 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4357 .query("needle")
4358 .samples(u64::try_from(rows.len())?)
4359 .threshold(threshold)
4360 .filter(filter.clone())
4361 .build();
4362 let expected = exact_oracle_results(
4363 &rows,
4364 &query,
4365 distance_metric,
4366 threshold,
4367 rows.len(),
4368 |row| {
4369 row.category == "docs" && row.priority < 10 && row.rating > 0.8 && row.published
4370 },
4371 )?;
4372 let test_name =
4373 format!("live_matches_exact_oracle_for_{distance_metric:?}").to_ascii_lowercase();
4374 let index = live_typed_test_index_with_metric(
4375 &test_name,
4376 sqlite_oracle_rows(&rows),
4377 distance_metric,
4378 )
4379 .await?;
4380
4381 let results = index.top_n::<TypedTestDocument>(req.clone()).await?;
4382 let scored_ids = results
4383 .iter()
4384 .map(|(score, id, doc)| {
4385 anyhow::ensure!(
4386 id == &doc.id.to_string(),
4387 "top_n returned mismatched id and document: id={id}, doc={doc:?}"
4388 );
4389 Ok((*score, id.clone()))
4390 })
4391 .collect::<anyhow::Result<Vec<_>>>()?;
4392 assert_scored_ids_match(&scored_ids, &expected, distance_metric, "top_n")?;
4393
4394 let id_results = index.top_n_ids(req).await?;
4395 assert_scored_ids_match(&id_results, &expected, distance_metric, "top_n_ids")?;
4396 }
4397
4398 Ok(())
4399 }
4400
4401 #[tokio::test]
4402 async fn live_or_filter_preserves_mixed_document_semantics() -> anyhow::Result<()> {
4403 let index = live_test_index(
4404 "live_or_filter_preserves_mixed_document_semantics",
4405 vec![
4406 row(
4407 "nearest",
4408 "misc",
4409 "nearest excluded category",
4410 vec![1.0, 0.0],
4411 ),
4412 row("special", "misc", "special title", vec![0.9, 0.1]),
4413 row("docs", "docs", "far docs match", vec![0.0, 1.0]),
4414 ],
4415 )
4416 .await?;
4417
4418 let filter = SqliteSearchFilter::eq("category", serde_json::json!("docs")).or(
4419 SqliteSearchFilter::eq("title", serde_json::json!("special title")),
4420 );
4421
4422 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4423 .query("needle")
4424 .samples(1)
4425 .filter(filter)
4426 .build();
4427
4428 let results = index.top_n::<TestDocument>(req.clone()).await?;
4429 let ids = results
4430 .iter()
4431 .map(|(_, id, _)| id.as_str())
4432 .collect::<Vec<_>>();
4433 anyhow::ensure!(
4434 ids.as_slice() == ["special"],
4435 "OR filters should be applied as a whole document predicate: {results:?}"
4436 );
4437
4438 let id_results = index.top_n_ids(req).await?;
4439 let result_ids = id_results
4440 .iter()
4441 .map(|(_, id)| id.as_str())
4442 .collect::<Vec<_>>();
4443 anyhow::ensure!(
4444 result_ids.as_slice() == ["special"],
4445 "top_n_ids should preserve OR document semantics: {id_results:?}"
4446 );
4447
4448 Ok(())
4449 }
4450
4451 #[tokio::test]
4452 async fn live_pattern_and_null_filters_are_applied_after_candidate_search() -> anyhow::Result<()>
4453 {
4454 let index = live_json_metadata_test_index(
4455 "live_pattern_and_null_filters_are_applied_after_candidate_search",
4456 vec![
4457 json_metadata_row("nearest", "docs", "skip", "skip this", vec![1.0, 0.0]),
4458 json_metadata_row("matched", "docs", "vvv", "metadata match", vec![0.0, 1.0]),
4459 ],
4460 )
4461 .await?;
4462
4463 let filter = SqliteSearchFilter::is_null("metadata->>'$.missing'".to_string())
4464 .and(SqliteSearchFilter::like("title".to_string(), "metadata%"))
4465 .and(SqliteSearchFilter::glob("category".to_string(), "doc*"));
4466
4467 let req = VectorSearchRequest::<SqliteSearchFilter>::builder()
4468 .query("needle")
4469 .samples(1)
4470 .filter(filter)
4471 .build();
4472
4473 let results = index.top_n::<JsonMetadataDocument>(req.clone()).await?;
4474 let ids = results
4475 .iter()
4476 .map(|(_, id, _)| id.as_str())
4477 .collect::<Vec<_>>();
4478 anyhow::ensure!(
4479 ids.as_slice() == ["matched"],
4480 "pattern and null filters should not be starved by the initial candidate limit: {results:?}"
4481 );
4482
4483 let id_results = index.top_n_ids(req).await?;
4484 let result_ids = id_results
4485 .iter()
4486 .map(|(_, id)| id.as_str())
4487 .collect::<Vec<_>>();
4488 anyhow::ensure!(
4489 result_ids.as_slice() == ["matched"],
4490 "top_n_ids should apply pattern and null filters after candidate search: {id_results:?}"
4491 );
4492
4493 Ok(())
4494 }
4495
4496 type SqliteExtensionFn =
4497 unsafe extern "C" fn(*mut sqlite3, *mut *mut c_char, *const sqlite3_api_routines) -> i32;
4498
4499 fn register_sqlite_vec_extension() {
4500 static REGISTER_SQLITE_VEC: Once = Once::new();
4501
4502 REGISTER_SQLITE_VEC.call_once(|| unsafe {
4503 sqlite3_auto_extension(Some(std::mem::transmute::<*const (), SqliteExtensionFn>(
4504 sqlite3_vec_init as *const (),
4505 )));
4506 });
4507 }
4508
4509 async fn live_test_index(
4510 name: &str,
4511 rows: Vec<(TestDocument, OneOrMany<Embedding>)>,
4512 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TestDocument>> {
4513 live_test_index_with_metric(name, rows, SqliteDistanceMetric::Cosine).await
4514 }
4515
4516 async fn live_test_index_with_metric(
4517 name: &str,
4518 rows: Vec<(TestDocument, OneOrMany<Embedding>)>,
4519 distance_metric: SqliteDistanceMetric,
4520 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TestDocument>> {
4521 register_sqlite_vec_extension();
4522
4523 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4524 let model = TestEmbeddingModel;
4525 let vector_store =
4526 SqliteVectorStore::with_distance_metric(conn, &model, distance_metric).await?;
4527
4528 vector_store.add_rows(rows).await?;
4529
4530 Ok(vector_store.index(model))
4531 }
4532
4533 async fn live_typed_test_index(
4534 name: &str,
4535 rows: Vec<(TypedTestDocument, OneOrMany<Embedding>)>,
4536 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TypedTestDocument>> {
4537 live_typed_test_index_with_metric(name, rows, SqliteDistanceMetric::Cosine).await
4538 }
4539
4540 async fn live_typed_test_index_with_metric(
4541 name: &str,
4542 rows: Vec<(TypedTestDocument, OneOrMany<Embedding>)>,
4543 distance_metric: SqliteDistanceMetric,
4544 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, TypedTestDocument>> {
4545 register_sqlite_vec_extension();
4546
4547 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4548 let model = TestEmbeddingModel;
4549 let vector_store: SqliteVectorStore<_, TypedTestDocument> =
4550 SqliteVectorStore::with_distance_metric(conn, &model, distance_metric).await?;
4551
4552 vector_store.add_rows(rows).await?;
4553
4554 Ok(vector_store.index(model))
4555 }
4556
4557 async fn live_common_type_test_index(
4558 name: &str,
4559 rows: Vec<(CommonTypeDocument, OneOrMany<Embedding>)>,
4560 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, CommonTypeDocument>> {
4561 register_sqlite_vec_extension();
4562
4563 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4564 let model = TestEmbeddingModel;
4565 let vector_store: SqliteVectorStore<_, CommonTypeDocument> =
4566 SqliteVectorStore::new(conn, &model).await?;
4567
4568 vector_store.add_rows(rows).await?;
4569
4570 Ok(vector_store.index(model))
4571 }
4572
4573 async fn live_json_metadata_test_index(
4574 name: &str,
4575 rows: Vec<(JsonMetadataDocument, OneOrMany<Embedding>)>,
4576 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, JsonMetadataDocument>> {
4577 register_sqlite_vec_extension();
4578
4579 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4580 let model = TestEmbeddingModel;
4581 let vector_store: SqliteVectorStore<_, JsonMetadataDocument> =
4582 SqliteVectorStore::new(conn, &model).await?;
4583
4584 vector_store.add_rows(rows).await?;
4585
4586 Ok(vector_store.index(model))
4587 }
4588
4589 async fn live_structured_json_metadata_test_index(
4590 name: &str,
4591 rows: Vec<(StructuredJsonMetadataDocument, OneOrMany<Embedding>)>,
4592 ) -> anyhow::Result<SqliteVectorIndex<TestEmbeddingModel, StructuredJsonMetadataDocument>> {
4593 register_sqlite_vec_extension();
4594
4595 let conn = Connection::open(format!("file:{name}?mode=memory")).await?;
4596 let model = TestEmbeddingModel;
4597 let vector_store: SqliteVectorStore<_, StructuredJsonMetadataDocument> =
4598 SqliteVectorStore::new(conn, &model).await?;
4599
4600 vector_store.add_rows(rows).await?;
4601
4602 Ok(vector_store.index(model))
4603 }
4604
4605 fn row(
4606 id: impl Into<String>,
4607 category: impl Into<String>,
4608 title: impl Into<String>,
4609 vec: Vec<f64>,
4610 ) -> (TestDocument, OneOrMany<Embedding>) {
4611 let document = TestDocument {
4612 id: id.into(),
4613 category: category.into(),
4614 title: title.into(),
4615 };
4616
4617 (
4618 document.clone(),
4619 OneOrMany::one(Embedding {
4620 document: document.title,
4621 vec,
4622 }),
4623 )
4624 }
4625
4626 fn common_type_row(
4627 id: impl Into<String>,
4628 name: impl Into<String>,
4629 notes: impl Into<String>,
4630 rank: i64,
4631 vec: Vec<f64>,
4632 ) -> (CommonTypeDocument, OneOrMany<Embedding>) {
4633 let document = CommonTypeDocument {
4634 id: id.into(),
4635 name: name.into(),
4636 notes: notes.into(),
4637 rank,
4638 };
4639
4640 (
4641 document.clone(),
4642 OneOrMany::one(Embedding {
4643 document: document.name.clone(),
4644 vec,
4645 }),
4646 )
4647 }
4648
4649 fn json_metadata_row(
4650 id: impl Into<String>,
4651 category: impl Into<String>,
4652 xxx: impl AsRef<str>,
4653 title: impl Into<String>,
4654 vec: Vec<f64>,
4655 ) -> (JsonMetadataDocument, OneOrMany<Embedding>) {
4656 let document = JsonMetadataDocument {
4657 id: id.into(),
4658 category: category.into(),
4659 metadata: serde_json::json!({ "xxx": xxx.as_ref() }).to_string(),
4660 title: title.into(),
4661 };
4662
4663 (
4664 document.clone(),
4665 OneOrMany::one(Embedding {
4666 document: document.title.clone(),
4667 vec,
4668 }),
4669 )
4670 }
4671
4672 fn structured_json_metadata_row(
4673 id: impl Into<String>,
4674 metadata: StructuredMetadata,
4675 title: impl Into<String>,
4676 vec: Vec<f64>,
4677 ) -> (StructuredJsonMetadataDocument, OneOrMany<Embedding>) {
4678 let document = StructuredJsonMetadataDocument {
4679 id: id.into(),
4680 metadata,
4681 title: title.into(),
4682 };
4683
4684 (
4685 document.clone(),
4686 OneOrMany::one(Embedding {
4687 document: document.title.clone(),
4688 vec,
4689 }),
4690 )
4691 }
4692
4693 fn reordered_id_row(
4694 id: impl Into<String>,
4695 title: impl Into<String>,
4696 category: impl Into<String>,
4697 vec: Vec<f64>,
4698 ) -> (ReorderedIdDocument, OneOrMany<Embedding>) {
4699 let document = ReorderedIdDocument {
4700 title: title.into(),
4701 id: id.into(),
4702 category: category.into(),
4703 };
4704
4705 (
4706 document.clone(),
4707 OneOrMany::one(Embedding {
4708 document: document.title.clone(),
4709 vec,
4710 }),
4711 )
4712 }
4713
4714 fn internal_alias_row(
4715 id: impl Into<String>,
4716 rig_score: impl Into<String>,
4717 rig_rank: impl Into<String>,
4718 title: impl Into<String>,
4719 vec: Vec<f64>,
4720 ) -> (InternalAliasDocument, OneOrMany<Embedding>) {
4721 let document = InternalAliasDocument {
4722 id: id.into(),
4723 rig_score: rig_score.into(),
4724 rig_rank: rig_rank.into(),
4725 title: title.into(),
4726 };
4727
4728 (
4729 document.clone(),
4730 OneOrMany::one(Embedding {
4731 document: document.title.clone(),
4732 vec,
4733 }),
4734 )
4735 }
4736
4737 fn typed_row(
4738 id: i64,
4739 category: impl Into<String>,
4740 priority: i64,
4741 rating: f64,
4742 published: bool,
4743 title: impl Into<String>,
4744 vec: Vec<f64>,
4745 ) -> (TypedTestDocument, OneOrMany<Embedding>) {
4746 let document = TypedTestDocument {
4747 id,
4748 category: category.into(),
4749 priority,
4750 rating,
4751 published,
4752 title: title.into(),
4753 };
4754
4755 (
4756 document.clone(),
4757 OneOrMany::one(Embedding {
4758 document: document.title,
4759 vec,
4760 }),
4761 )
4762 }
4763
4764 #[derive(Clone, Debug)]
4765 struct OracleRow {
4766 document: TypedTestDocument,
4767 embedding: Vec<f64>,
4768 }
4769
4770 #[derive(Debug)]
4771 struct ExpectedScoredId {
4772 id: String,
4773 score: f64,
4774 }
4775
4776 fn oracle_test_rows() -> Vec<OracleRow> {
4777 vec![
4778 oracle_row(1, "docs", 1, 0.95, true, "exact match", vec![1.0, 0.0]),
4779 oracle_row(2, "docs", 2, 0.90, true, "close match", vec![0.8, 0.6]),
4780 oracle_row(3, "docs", 3, 0.81, true, "borderline match", vec![0.5, 0.5]),
4781 oracle_row(
4782 4,
4783 "docs",
4784 4,
4785 0.70,
4786 true,
4787 "filtered by rating",
4788 vec![0.95, 0.05],
4789 ),
4790 oracle_row(
4791 5,
4792 "docs",
4793 15,
4794 0.99,
4795 true,
4796 "filtered by priority",
4797 vec![1.0, 0.0],
4798 ),
4799 oracle_row(
4800 6,
4801 "docs",
4802 5,
4803 0.99,
4804 false,
4805 "filtered by published",
4806 vec![1.0, 0.0],
4807 ),
4808 oracle_row(
4809 7,
4810 "misc",
4811 1,
4812 0.99,
4813 true,
4814 "filtered by category",
4815 vec![1.0, 0.0],
4816 ),
4817 oracle_row(8, "docs", 5, 0.95, true, "far match", vec![0.0, 1.0]),
4818 ]
4819 }
4820
4821 fn oracle_row(
4822 id: i64,
4823 category: impl Into<String>,
4824 priority: i64,
4825 rating: f64,
4826 published: bool,
4827 title: impl Into<String>,
4828 embedding: Vec<f64>,
4829 ) -> OracleRow {
4830 OracleRow {
4831 document: TypedTestDocument {
4832 id,
4833 category: category.into(),
4834 priority,
4835 rating,
4836 published,
4837 title: title.into(),
4838 },
4839 embedding,
4840 }
4841 }
4842
4843 fn sqlite_oracle_rows(rows: &[OracleRow]) -> Vec<(TypedTestDocument, OneOrMany<Embedding>)> {
4844 rows.iter()
4845 .map(|row| {
4846 (
4847 row.document.clone(),
4848 OneOrMany::one(Embedding {
4849 document: row.document.title.clone(),
4850 vec: row.embedding.clone(),
4851 }),
4852 )
4853 })
4854 .collect()
4855 }
4856
4857 fn oracle_threshold(distance_metric: SqliteDistanceMetric) -> f64 {
4858 match distance_metric {
4859 SqliteDistanceMetric::Cosine => 0.75,
4860 SqliteDistanceMetric::L2 => -0.8,
4861 SqliteDistanceMetric::L1 => -0.9,
4862 }
4863 }
4864
4865 fn exact_oracle_results(
4866 rows: &[OracleRow],
4867 query: &[f64],
4868 distance_metric: SqliteDistanceMetric,
4869 threshold: f64,
4870 samples: usize,
4871 filter: impl Fn(&TypedTestDocument) -> bool,
4872 ) -> anyhow::Result<Vec<ExpectedScoredId>> {
4873 let mut expected = Vec::new();
4874 for row in rows {
4875 if !filter(&row.document) {
4876 continue;
4877 }
4878
4879 let score = oracle_score(distance_metric, query, &row.embedding)?;
4880 if score >= threshold {
4881 expected.push(ExpectedScoredId {
4882 id: row.document.id.to_string(),
4883 score,
4884 });
4885 }
4886 }
4887
4888 sort_expected_scores(&mut expected);
4889 expected.truncate(samples);
4890 Ok(expected)
4891 }
4892
4893 fn sort_expected_scores(expected: &mut [ExpectedScoredId]) {
4894 expected.sort_by(|lhs, rhs| {
4895 rhs.score
4896 .partial_cmp(&lhs.score)
4897 .unwrap_or(Ordering::Equal)
4898 .then_with(|| lhs.id.cmp(&rhs.id))
4899 });
4900 }
4901
4902 fn oracle_score(
4903 distance_metric: SqliteDistanceMetric,
4904 query: &[f64],
4905 embedding: &[f64],
4906 ) -> anyhow::Result<f64> {
4907 anyhow::ensure!(
4908 query.len() == embedding.len(),
4909 "query and embedding dimensions differ: query={}, embedding={}",
4910 query.len(),
4911 embedding.len()
4912 );
4913
4914 let query = query.iter().map(|value| *value as f32).collect::<Vec<_>>();
4915 let embedding = embedding
4916 .iter()
4917 .map(|value| *value as f32)
4918 .collect::<Vec<_>>();
4919
4920 let score = match distance_metric {
4921 SqliteDistanceMetric::Cosine => {
4922 let dot = query
4923 .iter()
4924 .zip(&embedding)
4925 .map(|(lhs, rhs)| lhs * rhs)
4926 .sum::<f32>();
4927 let query_norm = query.iter().map(|value| value * value).sum::<f32>().sqrt();
4928 let embedding_norm = embedding
4929 .iter()
4930 .map(|value| value * value)
4931 .sum::<f32>()
4932 .sqrt();
4933 anyhow::ensure!(
4934 query_norm > 0.0 && embedding_norm > 0.0,
4935 "cosine oracle requires non-zero vectors"
4936 );
4937 dot / (query_norm * embedding_norm)
4938 }
4939 SqliteDistanceMetric::L2 => -query
4940 .iter()
4941 .zip(&embedding)
4942 .map(|(lhs, rhs)| {
4943 let delta = lhs - rhs;
4944 delta * delta
4945 })
4946 .sum::<f32>()
4947 .sqrt(),
4948 SqliteDistanceMetric::L1 => -query
4949 .iter()
4950 .zip(&embedding)
4951 .map(|(lhs, rhs)| (lhs - rhs).abs())
4952 .sum::<f32>(),
4953 };
4954
4955 Ok(f64::from(score))
4956 }
4957
4958 fn assert_scored_ids_match(
4959 actual: &[(f64, String)],
4960 expected: &[ExpectedScoredId],
4961 distance_metric: SqliteDistanceMetric,
4962 context: &str,
4963 ) -> anyhow::Result<()> {
4964 let actual_ids = actual.iter().map(|(_, id)| id.as_str()).collect::<Vec<_>>();
4965 let expected_ids = expected
4966 .iter()
4967 .map(|expected| expected.id.as_str())
4968 .collect::<Vec<_>>();
4969 anyhow::ensure!(
4970 actual_ids == expected_ids,
4971 "{context} ids for {distance_metric:?} did not match exact oracle: actual={actual:?}, expected={expected:?}"
4972 );
4973
4974 for ((actual_score, actual_id), expected) in actual.iter().zip(expected) {
4975 anyhow::ensure!(
4976 (actual_score - expected.score).abs() <= SCORE_EPSILON,
4977 "{context} score for {distance_metric:?} id `{actual_id}` did not match exact oracle: actual={actual_score}, expected={}",
4978 expected.score
4979 );
4980 }
4981
4982 Ok(())
4983 }
4984
4985 #[derive(Clone, Debug, Deserialize, Serialize)]
4986 struct TestDocument {
4987 id: String,
4988 category: String,
4989 title: String,
4990 }
4991
4992 impl SqliteVectorStoreTable for TestDocument {
4993 fn name() -> &'static str {
4994 "live_test_documents"
4995 }
4996
4997 fn schema() -> Vec<Column> {
4998 vec![
4999 Column::new("id", "TEXT PRIMARY KEY"),
5000 Column::new("category", "TEXT").indexed(),
5001 Column::new("title", "TEXT"),
5002 ]
5003 }
5004
5005 fn id(&self) -> String {
5006 self.id.clone()
5007 }
5008
5009 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5010 vec![
5011 ("id", Box::new(self.id.clone())),
5012 ("category", Box::new(self.category.clone())),
5013 ("title", Box::new(self.title.clone())),
5014 ]
5015 }
5016 }
5017
5018 #[derive(Clone, Debug, Deserialize, Serialize)]
5019 struct ReorderedIdDocument {
5020 title: String,
5021 id: String,
5022 category: String,
5023 }
5024
5025 impl SqliteVectorStoreTable for ReorderedIdDocument {
5026 fn name() -> &'static str {
5027 "live_reordered_id_test_documents"
5028 }
5029
5030 fn schema() -> Vec<Column> {
5031 vec![
5032 Column::new("title", "TEXT"),
5033 Column::new("id", "TEXT PRIMARY KEY"),
5034 Column::new("category", "TEXT").indexed(),
5035 ]
5036 }
5037
5038 fn id(&self) -> String {
5039 self.id.clone()
5040 }
5041
5042 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5043 vec![
5044 ("title", Box::new(self.title.clone())),
5045 ("id", Box::new(self.id.clone())),
5046 ("category", Box::new(self.category.clone())),
5047 ]
5048 }
5049 }
5050
5051 #[derive(Clone, Debug, Deserialize, Serialize)]
5052 struct InternalAliasDocument {
5053 id: String,
5054 #[serde(rename = "__rig_score")]
5055 rig_score: String,
5056 #[serde(rename = "__rig_rank")]
5057 rig_rank: String,
5058 title: String,
5059 }
5060
5061 impl SqliteVectorStoreTable for InternalAliasDocument {
5062 fn name() -> &'static str {
5063 "live_internal_alias_test_documents"
5064 }
5065
5066 fn schema() -> Vec<Column> {
5067 vec![
5068 Column::new("id", "TEXT PRIMARY KEY"),
5069 Column::new("__rig_score", "TEXT"),
5070 Column::new("__rig_rank", "TEXT"),
5071 Column::new("title", "TEXT"),
5072 ]
5073 }
5074
5075 fn id(&self) -> String {
5076 self.id.clone()
5077 }
5078
5079 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5080 vec![
5081 ("id", Box::new(self.id.clone())),
5082 ("__rig_score", Box::new(self.rig_score.clone())),
5083 ("__rig_rank", Box::new(self.rig_rank.clone())),
5084 ("title", Box::new(self.title.clone())),
5085 ]
5086 }
5087 }
5088
5089 #[derive(Clone, Debug, Deserialize, Serialize)]
5090 struct CommonTypeDocument {
5091 id: String,
5092 name: String,
5093 notes: String,
5094 rank: i64,
5095 }
5096
5097 impl SqliteVectorStoreTable for CommonTypeDocument {
5098 fn name() -> &'static str {
5099 "live_common_type_test_documents"
5100 }
5101
5102 fn schema() -> Vec<Column> {
5103 vec![
5104 Column::new("id", "TEXT PRIMARY KEY"),
5105 Column::new("name", "VARCHAR(255)").indexed(),
5106 Column::new("notes", "CLOB"),
5107 Column::new("rank", "NUMERIC"),
5108 ]
5109 }
5110
5111 fn id(&self) -> String {
5112 self.id.clone()
5113 }
5114
5115 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5116 vec![
5117 ("id", Box::new(self.id.clone())),
5118 ("name", Box::new(self.name.clone())),
5119 ("notes", Box::new(self.notes.clone())),
5120 ("rank", Box::new(self.rank)),
5121 ]
5122 }
5123 }
5124
5125 #[derive(Clone, Debug, Deserialize, Serialize)]
5126 struct JsonMetadataDocument {
5127 id: String,
5128 category: String,
5129 metadata: String,
5130 title: String,
5131 }
5132
5133 impl SqliteVectorStoreTable for JsonMetadataDocument {
5134 fn name() -> &'static str {
5135 "live_json_metadata_test_documents"
5136 }
5137
5138 fn schema() -> Vec<Column> {
5139 vec![
5140 Column::new("id", "TEXT PRIMARY KEY"),
5141 Column::new("category", "TEXT").indexed(),
5142 Column::new("metadata", "TEXT"),
5143 Column::new("title", "TEXT"),
5144 ]
5145 }
5146
5147 fn id(&self) -> String {
5148 self.id.clone()
5149 }
5150
5151 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5152 vec![
5153 ("id", Box::new(self.id.clone())),
5154 ("category", Box::new(self.category.clone())),
5155 ("metadata", Box::new(self.metadata.clone())),
5156 ("title", Box::new(self.title.clone())),
5157 ]
5158 }
5159 }
5160
5161 #[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
5162 struct StructuredMetadata {
5163 user_id: i64,
5164 knowledge_id: i64,
5165 knowledge_doc_id: i64,
5166 }
5167
5168 #[derive(Clone, Debug, Deserialize, Serialize)]
5169 struct StructuredJsonMetadataDocument {
5170 id: String,
5171 metadata: StructuredMetadata,
5172 title: String,
5173 }
5174
5175 impl SqliteVectorStoreTable for StructuredJsonMetadataDocument {
5176 fn name() -> &'static str {
5177 "live_structured_json_metadata_test_documents"
5178 }
5179
5180 fn schema() -> Vec<Column> {
5181 vec![
5182 Column::new("id", "TEXT PRIMARY KEY"),
5183 Column::new("metadata", "JSON"),
5184 Column::new("title", "TEXT"),
5185 ]
5186 }
5187
5188 fn id(&self) -> String {
5189 self.id.clone()
5190 }
5191
5192 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5193 vec![
5194 ("id", Box::new(self.id.clone())),
5195 (
5196 "metadata",
5197 Box::new(serde_json::json!({
5198 "user_id": self.metadata.user_id,
5199 "knowledge_id": self.metadata.knowledge_id,
5200 "knowledge_doc_id": self.metadata.knowledge_doc_id,
5201 })),
5202 ),
5203 ("title", Box::new(self.title.clone())),
5204 ]
5205 }
5206 }
5207
5208 #[derive(Clone, Debug, Deserialize, Serialize)]
5209 struct TypedTestDocument {
5210 id: i64,
5211 category: String,
5212 priority: i64,
5213 rating: f64,
5214 published: bool,
5215 title: String,
5216 }
5217
5218 impl SqliteVectorStoreTable for TypedTestDocument {
5219 fn name() -> &'static str {
5220 "live_typed_test_documents"
5221 }
5222
5223 fn schema() -> Vec<Column> {
5224 vec![
5225 Column::new("id", "INTEGER PRIMARY KEY"),
5226 Column::new("category", "TEXT").indexed(),
5227 Column::new("priority", "INTEGER").indexed(),
5228 Column::new("rating", "FLOAT").indexed(),
5229 Column::new("published", "BOOLEAN").indexed(),
5230 Column::new("title", "TEXT"),
5231 ]
5232 }
5233
5234 fn id(&self) -> String {
5235 self.id.to_string()
5236 }
5237
5238 fn column_values(&self) -> Vec<(&'static str, Box<dyn ColumnValue>)> {
5239 vec![
5240 ("id", Box::new(self.id)),
5241 ("category", Box::new(self.category.clone())),
5242 ("priority", Box::new(self.priority)),
5243 ("rating", Box::new(self.rating)),
5244 ("published", Box::new(self.published)),
5245 ("title", Box::new(self.title.clone())),
5246 ]
5247 }
5248 }
5249
5250 #[derive(Clone)]
5251 struct TestEmbeddingModel;
5252
5253 impl EmbeddingModel for TestEmbeddingModel {
5254 const MAX_DOCUMENTS: usize = 16;
5255
5256 type Client = ();
5257
5258 fn make(_: &Self::Client, _: impl Into<String>, _: Option<usize>) -> Self {
5259 Self
5260 }
5261
5262 fn ndims(&self) -> usize {
5263 2
5264 }
5265
5266 async fn embed_texts(
5267 &self,
5268 texts: impl IntoIterator<Item = String> + WasmCompatSend,
5269 ) -> Result<Vec<Embedding>, EmbeddingError> {
5270 Ok(texts
5271 .into_iter()
5272 .map(|text| Embedding {
5273 document: text,
5274 vec: vec![1.0, 0.0],
5275 })
5276 .collect())
5277 }
5278 }
5279}