Skip to main content

azoth_vector/
search.rs

1//! Vector similarity search API
2
3use crate::types::{DistanceMetric, SearchResult, Vector};
4use azoth_core::Result;
5use azoth_sqlite::SqliteProjectionStore;
6use rusqlite::params;
7use std::sync::Arc;
8
9/// Validate that a SQL identifier (table or column name) is safe.
10///
11/// Only allows `[a-zA-Z_][a-zA-Z0-9_]*` to prevent SQL injection via
12/// identifier manipulation. Returns an error if the identifier is invalid.
13pub(crate) fn validate_sql_identifier(name: &str, kind: &str) -> Result<()> {
14    if name.is_empty() {
15        return Err(azoth_core::error::AzothError::Config(format!(
16            "{} name must not be empty",
17            kind
18        )));
19    }
20    if name.len() > 128 {
21        return Err(azoth_core::error::AzothError::Config(format!(
22            "{} name must be 128 characters or fewer, got {}",
23            kind,
24            name.len()
25        )));
26    }
27    let mut chars = name.chars();
28    let first = chars.next().unwrap(); // safe: name is non-empty
29    if !first.is_ascii_alphabetic() && first != '_' {
30        return Err(azoth_core::error::AzothError::Config(format!(
31            "{} name '{}' must start with a letter or underscore",
32            kind, name
33        )));
34    }
35    for c in chars {
36        if !c.is_ascii_alphanumeric() && c != '_' {
37            return Err(azoth_core::error::AzothError::Config(format!(
38                "{} name '{}' contains invalid character '{}'. \
39                 Only ASCII alphanumeric and underscore are allowed.",
40                kind, name, c
41            )));
42        }
43    }
44    Ok(())
45}
46
47/// A bound parameter value for SQL queries.
48///
49/// Used internally by [`VectorFilter`] to hold typed values that are bound
50/// via parameterized queries, preventing SQL injection.
51#[derive(Clone, Debug)]
52pub enum FilterValue {
53    /// String parameter
54    String(String),
55    /// 64-bit integer parameter
56    I64(i64),
57    /// 64-bit float parameter
58    F64(f64),
59}
60
61impl FilterValue {
62    /// Convert into a boxed `ToSql` trait object for rusqlite parameter binding.
63    pub fn to_boxed_sql(self) -> Box<dyn rusqlite::ToSql> {
64        match self {
65            Self::String(s) => Box::new(s),
66            Self::I64(i) => Box::new(i),
67            Self::F64(f) => Box::new(f),
68        }
69    }
70}
71
72/// A single condition in a vector search filter.
73#[derive(Clone, Debug)]
74struct FilterCondition {
75    /// SQL fragment, e.g. `t.category = ?` or `t.in_stock = ?`
76    sql: String,
77    /// Bound parameter values (one per `?` placeholder in `sql`)
78    params: Vec<FilterValue>,
79}
80
81/// Type-safe filter builder for vector search queries.
82///
83/// All column names are validated as safe SQL identifiers and all values are
84/// bound via parameterized queries, eliminating SQL injection by construction.
85///
86/// # Example
87///
88/// ```
89/// use azoth_vector::VectorFilter;
90///
91/// let filter = VectorFilter::new()
92///     .eq("category", "electronics")
93///     .eq_i64("in_stock", 1)
94///     .gt("price", "9.99");
95///
96/// let (sql, params) = filter.to_sql().unwrap();
97/// assert_eq!(sql, "t.category = ? AND t.in_stock = ? AND t.price > ?");
98/// assert_eq!(params.len(), 3);
99/// ```
100#[derive(Clone, Debug, Default)]
101pub struct VectorFilter {
102    conditions: Vec<FilterCondition>,
103}
104
105impl VectorFilter {
106    /// Create an empty filter (matches all rows).
107    pub fn new() -> Self {
108        Self::default()
109    }
110
111    /// Add a string equality condition: `t.<column> = ?`
112    pub fn eq(self, column: &str, value: impl Into<String>) -> Self {
113        self.add_op(column, "=", FilterValue::String(value.into()))
114    }
115
116    /// Add a string inequality condition: `t.<column> != ?`
117    pub fn neq(self, column: &str, value: impl Into<String>) -> Self {
118        self.add_op(column, "!=", FilterValue::String(value.into()))
119    }
120
121    /// Add a string greater-than condition: `t.<column> > ?`
122    pub fn gt(self, column: &str, value: impl Into<String>) -> Self {
123        self.add_op(column, ">", FilterValue::String(value.into()))
124    }
125
126    /// Add a string greater-or-equal condition: `t.<column> >= ?`
127    pub fn gte(self, column: &str, value: impl Into<String>) -> Self {
128        self.add_op(column, ">=", FilterValue::String(value.into()))
129    }
130
131    /// Add a string less-than condition: `t.<column> < ?`
132    pub fn lt(self, column: &str, value: impl Into<String>) -> Self {
133        self.add_op(column, "<", FilterValue::String(value.into()))
134    }
135
136    /// Add a string less-or-equal condition: `t.<column> <= ?`
137    pub fn lte(self, column: &str, value: impl Into<String>) -> Self {
138        self.add_op(column, "<=", FilterValue::String(value.into()))
139    }
140
141    /// Add a LIKE condition: `t.<column> LIKE ?`
142    pub fn like(self, column: &str, pattern: impl Into<String>) -> Self {
143        self.add_op(column, "LIKE", FilterValue::String(pattern.into()))
144    }
145
146    /// Add an integer equality condition: `t.<column> = ?`
147    pub fn eq_i64(self, column: &str, value: i64) -> Self {
148        self.add_op(column, "=", FilterValue::I64(value))
149    }
150
151    /// Add an integer greater-than condition: `t.<column> > ?`
152    pub fn gt_i64(self, column: &str, value: i64) -> Self {
153        self.add_op(column, ">", FilterValue::I64(value))
154    }
155
156    /// Add an integer greater-or-equal condition: `t.<column> >= ?`
157    pub fn gte_i64(self, column: &str, value: i64) -> Self {
158        self.add_op(column, ">=", FilterValue::I64(value))
159    }
160
161    /// Add an integer less-than condition: `t.<column> < ?`
162    pub fn lt_i64(self, column: &str, value: i64) -> Self {
163        self.add_op(column, "<", FilterValue::I64(value))
164    }
165
166    /// Add an integer less-or-equal condition: `t.<column> <= ?`
167    pub fn lte_i64(self, column: &str, value: i64) -> Self {
168        self.add_op(column, "<=", FilterValue::I64(value))
169    }
170
171    /// Add a float equality condition: `t.<column> = ?`
172    pub fn eq_f64(self, column: &str, value: f64) -> Self {
173        self.add_op(column, "=", FilterValue::F64(value))
174    }
175
176    /// Add a float greater-than condition: `t.<column> > ?`
177    pub fn gt_f64(self, column: &str, value: f64) -> Self {
178        self.add_op(column, ">", FilterValue::F64(value))
179    }
180
181    /// Add a float less-than condition: `t.<column> < ?`
182    pub fn lt_f64(self, column: &str, value: f64) -> Self {
183        self.add_op(column, "<", FilterValue::F64(value))
184    }
185
186    /// Internal helper: validate column and push a condition.
187    fn add_op(mut self, column: &str, op: &str, value: FilterValue) -> Self {
188        self.conditions.push(FilterCondition {
189            // We store validated column + op; validation happens in to_sql()
190            sql: format!("t.{column} {op} ?"),
191            params: vec![value],
192        });
193        self
194    }
195
196    /// Emit the WHERE clause and its bound parameters.
197    ///
198    /// Returns `("1 = 1", [])` for an empty filter (matches all rows).
199    ///
200    /// # Errors
201    ///
202    /// Returns `AzothError::Config` if any column name fails identifier validation.
203    pub fn to_sql(&self) -> Result<(String, Vec<FilterValue>)> {
204        if self.conditions.is_empty() {
205            return Ok(("1 = 1".to_string(), Vec::new()));
206        }
207
208        // Validate all column names before emitting SQL
209        for cond in &self.conditions {
210            // Extract column name from `t.<col> <op> ?`
211            let col_name = cond
212                .sql
213                .strip_prefix("t.")
214                .and_then(|rest| rest.split_whitespace().next())
215                .unwrap_or("");
216            validate_sql_identifier(col_name, "Filter column")?;
217        }
218
219        let sql_parts: Vec<&str> = self.conditions.iter().map(|c| c.sql.as_str()).collect();
220        let sql = sql_parts.join(" AND ");
221
222        let params: Vec<FilterValue> = self
223            .conditions
224            .iter()
225            .flat_map(|c| c.params.clone())
226            .collect();
227
228        Ok((sql, params))
229    }
230}
231
232/// Vector search builder
233///
234/// Provides k-NN search with optional filtering and custom distance metrics.
235///
236/// # Example
237///
238/// ```no_run
239/// use azoth::prelude::*;
240/// use azoth_vector::{VectorSearch, Vector, DistanceMetric};
241///
242/// # async fn example() -> Result<()> {
243/// let db = AzothDb::open("./data")?;
244///
245/// let query = Vector::new(vec![0.1, 0.2, 0.3]);
246/// let search = VectorSearch::new(db.projection().clone(), "embeddings", "vector")?
247///     .distance_metric(DistanceMetric::Cosine);
248///
249/// let results = search.knn(&query, 10).await?;
250/// # Ok(())
251/// # }
252/// ```
253pub struct VectorSearch {
254    projection: Arc<SqliteProjectionStore>,
255    table: String,
256    column: String,
257    distance_metric: DistanceMetric,
258}
259
260impl VectorSearch {
261    /// Create a new vector search builder
262    ///
263    /// # Arguments
264    ///
265    /// * `projection` - The SQLite projection store
266    /// * `table` - Table name containing the vector column (must be a valid SQL identifier)
267    /// * `column` - Vector column name (must be a valid SQL identifier, initialized with vector_init)
268    ///
269    /// # Errors
270    ///
271    /// Returns an error if `table` or `column` contain characters other than
272    /// ASCII alphanumeric and underscore, or don't start with a letter/underscore.
273    pub fn new(
274        projection: Arc<SqliteProjectionStore>,
275        table: impl Into<String>,
276        column: impl Into<String>,
277    ) -> Result<Self> {
278        let table = table.into();
279        let column = column.into();
280        validate_sql_identifier(&table, "Table")?;
281        validate_sql_identifier(&column, "Column")?;
282        Ok(Self {
283            projection,
284            table,
285            column,
286            distance_metric: DistanceMetric::Cosine,
287        })
288    }
289
290    /// Set the distance metric
291    ///
292    /// Default is Cosine similarity.
293    pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
294        self.distance_metric = metric;
295        self
296    }
297
298    /// Perform k-nearest neighbors search
299    ///
300    /// Returns up to `k` results ordered by similarity (closest first).
301    ///
302    /// # Example
303    ///
304    /// ```no_run
305    /// # use azoth_vector::{VectorSearch, Vector};
306    /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
307    /// let query_vector = Vector::new(vec![0.1, 0.2, 0.3]);
308    /// let results = search.knn(&query_vector, 10).await?;
309    ///
310    /// for result in results {
311    ///     println!("Row {}: distance = {}", result.rowid, result.distance);
312    /// }
313    /// # Ok(())
314    /// # }
315    /// ```
316    pub async fn knn(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
317        // Table and column are validated at construction time via validate_sql_identifier
318        let table = self.table.clone();
319        let column = self.column.clone();
320        let query_json = query.to_json();
321        let k_i64 = k as i64;
322
323        self.projection
324            .query_async(move |conn| {
325                let sql = format!(
326                    "SELECT rowid, distance
327                     FROM vector_quantize_scan('{table}', '{column}', ?, ?)
328                     ORDER BY distance ASC",
329                );
330
331                let mut stmt = conn
332                    .prepare(&sql)
333                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
334
335                let results = stmt
336                    .query_map(params![query_json, k_i64], |row| {
337                        Ok(SearchResult {
338                            rowid: row.get(0)?,
339                            distance: row.get(1)?,
340                        })
341                    })
342                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
343                    .collect::<rusqlite::Result<Vec<_>>>()
344                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
345
346                Ok(results)
347            })
348            .await
349    }
350
351    /// Search with distance threshold
352    ///
353    /// Returns all results within the given distance threshold, up to `k` results.
354    ///
355    /// # Example
356    ///
357    /// ```no_run
358    /// # use azoth_vector::{VectorSearch, Vector};
359    /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
360    /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
361    /// // Only return results with cosine distance < 0.3 (similarity > 0.7)
362    /// let results = search.threshold(&query, 0.3, 100).await?;
363    /// # Ok(())
364    /// # }
365    /// ```
366    pub async fn threshold(
367        &self,
368        query: &Vector,
369        max_distance: f32,
370        k: usize,
371    ) -> Result<Vec<SearchResult>> {
372        let results = self.knn(query, k).await?;
373        Ok(results
374            .into_iter()
375            .filter(|r| r.distance <= max_distance)
376            .collect())
377    }
378
379    /// Search with structured filter conditions
380    ///
381    /// Allows filtering results by additional columns in the table using a
382    /// type-safe [`VectorFilter`] builder. All column names are validated as
383    /// safe SQL identifiers, and all values are bound via parameterized queries,
384    /// preventing SQL injection by construction.
385    ///
386    /// # Example
387    ///
388    /// ```no_run
389    /// # use azoth_vector::{VectorSearch, Vector, VectorFilter};
390    /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
391    /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
392    ///
393    /// let filter = VectorFilter::new()
394    ///     .eq("category", "tech")
395    ///     .eq_i64("in_stock", 1);
396    ///
397    /// let results = search.knn_filtered(&query, 10, &filter).await?;
398    /// # Ok(())
399    /// # }
400    /// ```
401    pub async fn knn_filtered(
402        &self,
403        query: &Vector,
404        k: usize,
405        filter: &VectorFilter,
406    ) -> Result<Vec<SearchResult>> {
407        let (where_clause, filter_params) = filter.to_sql()?;
408
409        let table = self.table.clone();
410        let column = self.column.clone();
411        let query_json = query.to_json();
412        let k_i64 = k as i64;
413
414        self.projection
415            .query_async(move |conn| {
416                let sql = format!(
417                    "SELECT v.rowid, v.distance
418                     FROM vector_quantize_scan('{table}', '{column}', ?, ?) AS v
419                     JOIN {table} AS t ON v.rowid = t.rowid
420                     WHERE {where_clause}
421                     ORDER BY v.distance ASC",
422                );
423
424                let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> =
425                    vec![Box::new(query_json), Box::new(k_i64)];
426                for p in filter_params {
427                    params_vec.push(p.to_boxed_sql());
428                }
429
430                let mut stmt = conn
431                    .prepare(&sql)
432                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
433
434                let results = stmt
435                    .query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
436                        Ok(SearchResult {
437                            rowid: row.get(0)?,
438                            distance: row.get(1)?,
439                        })
440                    })
441                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
442                    .collect::<rusqlite::Result<Vec<_>>>()
443                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
444
445                Ok(results)
446            })
447            .await
448    }
449
450    /// Get multiple results by rowids and include their distances from query
451    ///
452    /// Useful for retrieving full records after search.
453    ///
454    /// # Example
455    ///
456    /// ```no_run
457    /// # use azoth_vector::{VectorSearch, Vector};
458    /// # use azoth_core::Result;
459    /// # async fn example(search: VectorSearch) -> Result<()> {
460    /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
461    /// let results = search.knn(&query, 10).await?;
462    ///
463    /// // Get full records
464    /// for result in results {
465    ///     let record: String = search.projection()
466    ///         .query(|conn: &rusqlite::Connection| {
467    ///             conn.query_row(
468    ///                 "SELECT content FROM embeddings WHERE rowid = ?",
469    ///                 [result.rowid],
470    ///                 |row: &rusqlite::Row| row.get(0),
471    ///             ).map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))
472    ///         })?;
473    ///     println!("Distance: {}, Content: {}", result.distance, record);
474    /// }
475    /// # Ok(())
476    /// # }
477    /// ```
478    pub fn projection(&self) -> &Arc<SqliteProjectionStore> {
479        &self.projection
480    }
481
482    /// Get the table name
483    pub fn table(&self) -> &str {
484        &self.table
485    }
486
487    /// Get the column name
488    pub fn column(&self) -> &str {
489        &self.column
490    }
491
492    /// Get the distance metric
493    pub fn distance_metric_value(&self) -> DistanceMetric {
494        self.distance_metric
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501    use azoth_core::traits::ProjectionStore;
502
503    fn make_store() -> Arc<SqliteProjectionStore> {
504        use tempfile::tempdir;
505        let dir = tempdir().unwrap();
506        let db_path = dir.path().join("test.db");
507
508        let config = azoth_core::ProjectionConfig {
509            path: db_path.clone(),
510            wal_mode: true,
511            synchronous: azoth_core::config::SynchronousMode::Normal,
512            cache_size: -2000,
513            schema_version: 1,
514            read_pool: azoth_core::config::ReadPoolConfig::default(),
515        };
516
517        // Leak the tempdir so it lives long enough for the test
518        std::mem::forget(dir);
519        Arc::new(azoth_sqlite::SqliteProjectionStore::open(config).unwrap())
520    }
521
522    #[test]
523    fn test_search_builder() {
524        let store = make_store();
525
526        let search = VectorSearch::new(store.clone(), "test", "vector")
527            .unwrap()
528            .distance_metric(DistanceMetric::L2);
529
530        assert_eq!(search.table(), "test");
531        assert_eq!(search.column(), "vector");
532        assert_eq!(search.distance_metric_value(), DistanceMetric::L2);
533    }
534
535    #[test]
536    fn test_identifier_validation_rejects_injection() {
537        let store = make_store();
538
539        // SQL injection in table name should be rejected
540        let result = VectorSearch::new(store.clone(), "x; DROP TABLE y; --", "vector");
541        assert!(result.is_err());
542
543        // SQL injection in column name should be rejected
544        let result = VectorSearch::new(store.clone(), "test", "v'; DROP TABLE y; --");
545        assert!(result.is_err());
546
547        // Empty names should be rejected
548        let result = VectorSearch::new(store.clone(), "", "vector");
549        assert!(result.is_err());
550
551        // Names starting with digits should be rejected
552        let result = VectorSearch::new(store.clone(), "123table", "vector");
553        assert!(result.is_err());
554
555        // Valid identifiers should work
556        let result = VectorSearch::new(store.clone(), "my_table", "embedding_col");
557        assert!(result.is_ok());
558
559        let result = VectorSearch::new(store.clone(), "_private", "_col");
560        assert!(result.is_ok());
561    }
562
563    // Full integration tests with vector extension in tests/ directory
564}