azoth_vector/search.rs
1//! Vector similarity search API
2
3use crate::types::{DistanceMetric, SearchResult, Vector};
4use azoth_core::Result;
5use azoth_sqlite::SqliteProjectionStore;
6use rusqlite::params;
7use std::sync::Arc;
8
9/// Validate that a SQL identifier (table or column name) is safe.
10///
11/// Only allows `[a-zA-Z_][a-zA-Z0-9_]*` to prevent SQL injection via
12/// identifier manipulation. Returns an error if the identifier is invalid.
13pub(crate) fn validate_sql_identifier(name: &str, kind: &str) -> Result<()> {
14 if name.is_empty() {
15 return Err(azoth_core::error::AzothError::Config(format!(
16 "{} name must not be empty",
17 kind
18 )));
19 }
20 if name.len() > 128 {
21 return Err(azoth_core::error::AzothError::Config(format!(
22 "{} name must be 128 characters or fewer, got {}",
23 kind,
24 name.len()
25 )));
26 }
27 let mut chars = name.chars();
28 let first = chars.next().unwrap(); // safe: name is non-empty
29 if !first.is_ascii_alphabetic() && first != '_' {
30 return Err(azoth_core::error::AzothError::Config(format!(
31 "{} name '{}' must start with a letter or underscore",
32 kind, name
33 )));
34 }
35 for c in chars {
36 if !c.is_ascii_alphanumeric() && c != '_' {
37 return Err(azoth_core::error::AzothError::Config(format!(
38 "{} name '{}' contains invalid character '{}'. \
39 Only ASCII alphanumeric and underscore are allowed.",
40 kind, name, c
41 )));
42 }
43 }
44 Ok(())
45}
46
47/// A bound parameter value for SQL queries.
48///
49/// Used internally by [`VectorFilter`] to hold typed values that are bound
50/// via parameterized queries, preventing SQL injection.
51#[derive(Clone, Debug)]
52pub enum FilterValue {
53 /// String parameter
54 String(String),
55 /// 64-bit integer parameter
56 I64(i64),
57 /// 64-bit float parameter
58 F64(f64),
59}
60
61impl FilterValue {
62 /// Convert into a boxed `ToSql` trait object for rusqlite parameter binding.
63 pub fn to_boxed_sql(self) -> Box<dyn rusqlite::ToSql> {
64 match self {
65 Self::String(s) => Box::new(s),
66 Self::I64(i) => Box::new(i),
67 Self::F64(f) => Box::new(f),
68 }
69 }
70}
71
72/// A single condition in a vector search filter.
73#[derive(Clone, Debug)]
74struct FilterCondition {
75 /// SQL fragment, e.g. `t.category = ?` or `t.in_stock = ?`
76 sql: String,
77 /// Bound parameter values (one per `?` placeholder in `sql`)
78 params: Vec<FilterValue>,
79}
80
81/// Type-safe filter builder for vector search queries.
82///
83/// All column names are validated as safe SQL identifiers and all values are
84/// bound via parameterized queries, eliminating SQL injection by construction.
85///
86/// # Example
87///
88/// ```
89/// use azoth_vector::VectorFilter;
90///
91/// let filter = VectorFilter::new()
92/// .eq("category", "electronics")
93/// .eq_i64("in_stock", 1)
94/// .gt("price", "9.99");
95///
96/// let (sql, params) = filter.to_sql().unwrap();
97/// assert_eq!(sql, "t.category = ? AND t.in_stock = ? AND t.price > ?");
98/// assert_eq!(params.len(), 3);
99/// ```
100#[derive(Clone, Debug, Default)]
101pub struct VectorFilter {
102 conditions: Vec<FilterCondition>,
103}
104
105impl VectorFilter {
106 /// Create an empty filter (matches all rows).
107 pub fn new() -> Self {
108 Self::default()
109 }
110
111 /// Add a string equality condition: `t.<column> = ?`
112 pub fn eq(self, column: &str, value: impl Into<String>) -> Self {
113 self.add_op(column, "=", FilterValue::String(value.into()))
114 }
115
116 /// Add a string inequality condition: `t.<column> != ?`
117 pub fn neq(self, column: &str, value: impl Into<String>) -> Self {
118 self.add_op(column, "!=", FilterValue::String(value.into()))
119 }
120
121 /// Add a string greater-than condition: `t.<column> > ?`
122 pub fn gt(self, column: &str, value: impl Into<String>) -> Self {
123 self.add_op(column, ">", FilterValue::String(value.into()))
124 }
125
126 /// Add a string greater-or-equal condition: `t.<column> >= ?`
127 pub fn gte(self, column: &str, value: impl Into<String>) -> Self {
128 self.add_op(column, ">=", FilterValue::String(value.into()))
129 }
130
131 /// Add a string less-than condition: `t.<column> < ?`
132 pub fn lt(self, column: &str, value: impl Into<String>) -> Self {
133 self.add_op(column, "<", FilterValue::String(value.into()))
134 }
135
136 /// Add a string less-or-equal condition: `t.<column> <= ?`
137 pub fn lte(self, column: &str, value: impl Into<String>) -> Self {
138 self.add_op(column, "<=", FilterValue::String(value.into()))
139 }
140
141 /// Add a LIKE condition: `t.<column> LIKE ?`
142 pub fn like(self, column: &str, pattern: impl Into<String>) -> Self {
143 self.add_op(column, "LIKE", FilterValue::String(pattern.into()))
144 }
145
146 /// Add an integer equality condition: `t.<column> = ?`
147 pub fn eq_i64(self, column: &str, value: i64) -> Self {
148 self.add_op(column, "=", FilterValue::I64(value))
149 }
150
151 /// Add an integer greater-than condition: `t.<column> > ?`
152 pub fn gt_i64(self, column: &str, value: i64) -> Self {
153 self.add_op(column, ">", FilterValue::I64(value))
154 }
155
156 /// Add an integer greater-or-equal condition: `t.<column> >= ?`
157 pub fn gte_i64(self, column: &str, value: i64) -> Self {
158 self.add_op(column, ">=", FilterValue::I64(value))
159 }
160
161 /// Add an integer less-than condition: `t.<column> < ?`
162 pub fn lt_i64(self, column: &str, value: i64) -> Self {
163 self.add_op(column, "<", FilterValue::I64(value))
164 }
165
166 /// Add an integer less-or-equal condition: `t.<column> <= ?`
167 pub fn lte_i64(self, column: &str, value: i64) -> Self {
168 self.add_op(column, "<=", FilterValue::I64(value))
169 }
170
171 /// Add a float equality condition: `t.<column> = ?`
172 pub fn eq_f64(self, column: &str, value: f64) -> Self {
173 self.add_op(column, "=", FilterValue::F64(value))
174 }
175
176 /// Add a float greater-than condition: `t.<column> > ?`
177 pub fn gt_f64(self, column: &str, value: f64) -> Self {
178 self.add_op(column, ">", FilterValue::F64(value))
179 }
180
181 /// Add a float less-than condition: `t.<column> < ?`
182 pub fn lt_f64(self, column: &str, value: f64) -> Self {
183 self.add_op(column, "<", FilterValue::F64(value))
184 }
185
186 /// Internal helper: validate column and push a condition.
187 fn add_op(mut self, column: &str, op: &str, value: FilterValue) -> Self {
188 self.conditions.push(FilterCondition {
189 // We store validated column + op; validation happens in to_sql()
190 sql: format!("t.{column} {op} ?"),
191 params: vec![value],
192 });
193 self
194 }
195
196 /// Emit the WHERE clause and its bound parameters.
197 ///
198 /// Returns `("1 = 1", [])` for an empty filter (matches all rows).
199 ///
200 /// # Errors
201 ///
202 /// Returns `AzothError::Config` if any column name fails identifier validation.
203 pub fn to_sql(&self) -> Result<(String, Vec<FilterValue>)> {
204 if self.conditions.is_empty() {
205 return Ok(("1 = 1".to_string(), Vec::new()));
206 }
207
208 // Validate all column names before emitting SQL
209 for cond in &self.conditions {
210 // Extract column name from `t.<col> <op> ?`
211 let col_name = cond
212 .sql
213 .strip_prefix("t.")
214 .and_then(|rest| rest.split_whitespace().next())
215 .unwrap_or("");
216 validate_sql_identifier(col_name, "Filter column")?;
217 }
218
219 let sql_parts: Vec<&str> = self.conditions.iter().map(|c| c.sql.as_str()).collect();
220 let sql = sql_parts.join(" AND ");
221
222 let params: Vec<FilterValue> = self
223 .conditions
224 .iter()
225 .flat_map(|c| c.params.clone())
226 .collect();
227
228 Ok((sql, params))
229 }
230}
231
232/// Vector search builder
233///
234/// Provides k-NN search with optional filtering and custom distance metrics.
235///
236/// # Example
237///
238/// ```no_run
239/// use azoth::prelude::*;
240/// use azoth_vector::{VectorSearch, Vector, DistanceMetric};
241///
242/// # async fn example() -> Result<()> {
243/// let db = AzothDb::open("./data")?;
244///
245/// let query = Vector::new(vec![0.1, 0.2, 0.3]);
246/// let search = VectorSearch::new(db.projection().clone(), "embeddings", "vector")?
247/// .distance_metric(DistanceMetric::Cosine);
248///
249/// let results = search.knn(&query, 10).await?;
250/// # Ok(())
251/// # }
252/// ```
253pub struct VectorSearch {
254 projection: Arc<SqliteProjectionStore>,
255 table: String,
256 column: String,
257 distance_metric: DistanceMetric,
258}
259
260impl VectorSearch {
261 /// Create a new vector search builder
262 ///
263 /// # Arguments
264 ///
265 /// * `projection` - The SQLite projection store
266 /// * `table` - Table name containing the vector column (must be a valid SQL identifier)
267 /// * `column` - Vector column name (must be a valid SQL identifier, initialized with vector_init)
268 ///
269 /// # Errors
270 ///
271 /// Returns an error if `table` or `column` contain characters other than
272 /// ASCII alphanumeric and underscore, or don't start with a letter/underscore.
273 pub fn new(
274 projection: Arc<SqliteProjectionStore>,
275 table: impl Into<String>,
276 column: impl Into<String>,
277 ) -> Result<Self> {
278 let table = table.into();
279 let column = column.into();
280 validate_sql_identifier(&table, "Table")?;
281 validate_sql_identifier(&column, "Column")?;
282 Ok(Self {
283 projection,
284 table,
285 column,
286 distance_metric: DistanceMetric::Cosine,
287 })
288 }
289
290 /// Set the distance metric
291 ///
292 /// Default is Cosine similarity.
293 pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
294 self.distance_metric = metric;
295 self
296 }
297
298 /// Perform k-nearest neighbors search
299 ///
300 /// Returns up to `k` results ordered by similarity (closest first).
301 ///
302 /// # Example
303 ///
304 /// ```no_run
305 /// # use azoth_vector::{VectorSearch, Vector};
306 /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
307 /// let query_vector = Vector::new(vec![0.1, 0.2, 0.3]);
308 /// let results = search.knn(&query_vector, 10).await?;
309 ///
310 /// for result in results {
311 /// println!("Row {}: distance = {}", result.rowid, result.distance);
312 /// }
313 /// # Ok(())
314 /// # }
315 /// ```
316 pub async fn knn(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
317 // Table and column are validated at construction time via validate_sql_identifier
318 let table = self.table.clone();
319 let column = self.column.clone();
320 let query_json = query.to_json();
321 let k_i64 = k as i64;
322
323 self.projection
324 .query_async(move |conn| {
325 let sql = format!(
326 "SELECT rowid, distance
327 FROM vector_quantize_scan('{table}', '{column}', ?, ?)
328 ORDER BY distance ASC",
329 );
330
331 let mut stmt = conn
332 .prepare(&sql)
333 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
334
335 let results = stmt
336 .query_map(params![query_json, k_i64], |row| {
337 Ok(SearchResult {
338 rowid: row.get(0)?,
339 distance: row.get(1)?,
340 })
341 })
342 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
343 .collect::<rusqlite::Result<Vec<_>>>()
344 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
345
346 Ok(results)
347 })
348 .await
349 }
350
351 /// Search with distance threshold
352 ///
353 /// Returns all results within the given distance threshold, up to `k` results.
354 ///
355 /// # Example
356 ///
357 /// ```no_run
358 /// # use azoth_vector::{VectorSearch, Vector};
359 /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
360 /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
361 /// // Only return results with cosine distance < 0.3 (similarity > 0.7)
362 /// let results = search.threshold(&query, 0.3, 100).await?;
363 /// # Ok(())
364 /// # }
365 /// ```
366 pub async fn threshold(
367 &self,
368 query: &Vector,
369 max_distance: f32,
370 k: usize,
371 ) -> Result<Vec<SearchResult>> {
372 let results = self.knn(query, k).await?;
373 Ok(results
374 .into_iter()
375 .filter(|r| r.distance <= max_distance)
376 .collect())
377 }
378
379 /// Search with structured filter conditions
380 ///
381 /// Allows filtering results by additional columns in the table using a
382 /// type-safe [`VectorFilter`] builder. All column names are validated as
383 /// safe SQL identifiers, and all values are bound via parameterized queries,
384 /// preventing SQL injection by construction.
385 ///
386 /// # Example
387 ///
388 /// ```no_run
389 /// # use azoth_vector::{VectorSearch, Vector, VectorFilter};
390 /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
391 /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
392 ///
393 /// let filter = VectorFilter::new()
394 /// .eq("category", "tech")
395 /// .eq_i64("in_stock", 1);
396 ///
397 /// let results = search.knn_filtered(&query, 10, &filter).await?;
398 /// # Ok(())
399 /// # }
400 /// ```
401 pub async fn knn_filtered(
402 &self,
403 query: &Vector,
404 k: usize,
405 filter: &VectorFilter,
406 ) -> Result<Vec<SearchResult>> {
407 let (where_clause, filter_params) = filter.to_sql()?;
408
409 let table = self.table.clone();
410 let column = self.column.clone();
411 let query_json = query.to_json();
412 let k_i64 = k as i64;
413
414 self.projection
415 .query_async(move |conn| {
416 let sql = format!(
417 "SELECT v.rowid, v.distance
418 FROM vector_quantize_scan('{table}', '{column}', ?, ?) AS v
419 JOIN {table} AS t ON v.rowid = t.rowid
420 WHERE {where_clause}
421 ORDER BY v.distance ASC",
422 );
423
424 let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> =
425 vec![Box::new(query_json), Box::new(k_i64)];
426 for p in filter_params {
427 params_vec.push(p.to_boxed_sql());
428 }
429
430 let mut stmt = conn
431 .prepare(&sql)
432 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
433
434 let results = stmt
435 .query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
436 Ok(SearchResult {
437 rowid: row.get(0)?,
438 distance: row.get(1)?,
439 })
440 })
441 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
442 .collect::<rusqlite::Result<Vec<_>>>()
443 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
444
445 Ok(results)
446 })
447 .await
448 }
449
450 /// Get multiple results by rowids and include their distances from query
451 ///
452 /// Useful for retrieving full records after search.
453 ///
454 /// # Example
455 ///
456 /// ```no_run
457 /// # use azoth_vector::{VectorSearch, Vector};
458 /// # use azoth_core::Result;
459 /// # async fn example(search: VectorSearch) -> Result<()> {
460 /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
461 /// let results = search.knn(&query, 10).await?;
462 ///
463 /// // Get full records
464 /// for result in results {
465 /// let record: String = search.projection()
466 /// .query(|conn: &rusqlite::Connection| {
467 /// conn.query_row(
468 /// "SELECT content FROM embeddings WHERE rowid = ?",
469 /// [result.rowid],
470 /// |row: &rusqlite::Row| row.get(0),
471 /// ).map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))
472 /// })?;
473 /// println!("Distance: {}, Content: {}", result.distance, record);
474 /// }
475 /// # Ok(())
476 /// # }
477 /// ```
478 pub fn projection(&self) -> &Arc<SqliteProjectionStore> {
479 &self.projection
480 }
481
482 /// Get the table name
483 pub fn table(&self) -> &str {
484 &self.table
485 }
486
487 /// Get the column name
488 pub fn column(&self) -> &str {
489 &self.column
490 }
491
492 /// Get the distance metric
493 pub fn distance_metric_value(&self) -> DistanceMetric {
494 self.distance_metric
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use azoth_core::traits::ProjectionStore;
502
503 fn make_store() -> Arc<SqliteProjectionStore> {
504 use tempfile::tempdir;
505 let dir = tempdir().unwrap();
506 let db_path = dir.path().join("test.db");
507
508 let config = azoth_core::ProjectionConfig {
509 path: db_path.clone(),
510 wal_mode: true,
511 synchronous: azoth_core::config::SynchronousMode::Normal,
512 cache_size: -2000,
513 schema_version: 1,
514 read_pool: azoth_core::config::ReadPoolConfig::default(),
515 };
516
517 // Leak the tempdir so it lives long enough for the test
518 std::mem::forget(dir);
519 Arc::new(azoth_sqlite::SqliteProjectionStore::open(config).unwrap())
520 }
521
522 #[test]
523 fn test_search_builder() {
524 let store = make_store();
525
526 let search = VectorSearch::new(store.clone(), "test", "vector")
527 .unwrap()
528 .distance_metric(DistanceMetric::L2);
529
530 assert_eq!(search.table(), "test");
531 assert_eq!(search.column(), "vector");
532 assert_eq!(search.distance_metric_value(), DistanceMetric::L2);
533 }
534
535 #[test]
536 fn test_identifier_validation_rejects_injection() {
537 let store = make_store();
538
539 // SQL injection in table name should be rejected
540 let result = VectorSearch::new(store.clone(), "x; DROP TABLE y; --", "vector");
541 assert!(result.is_err());
542
543 // SQL injection in column name should be rejected
544 let result = VectorSearch::new(store.clone(), "test", "v'; DROP TABLE y; --");
545 assert!(result.is_err());
546
547 // Empty names should be rejected
548 let result = VectorSearch::new(store.clone(), "", "vector");
549 assert!(result.is_err());
550
551 // Names starting with digits should be rejected
552 let result = VectorSearch::new(store.clone(), "123table", "vector");
553 assert!(result.is_err());
554
555 // Valid identifiers should work
556 let result = VectorSearch::new(store.clone(), "my_table", "embedding_col");
557 assert!(result.is_ok());
558
559 let result = VectorSearch::new(store.clone(), "_private", "_col");
560 assert!(result.is_ok());
561 }
562
563 // Full integration tests with vector extension in tests/ directory
564}