azoth_vector/search.rs
1//! Vector similarity search API
2
3use crate::types::{DistanceMetric, SearchResult, Vector};
4use azoth_core::Result;
5use azoth_sqlite::SqliteProjectionStore;
6use rusqlite::params;
7use std::sync::Arc;
8
9/// Vector search builder
10///
11/// Provides k-NN search with optional filtering and custom distance metrics.
12///
13/// # Example
14///
15/// ```no_run
16/// use azoth::prelude::*;
17/// use azoth_vector::{VectorSearch, Vector, DistanceMetric};
18///
19/// # async fn example() -> Result<()> {
20/// let db = AzothDb::open("./data")?;
21///
22/// let query = Vector::new(vec![0.1, 0.2, 0.3]);
23/// let search = VectorSearch::new(db.projection().clone(), "embeddings", "vector")
24/// .distance_metric(DistanceMetric::Cosine);
25///
26/// let results = search.knn(&query, 10).await?;
27/// # Ok(())
28/// # }
29/// ```
30pub struct VectorSearch {
31 projection: Arc<SqliteProjectionStore>,
32 table: String,
33 column: String,
34 distance_metric: DistanceMetric,
35}
36
37impl VectorSearch {
38 /// Create a new vector search builder
39 ///
40 /// # Arguments
41 ///
42 /// * `projection` - The SQLite projection store
43 /// * `table` - Table name containing the vector column
44 /// * `column` - Vector column name (must be initialized with vector_init)
45 pub fn new(
46 projection: Arc<SqliteProjectionStore>,
47 table: impl Into<String>,
48 column: impl Into<String>,
49 ) -> Self {
50 Self {
51 projection,
52 table: table.into(),
53 column: column.into(),
54 distance_metric: DistanceMetric::Cosine,
55 }
56 }
57
58 /// Set the distance metric
59 ///
60 /// Default is Cosine similarity.
61 pub fn distance_metric(mut self, metric: DistanceMetric) -> Self {
62 self.distance_metric = metric;
63 self
64 }
65
66 /// Perform k-nearest neighbors search
67 ///
68 /// Returns up to `k` results ordered by similarity (closest first).
69 ///
70 /// # Example
71 ///
72 /// ```no_run
73 /// # use azoth_vector::{VectorSearch, Vector};
74 /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
75 /// let query_vector = Vector::new(vec![0.1, 0.2, 0.3]);
76 /// let results = search.knn(&query_vector, 10).await?;
77 ///
78 /// for result in results {
79 /// println!("Row {}: distance = {}", result.rowid, result.distance);
80 /// }
81 /// # Ok(())
82 /// # }
83 /// ```
84 pub async fn knn(&self, query: &Vector, k: usize) -> Result<Vec<SearchResult>> {
85 let table = self.table.clone();
86 let column = self.column.clone();
87 let query_json = query.to_json();
88 let k_i64 = k as i64;
89
90 self.projection
91 .query_async(move |conn| {
92 let sql = format!(
93 "SELECT rowid, distance
94 FROM vector_quantize_scan('{}', '{}', ?, ?)
95 ORDER BY distance ASC",
96 table.replace('\'', "''"),
97 column.replace('\'', "''")
98 );
99
100 let mut stmt = conn
101 .prepare(&sql)
102 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
103
104 let results = stmt
105 .query_map(params![query_json, k_i64], |row| {
106 Ok(SearchResult {
107 rowid: row.get(0)?,
108 distance: row.get(1)?,
109 })
110 })
111 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
112 .collect::<rusqlite::Result<Vec<_>>>()
113 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
114
115 Ok(results)
116 })
117 .await
118 }
119
120 /// Search with distance threshold
121 ///
122 /// Returns all results within the given distance threshold, up to `k` results.
123 ///
124 /// # Example
125 ///
126 /// ```no_run
127 /// # use azoth_vector::{VectorSearch, Vector};
128 /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
129 /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
130 /// // Only return results with cosine distance < 0.3 (similarity > 0.7)
131 /// let results = search.threshold(&query, 0.3, 100).await?;
132 /// # Ok(())
133 /// # }
134 /// ```
135 pub async fn threshold(
136 &self,
137 query: &Vector,
138 max_distance: f32,
139 k: usize,
140 ) -> Result<Vec<SearchResult>> {
141 let results = self.knn(query, k).await?;
142 Ok(results
143 .into_iter()
144 .filter(|r| r.distance <= max_distance)
145 .collect())
146 }
147
148 /// Search with custom SQL filter
149 ///
150 /// Allows filtering results by additional columns in the table.
151 ///
152 /// # Example
153 ///
154 /// ```no_run
155 /// # use azoth_vector::{VectorSearch, Vector};
156 /// # async fn example(search: VectorSearch) -> Result<(), Box<dyn std::error::Error>> {
157 /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
158 ///
159 /// // Only search within a specific category
160 /// let results = search
161 /// .knn_filtered(&query, 10, "category = ?", vec!["tech".to_string()])
162 /// .await?;
163 /// # Ok(())
164 /// # }
165 /// ```
166 pub async fn knn_filtered(
167 &self,
168 query: &Vector,
169 k: usize,
170 filter: &str,
171 filter_params: Vec<String>,
172 ) -> Result<Vec<SearchResult>> {
173 let table = self.table.clone();
174 let column = self.column.clone();
175 let query_json = query.to_json();
176 let k_i64 = k as i64;
177 let filter = filter.to_string();
178
179 self.projection
180 .query_async(move |conn| {
181 let sql = format!(
182 "SELECT v.rowid, v.distance
183 FROM vector_quantize_scan('{}', '{}', ?, ?) AS v
184 JOIN {} AS t ON v.rowid = t.rowid
185 WHERE {}
186 ORDER BY v.distance ASC",
187 table.replace('\'', "''"),
188 column.replace('\'', "''"),
189 table.replace('\'', "''"),
190 filter
191 );
192
193 let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> =
194 vec![Box::new(query_json), Box::new(k_i64)];
195 for p in filter_params {
196 params_vec.push(Box::new(p));
197 }
198
199 let mut stmt = conn
200 .prepare(&sql)
201 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
202
203 let results = stmt
204 .query_map(rusqlite::params_from_iter(params_vec.iter()), |row| {
205 Ok(SearchResult {
206 rowid: row.get(0)?,
207 distance: row.get(1)?,
208 })
209 })
210 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?
211 .collect::<rusqlite::Result<Vec<_>>>()
212 .map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))?;
213
214 Ok(results)
215 })
216 .await
217 }
218
219 /// Get multiple results by rowids and include their distances from query
220 ///
221 /// Useful for retrieving full records after search.
222 ///
223 /// # Example
224 ///
225 /// ```no_run
226 /// # use azoth_vector::{VectorSearch, Vector};
227 /// # use azoth_core::Result;
228 /// # async fn example(search: VectorSearch) -> Result<()> {
229 /// let query = Vector::new(vec![0.1, 0.2, 0.3]);
230 /// let results = search.knn(&query, 10).await?;
231 ///
232 /// // Get full records
233 /// for result in results {
234 /// let record: String = search.projection()
235 /// .query(|conn: &rusqlite::Connection| {
236 /// conn.query_row(
237 /// "SELECT content FROM embeddings WHERE rowid = ?",
238 /// [result.rowid],
239 /// |row: &rusqlite::Row| row.get(0),
240 /// ).map_err(|e| azoth_core::error::AzothError::Projection(e.to_string()))
241 /// })?;
242 /// println!("Distance: {}, Content: {}", result.distance, record);
243 /// }
244 /// # Ok(())
245 /// # }
246 /// ```
247 pub fn projection(&self) -> &Arc<SqliteProjectionStore> {
248 &self.projection
249 }
250
251 /// Get the table name
252 pub fn table(&self) -> &str {
253 &self.table
254 }
255
256 /// Get the column name
257 pub fn column(&self) -> &str {
258 &self.column
259 }
260
261 /// Get the distance metric
262 pub fn distance_metric_value(&self) -> DistanceMetric {
263 self.distance_metric
264 }
265}
266
267#[cfg(test)]
268mod tests {
269 use super::*;
270 use azoth_core::traits::ProjectionStore;
271
272 #[test]
273 fn test_search_builder() {
274 use tempfile::tempdir;
275 let dir = tempdir().unwrap();
276 let db_path = dir.path().join("test.db");
277
278 let config = azoth_core::ProjectionConfig {
279 path: db_path.clone(),
280 wal_mode: true,
281 synchronous: azoth_core::config::SynchronousMode::Normal,
282 cache_size: -2000,
283 schema_version: 1,
284 read_pool: azoth_core::config::ReadPoolConfig::default(),
285 };
286
287 let store = Arc::new(azoth_sqlite::SqliteProjectionStore::open(config).unwrap());
288
289 let search =
290 VectorSearch::new(store.clone(), "test", "vector").distance_metric(DistanceMetric::L2);
291
292 assert_eq!(search.table(), "test");
293 assert_eq!(search.column(), "vector");
294 assert_eq!(search.distance_metric_value(), DistanceMetric::L2);
295 }
296
297 // Full integration tests with vector extension in tests/ directory
298}