Skip to main content

azoth_vector/
search.rs

1//! Vector similarity search API
2
3use crate::types::{DistanceMetric, SearchResult, Vector};
4use azoth_core::Result;
5use azoth_sqlite::SqliteProjectionStore;
6use rusqlite::params;
7use std::sync::Arc;
8
9/// Vector search builder
10///
11/// Provides k-NN search with optional filtering and custom distance metrics.
12///
13/// # Example
14///
15/// ```no_run
16/// use azoth::prelude::*;
17/// use azoth_vector::{VectorSearch, Vector, DistanceMetric};
18///
19/// # async fn example() -> Result<()> {
20/// let db = AzothDb::open("./data")?;
21///
22/// let query = Vector::new(vec![0.1, 0.2, 0.3]);
23/// let search = VectorSearch::new(db.projection().clone(), "embeddings", "vector")
24///     .distance_metric(DistanceMetric::Cosine);
25///
26/// let results = search.knn(&query, 10).await?;
27/// # Ok(())
28/// # }
29/// ```
30pub struct VectorSearch {
31    projection: Arc<SqliteProjectionStore>,
32    table: String,
33    column: String,
34    distance_metric: DistanceMetric,
35}
36
37impl VectorSearch {
38    /// Create a new vector search builder
39    ///
40    /// # Arguments
41    ///
42    /// * `projection` - The SQLite projection store
43    /// * `table` - Table name containing the vector column
44    /// * `column` - Vector column name (must be initialized with vector_init)
45    pub fn new(
46        projection: Arc<SqliteProjectionStore>,
47        table: impl Into<String>,
48        column: impl Into<String>,
49    ) -> Self {
50        Self {
51            projection,
52            table: table.into(),
53            column: column.into(),
54            distance_metric: DistanceMetric::Cosine,
55        }
56    }
57
58    /// Set the distance metric
59    ///
60    /// Default is Cosine similarity.
61    pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
62        self.distance_metric = metric;
63        self
64    }
65
66    /// Perform k-nearest neighbors search
67    ///
68    /// Returns up to `k` results ordered by similarity (closest first).
69    ///
70    /// # Example
71    ///
72    /// ```no_run
73    /// # use azoth_vector::{VectorSearch, Vector};
74    /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
75    /// let query_vector = Vector::new(vec![0.1, 0.2, 0.3]);
76    /// let results = search.knn(&query_vector, 10).await?;
77    ///
78    /// for result in results {
79    ///     println!("Row {}: distance = {}", result.rowid, result.distance);
80    /// }
81    /// # Ok(())
82    /// # }
83    /// ```
84    pub async fn knn(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
85        let table = self.table.clone();
86        let column = self.column.clone();
87        let query_json = query.to_json();
88        let k_i64 = k as i64;
89
90        self.projection
91            .query_async(move |conn| {
92                let sql = format!(
93                    "SELECT rowid, distance
94                     FROM vector_quantize_scan('{}', '{}', ?, ?)
95                     ORDER BY distance ASC",
96                    table.replace('\'', "''"),
97                    column.replace('\'', "''")
98                );
99
100                let mut stmt = conn
101                    .prepare(&sql)
102                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
103
104                let results = stmt
105                    .query_map(params![query_json, k_i64], |row| {
106                        Ok(SearchResult {
107                            rowid: row.get(0)?,
108                            distance: row.get(1)?,
109                        })
110                    })
111                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
112                    .collect::<rusqlite::Result<Vec<_>>>()
113                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
114
115                Ok(results)
116            })
117            .await
118    }
119
120    /// Search with distance threshold
121    ///
122    /// Returns all results within the given distance threshold, up to `k` results.
123    ///
124    /// # Example
125    ///
126    /// ```no_run
127    /// # use azoth_vector::{VectorSearch, Vector};
128    /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
129    /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
130    /// // Only return results with cosine distance < 0.3 (similarity > 0.7)
131    /// let results = search.threshold(&query, 0.3, 100).await?;
132    /// # Ok(())
133    /// # }
134    /// ```
135    pub async fn threshold(
136        &self,
137        query: &Vector,
138        max_distance: f32,
139        k: usize,
140    ) -> Result<Vec<SearchResult>> {
141        let results = self.knn(query, k).await?;
142        Ok(results
143            .into_iter()
144            .filter(|r| r.distance <= max_distance)
145            .collect())
146    }
147
148    /// Search with custom SQL filter
149    ///
150    /// Allows filtering results by additional columns in the table.
151    ///
152    /// # Example
153    ///
154    /// ```no_run
155    /// # use azoth_vector::{VectorSearch, Vector};
156    /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
157    /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
158    ///
159    /// // Only search within a specific category
160    /// let results = search
161    ///     .knn_filtered(&query, 10, "category = ?", vec!["tech".to_string()])
162    ///     .await?;
163    /// # Ok(())
164    /// # }
165    /// ```
166    pub async fn knn_filtered(
167        &self,
168        query: &Vector,
169        k: usize,
170        filter: &str,
171        filter_params: Vec<String>,
172    ) -> Result<Vec<SearchResult>> {
173        let table = self.table.clone();
174        let column = self.column.clone();
175        let query_json = query.to_json();
176        let k_i64 = k as i64;
177        let filter = filter.to_string();
178
179        self.projection
180            .query_async(move |conn| {
181                let sql = format!(
182                    "SELECT v.rowid, v.distance
183                     FROM vector_quantize_scan('{}', '{}', ?, ?) AS v
184                     JOIN {} AS t ON v.rowid = t.rowid
185                     WHERE {}
186                     ORDER BY v.distance ASC",
187                    table.replace('\'', "''"),
188                    column.replace('\'', "''"),
189                    table.replace('\'', "''"),
190                    filter
191                );
192
193                let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> =
194                    vec![Box::new(query_json), Box::new(k_i64)];
195                for p in filter_params {
196                    params_vec.push(Box::new(p));
197                }
198
199                let mut stmt = conn
200                    .prepare(&sql)
201                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
202
203                let results = stmt
204                    .query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
205                        Ok(SearchResult {
206                            rowid: row.get(0)?,
207                            distance: row.get(1)?,
208                        })
209                    })
210                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
211                    .collect::<rusqlite::Result<Vec<_>>>()
212                    .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
213
214                Ok(results)
215            })
216            .await
217    }
218
219    /// Get multiple results by rowids and include their distances from query
220    ///
221    /// Useful for retrieving full records after search.
222    ///
223    /// # Example
224    ///
225    /// ```no_run
226    /// # use azoth_vector::{VectorSearch, Vector};
227    /// # use azoth_core::Result;
228    /// # async fn example(search: VectorSearch) -> Result<()> {
229    /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
230    /// let results = search.knn(&query, 10).await?;
231    ///
232    /// // Get full records
233    /// for result in results {
234    ///     let record: String = search.projection()
235    ///         .query(|conn: &rusqlite::Connection| {
236    ///             conn.query_row(
237    ///                 "SELECT content FROM embeddings WHERE rowid = ?",
238    ///                 [result.rowid],
239    ///                 |row: &rusqlite::Row| row.get(0),
240    ///             ).map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))
241    ///         })?;
242    ///     println!("Distance: {}, Content: {}", result.distance, record);
243    /// }
244    /// # Ok(())
245    /// # }
246    /// ```
247    pub fn projection(&self) -> &Arc<SqliteProjectionStore> {
248        &self.projection
249    }
250
251    /// Get the table name
252    pub fn table(&self) -> &str {
253        &self.table
254    }
255
256    /// Get the column name
257    pub fn column(&self) -> &str {
258        &self.column
259    }
260
261    /// Get the distance metric
262    pub fn distance_metric_value(&self) -> DistanceMetric {
263        self.distance_metric
264    }
265}
266
267#[cfg(test)]
268mod tests {
269    use super::*;
270    use azoth_core::traits::ProjectionStore;
271
272    #[test]
273    fn test_search_builder() {
274        use tempfile::tempdir;
275        let dir = tempdir().unwrap();
276        let db_path = dir.path().join("test.db");
277
278        let config = azoth_core::ProjectionConfig {
279            path: db_path.clone(),
280            wal_mode: true,
281            synchronous: azoth_core::config::SynchronousMode::Normal,
282            cache_size: -2000,
283            schema_version: 1,
284        };
285
286        let store = Arc::new(azoth_sqlite::SqliteProjectionStore::open(config).unwrap());
287
288        let search =
289            VectorSearch::new(store.clone(), "test", "vector").distance_metric(DistanceMetric::L2);
290
291        assert_eq!(search.table(), "test");
292        assert_eq!(search.column(), "vector");
293        assert_eq!(search.distance_metric_value(), DistanceMetric::L2);
294    }
295
296    // Full integration tests with vector extension in tests/ directory
297}