1use serde::{Deserialize, Serialize};
22
23use crate::ops::DistanceMetric;
24use crate::types::Embedding;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct VectorFilter {
29 pub column: String,
31 pub query_vector: Embedding,
33 pub metric: DistanceMetric,
35 pub filter_type: VectorFilterType,
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
41#[non_exhaustive]
42pub enum VectorFilterType {
43 Nearest {
45 limit: usize,
47 },
48
49 WithinDistance {
51 max_distance: f64,
53 limit: Option<usize>,
55 },
56
57 DistanceRange {
59 min_distance: f64,
61 max_distance: f64,
63 limit: Option<usize>,
65 },
66}
67
68impl VectorFilter {
69 pub fn nearest(
73 column: impl Into<String>,
74 query_vector: Embedding,
75 metric: DistanceMetric,
76 limit: usize,
77 ) -> Self {
78 Self {
79 column: column.into(),
80 query_vector,
81 metric,
82 filter_type: VectorFilterType::Nearest { limit },
83 }
84 }
85
86 pub fn within_distance(
90 column: impl Into<String>,
91 query_vector: Embedding,
92 metric: DistanceMetric,
93 max_distance: f64,
94 ) -> Self {
95 Self {
96 column: column.into(),
97 query_vector,
98 metric,
99 filter_type: VectorFilterType::WithinDistance {
100 max_distance,
101 limit: None,
102 },
103 }
104 }
105
106 pub fn distance_range(
108 column: impl Into<String>,
109 query_vector: Embedding,
110 metric: DistanceMetric,
111 min_distance: f64,
112 max_distance: f64,
113 ) -> Self {
114 Self {
115 column: column.into(),
116 query_vector,
117 metric,
118 filter_type: VectorFilterType::DistanceRange {
119 min_distance,
120 max_distance,
121 limit: None,
122 },
123 }
124 }
125
126 pub fn with_limit(mut self, limit: usize) -> Self {
128 match &mut self.filter_type {
129 VectorFilterType::Nearest { limit: l } => *l = limit,
130 VectorFilterType::WithinDistance { limit: l, .. } => *l = Some(limit),
131 VectorFilterType::DistanceRange { limit: l, .. } => *l = Some(limit),
132 }
133 self
134 }
135
136 pub fn distance_expr_sql(&self, param_index: usize) -> String {
140 format!(
141 "{} {} ${}",
142 self.column,
143 self.metric.operator(),
144 param_index
145 )
146 }
147
148 pub fn where_sql(&self, param_index: usize) -> Option<String> {
152 let distance_expr = self.distance_expr_sql(param_index);
153
154 match &self.filter_type {
155 VectorFilterType::Nearest { .. } => None,
156 VectorFilterType::WithinDistance { max_distance, .. } => {
157 Some(format!("{distance_expr} < {max_distance}"))
158 }
159 VectorFilterType::DistanceRange {
160 min_distance,
161 max_distance,
162 ..
163 } => Some(format!(
164 "{distance_expr} BETWEEN {min_distance} AND {max_distance}"
165 )),
166 }
167 }
168
169 pub fn order_by_sql(&self, param_index: usize) -> String {
171 self.distance_expr_sql(param_index)
172 }
173
174 pub fn limit_sql(&self) -> Option<String> {
176 let limit = match &self.filter_type {
177 VectorFilterType::Nearest { limit } => Some(*limit),
178 VectorFilterType::WithinDistance { limit, .. } => *limit,
179 VectorFilterType::DistanceRange { limit, .. } => *limit,
180 };
181
182 limit.map(|l| format!("LIMIT {l}"))
183 }
184
185 pub fn to_select_sql(
196 &self,
197 table: &str,
198 param_index: usize,
199 extra_where: Option<&str>,
200 select_columns: &str,
201 ) -> String {
202 let distance_expr = self.distance_expr_sql(param_index);
203
204 let mut sql = format!(
205 "SELECT {}, {} AS distance FROM {}",
206 select_columns, distance_expr, table
207 );
208
209 let mut where_parts = Vec::new();
211 if let Some(vec_where) = self.where_sql(param_index) {
212 where_parts.push(vec_where);
213 }
214 if let Some(extra) = extra_where {
215 where_parts.push(extra.to_string());
216 }
217 if !where_parts.is_empty() {
218 sql.push_str(&format!(" WHERE {}", where_parts.join(" AND ")));
219 }
220
221 sql.push_str(&format!(" ORDER BY {}", self.order_by_sql(param_index)));
223
224 if let Some(limit) = self.limit_sql() {
226 sql.push_str(&format!(" {limit}"));
227 }
228
229 sql
230 }
231}
232
233#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct VectorOrderBy {
236 pub column: String,
238 pub query_vector: Embedding,
240 pub metric: DistanceMetric,
242 pub include_distance: bool,
244 pub distance_alias: String,
246}
247
248impl VectorOrderBy {
249 pub fn new(column: impl Into<String>, query_vector: Embedding, metric: DistanceMetric) -> Self {
251 Self {
252 column: column.into(),
253 query_vector,
254 metric,
255 include_distance: true,
256 distance_alias: "distance".to_string(),
257 }
258 }
259
260 pub fn alias(mut self, alias: impl Into<String>) -> Self {
262 self.distance_alias = alias.into();
263 self
264 }
265
266 pub fn without_distance(mut self) -> Self {
268 self.include_distance = false;
269 self
270 }
271
272 pub fn select_distance_sql(&self, param_index: usize) -> Option<String> {
274 if self.include_distance {
275 Some(format!(
276 "{} {} ${} AS {}",
277 self.column,
278 self.metric.operator(),
279 param_index,
280 self.distance_alias
281 ))
282 } else {
283 None
284 }
285 }
286
287 pub fn order_by_sql(&self, param_index: usize) -> String {
289 if self.include_distance {
290 self.distance_alias.clone()
291 } else {
292 format!(
293 "{} {} ${}",
294 self.column,
295 self.metric.operator(),
296 param_index
297 )
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 fn test_embedding() -> Embedding {
307 Embedding::new(vec![0.1, 0.2, 0.3])
308 }
309
310 #[test]
311 fn test_nearest_filter() {
312 let filter =
313 VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 10);
314 assert!(filter.where_sql(1).is_none());
315 assert_eq!(filter.order_by_sql(1), "embedding <=> $1");
316 assert_eq!(filter.limit_sql(), Some("LIMIT 10".to_string()));
317 }
318
319 #[test]
320 fn test_within_distance_filter() {
321 let filter =
322 VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5);
323 let where_sql = filter.where_sql(1).unwrap();
324 assert!(where_sql.contains("<->"));
325 assert!(where_sql.contains("< 0.5"));
326 }
327
328 #[test]
329 fn test_distance_range_filter() {
330 let filter = VectorFilter::distance_range(
331 "embedding",
332 test_embedding(),
333 DistanceMetric::L2,
334 0.1,
335 0.5,
336 );
337 let where_sql = filter.where_sql(1).unwrap();
338 assert!(where_sql.contains("BETWEEN"));
339 assert!(where_sql.contains("0.1"));
340 assert!(where_sql.contains("0.5"));
341 }
342
343 #[test]
344 fn test_filter_with_limit() {
345 let filter =
346 VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
347 .with_limit(50);
348
349 assert_eq!(filter.limit_sql(), Some("LIMIT 50".to_string()));
350 }
351
352 #[test]
353 fn test_to_select_sql_nearest() {
354 let filter =
355 VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 5);
356 let sql = filter.to_select_sql("documents", 1, None, "*");
357
358 assert!(sql.contains("SELECT *, embedding <=> $1 AS distance"));
359 assert!(sql.contains("FROM documents"));
360 assert!(sql.contains("ORDER BY"));
361 assert!(sql.contains("LIMIT 5"));
362 assert!(!sql.contains("WHERE")); }
364
365 #[test]
366 fn test_to_select_sql_with_extra_where() {
367 let filter =
368 VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
369 .with_limit(20);
370
371 let sql = filter.to_select_sql("documents", 1, Some("category = 'tech'"), "*");
372 assert!(sql.contains("WHERE"));
373 assert!(sql.contains("< 0.5"));
374 assert!(sql.contains("category = 'tech'"));
375 assert!(sql.contains("AND"));
376 }
377
378 #[test]
379 fn test_vector_order_by() {
380 let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine);
381 assert!(order.include_distance);
382
383 let select = order.select_distance_sql(1).unwrap();
384 assert!(select.contains("<=>"));
385 assert!(select.contains("AS distance"));
386
387 let order_by = order.order_by_sql(1);
388 assert_eq!(order_by, "distance");
389 }
390
391 #[test]
392 fn test_vector_order_by_without_distance() {
393 let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::L2)
394 .without_distance();
395
396 assert!(order.select_distance_sql(1).is_none());
397 let order_by = order.order_by_sql(1);
398 assert!(order_by.contains("<->"));
399 }
400
401 #[test]
402 fn test_vector_order_by_custom_alias() {
403 let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine)
404 .alias("similarity");
405
406 let select = order.select_distance_sql(1).unwrap();
407 assert!(select.contains("AS similarity"));
408 }
409
410 #[test]
411 fn test_distance_expr_sql() {
412 let filter =
413 VectorFilter::nearest("emb", test_embedding(), DistanceMetric::InnerProduct, 5);
414 let expr = filter.distance_expr_sql(2);
415 assert_eq!(expr, "emb <#> $2");
416 }
417}