Skip to main content

prax_pgvector/
filter.rs

1//! Vector filter operations for integration with the prax query builder.
2//!
3//! This module provides filter types that can be used with prax-query's
4//! filter system to perform vector similarity searches as part of WHERE clauses.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use prax_pgvector::filter::{VectorFilter, VectorOrderBy};
10//! use prax_pgvector::{Embedding, DistanceMetric};
11//!
12//! // Create a nearest-neighbor filter
13//! let query_vec = Embedding::new(vec![0.1, 0.2, 0.3]);
14//! let filter = VectorFilter::nearest("embedding", query_vec, DistanceMetric::Cosine, 10);
15//!
16//! // Create a distance-filtered search
17//! let query_vec = Embedding::new(vec![0.1, 0.2, 0.3]);
18//! let filter = VectorFilter::within_distance("embedding", query_vec, DistanceMetric::L2, 0.5);
19//! ```
20
21use serde::{Deserialize, Serialize};
22
23use crate::ops::DistanceMetric;
24use crate::types::Embedding;
25
26/// A vector filter operation for use in WHERE and ORDER BY clauses.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct VectorFilter {
29    /// Column containing the vector.
30    pub column: String,
31    /// Query vector to compare against.
32    pub query_vector: Embedding,
33    /// Distance metric to use.
34    pub metric: DistanceMetric,
35    /// Type of vector filter.
36    pub filter_type: VectorFilterType,
37}
38
39/// The type of vector filter operation.
40#[derive(Debug, Clone, Serialize, Deserialize)]
41#[non_exhaustive]
42pub enum VectorFilterType {
43    /// K-nearest neighbor search (ORDER BY distance LIMIT k).
44    Nearest {
45        /// Maximum number of results to return.
46        limit: usize,
47    },
48
49    /// Distance-based filter (WHERE distance < threshold).
50    WithinDistance {
51        /// Maximum distance threshold.
52        max_distance: f64,
53        /// Optional result limit.
54        limit: Option<usize>,
55    },
56
57    /// Distance range filter (WHERE distance BETWEEN min AND max).
58    DistanceRange {
59        /// Minimum distance.
60        min_distance: f64,
61        /// Maximum distance.
62        max_distance: f64,
63        /// Optional result limit.
64        limit: Option<usize>,
65    },
66}
67
68impl VectorFilter {
69    /// Create a k-nearest neighbor filter.
70    ///
71    /// This generates an ORDER BY with the vector distance operator and LIMIT.
72    pub fn nearest(
73        column: impl Into<String>,
74        query_vector: Embedding,
75        metric: DistanceMetric,
76        limit: usize,
77    ) -> Self {
78        Self {
79            column: column.into(),
80            query_vector,
81            metric,
82            filter_type: VectorFilterType::Nearest { limit },
83        }
84    }
85
86    /// Create a distance-based filter.
87    ///
88    /// This generates a WHERE clause filtering by maximum distance.
89    pub fn within_distance(
90        column: impl Into<String>,
91        query_vector: Embedding,
92        metric: DistanceMetric,
93        max_distance: f64,
94    ) -> Self {
95        Self {
96            column: column.into(),
97            query_vector,
98            metric,
99            filter_type: VectorFilterType::WithinDistance {
100                max_distance,
101                limit: None,
102            },
103        }
104    }
105
106    /// Create a distance range filter.
107    pub fn distance_range(
108        column: impl Into<String>,
109        query_vector: Embedding,
110        metric: DistanceMetric,
111        min_distance: f64,
112        max_distance: f64,
113    ) -> Self {
114        Self {
115            column: column.into(),
116            query_vector,
117            metric,
118            filter_type: VectorFilterType::DistanceRange {
119                min_distance,
120                max_distance,
121                limit: None,
122            },
123        }
124    }
125
126    /// Add a limit to this filter.
127    pub fn with_limit(mut self, limit: usize) -> Self {
128        match &mut self.filter_type {
129            VectorFilterType::Nearest { limit: l } => *l = limit,
130            VectorFilterType::WithinDistance { limit: l, .. } => *l = Some(limit),
131            VectorFilterType::DistanceRange { limit: l, .. } => *l = Some(limit),
132        }
133        self
134    }
135
136    /// Generate the distance expression SQL fragment.
137    ///
138    /// Returns something like: `embedding <=> $1`
139    pub fn distance_expr_sql(&self, param_index: usize) -> String {
140        format!(
141            "{} {} ${}",
142            self.column,
143            self.metric.operator(),
144            param_index
145        )
146    }
147
148    /// Generate the WHERE clause SQL fragment.
149    ///
150    /// Returns `None` for nearest-neighbor searches (which only use ORDER BY).
151    pub fn where_sql(&self, param_index: usize) -> Option<String> {
152        let distance_expr = self.distance_expr_sql(param_index);
153
154        match &self.filter_type {
155            VectorFilterType::Nearest { .. } => None,
156            VectorFilterType::WithinDistance { max_distance, .. } => {
157                Some(format!("{distance_expr} < {max_distance}"))
158            }
159            VectorFilterType::DistanceRange {
160                min_distance,
161                max_distance,
162                ..
163            } => Some(format!(
164                "{distance_expr} BETWEEN {min_distance} AND {max_distance}"
165            )),
166        }
167    }
168
169    /// Generate the ORDER BY clause SQL fragment.
170    pub fn order_by_sql(&self, param_index: usize) -> String {
171        self.distance_expr_sql(param_index)
172    }
173
174    /// Generate the LIMIT clause.
175    pub fn limit_sql(&self) -> Option<String> {
176        let limit = match &self.filter_type {
177            VectorFilterType::Nearest { limit } => Some(*limit),
178            VectorFilterType::WithinDistance { limit, .. } => *limit,
179            VectorFilterType::DistanceRange { limit, .. } => *limit,
180        };
181
182        limit.map(|l| format!("LIMIT {l}"))
183    }
184
185    /// Generate the complete SELECT query incorporating this vector filter.
186    ///
187    /// This produces a query like:
188    /// ```sql
189    /// SELECT *, embedding <=> $1 AS distance
190    /// FROM documents
191    /// WHERE embedding <=> $1 < 0.5
192    /// ORDER BY distance
193    /// LIMIT 10
194    /// ```
195    pub fn to_select_sql(
196        &self,
197        table: &str,
198        param_index: usize,
199        extra_where: Option<&str>,
200        select_columns: &str,
201    ) -> String {
202        let distance_expr = self.distance_expr_sql(param_index);
203
204        let mut sql = format!(
205            "SELECT {}, {} AS distance FROM {}",
206            select_columns, distance_expr, table
207        );
208
209        // WHERE clause
210        let mut where_parts = Vec::new();
211        if let Some(vec_where) = self.where_sql(param_index) {
212            where_parts.push(vec_where);
213        }
214        if let Some(extra) = extra_where {
215            where_parts.push(extra.to_string());
216        }
217        if !where_parts.is_empty() {
218            sql.push_str(&format!(" WHERE {}", where_parts.join(" AND ")));
219        }
220
221        // ORDER BY
222        sql.push_str(&format!(" ORDER BY {}", self.order_by_sql(param_index)));
223
224        // LIMIT
225        if let Some(limit) = self.limit_sql() {
226            sql.push_str(&format!(" {limit}"));
227        }
228
229        sql
230    }
231}
232
233/// Vector ordering specification for use with query builders.
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct VectorOrderBy {
236    /// Column containing the vector.
237    pub column: String,
238    /// Query vector to compare against.
239    pub query_vector: Embedding,
240    /// Distance metric.
241    pub metric: DistanceMetric,
242    /// Whether to include the distance as a result column.
243    pub include_distance: bool,
244    /// Alias for the distance column.
245    pub distance_alias: String,
246}
247
248impl VectorOrderBy {
249    /// Create a new vector ordering.
250    pub fn new(column: impl Into<String>, query_vector: Embedding, metric: DistanceMetric) -> Self {
251        Self {
252            column: column.into(),
253            query_vector,
254            metric,
255            include_distance: true,
256            distance_alias: "distance".to_string(),
257        }
258    }
259
260    /// Set the distance column alias.
261    pub fn alias(mut self, alias: impl Into<String>) -> Self {
262        self.distance_alias = alias.into();
263        self
264    }
265
266    /// Don't include the distance as a result column.
267    pub fn without_distance(mut self) -> Self {
268        self.include_distance = false;
269        self
270    }
271
272    /// Generate the SELECT addition for the distance column.
273    pub fn select_distance_sql(&self, param_index: usize) -> Option<String> {
274        if self.include_distance {
275            Some(format!(
276                "{} {} ${} AS {}",
277                self.column,
278                self.metric.operator(),
279                param_index,
280                self.distance_alias
281            ))
282        } else {
283            None
284        }
285    }
286
287    /// Generate the ORDER BY clause.
288    pub fn order_by_sql(&self, param_index: usize) -> String {
289        if self.include_distance {
290            self.distance_alias.clone()
291        } else {
292            format!(
293                "{} {} ${}",
294                self.column,
295                self.metric.operator(),
296                param_index
297            )
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    fn test_embedding() -> Embedding {
307        Embedding::new(vec![0.1, 0.2, 0.3])
308    }
309
310    #[test]
311    fn test_nearest_filter() {
312        let filter =
313            VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 10);
314        assert!(filter.where_sql(1).is_none());
315        assert_eq!(filter.order_by_sql(1), "embedding <=> $1");
316        assert_eq!(filter.limit_sql(), Some("LIMIT 10".to_string()));
317    }
318
319    #[test]
320    fn test_within_distance_filter() {
321        let filter =
322            VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5);
323        let where_sql = filter.where_sql(1).unwrap();
324        assert!(where_sql.contains("<->"));
325        assert!(where_sql.contains("< 0.5"));
326    }
327
328    #[test]
329    fn test_distance_range_filter() {
330        let filter = VectorFilter::distance_range(
331            "embedding",
332            test_embedding(),
333            DistanceMetric::L2,
334            0.1,
335            0.5,
336        );
337        let where_sql = filter.where_sql(1).unwrap();
338        assert!(where_sql.contains("BETWEEN"));
339        assert!(where_sql.contains("0.1"));
340        assert!(where_sql.contains("0.5"));
341    }
342
343    #[test]
344    fn test_filter_with_limit() {
345        let filter =
346            VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
347                .with_limit(50);
348
349        assert_eq!(filter.limit_sql(), Some("LIMIT 50".to_string()));
350    }
351
352    #[test]
353    fn test_to_select_sql_nearest() {
354        let filter =
355            VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 5);
356        let sql = filter.to_select_sql("documents", 1, None, "*");
357
358        assert!(sql.contains("SELECT *, embedding <=> $1 AS distance"));
359        assert!(sql.contains("FROM documents"));
360        assert!(sql.contains("ORDER BY"));
361        assert!(sql.contains("LIMIT 5"));
362        assert!(!sql.contains("WHERE")); // No WHERE for nearest
363    }
364
365    #[test]
366    fn test_to_select_sql_with_extra_where() {
367        let filter =
368            VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
369                .with_limit(20);
370
371        let sql = filter.to_select_sql("documents", 1, Some("category = 'tech'"), "*");
372        assert!(sql.contains("WHERE"));
373        assert!(sql.contains("< 0.5"));
374        assert!(sql.contains("category = 'tech'"));
375        assert!(sql.contains("AND"));
376    }
377
378    #[test]
379    fn test_vector_order_by() {
380        let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine);
381        assert!(order.include_distance);
382
383        let select = order.select_distance_sql(1).unwrap();
384        assert!(select.contains("<=>"));
385        assert!(select.contains("AS distance"));
386
387        let order_by = order.order_by_sql(1);
388        assert_eq!(order_by, "distance");
389    }
390
391    #[test]
392    fn test_vector_order_by_without_distance() {
393        let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::L2)
394            .without_distance();
395
396        assert!(order.select_distance_sql(1).is_none());
397        let order_by = order.order_by_sql(1);
398        assert!(order_by.contains("<->"));
399    }
400
401    #[test]
402    fn test_vector_order_by_custom_alias() {
403        let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine)
404            .alias("similarity");
405
406        let select = order.select_distance_sql(1).unwrap();
407        assert!(select.contains("AS similarity"));
408    }
409
410    #[test]
411    fn test_distance_expr_sql() {
412        let filter =
413            VectorFilter::nearest("emb", test_embedding(), DistanceMetric::InnerProduct, 5);
414        let expr = filter.distance_expr_sql(2);
415        assert_eq!(expr, "emb <#> $2");
416    }
417}