prax-pgvector 0.9.2

pgvector integration for the Prax ORM — vector similarity search, embeddings, and index management
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
//! Vector filter operations for integration with the prax query builder.
//!
//! This module provides filter types that can be used with prax-query's
//! filter system to perform vector similarity searches as part of WHERE clauses.
//!
//! # Examples
//!
//! ```rust
//! use prax_pgvector::filter::{VectorFilter, VectorOrderBy};
//! use prax_pgvector::{Embedding, DistanceMetric};
//!
//! // Create a nearest-neighbor filter
//! let query_vec = Embedding::new(vec![0.1, 0.2, 0.3]);
//! let filter = VectorFilter::nearest("embedding", query_vec, DistanceMetric::Cosine, 10);
//!
//! // Create a distance-filtered search
//! let query_vec = Embedding::new(vec![0.1, 0.2, 0.3]);
//! let filter = VectorFilter::within_distance("embedding", query_vec, DistanceMetric::L2, 0.5);
//! ```

use serde::{Deserialize, Serialize};

use crate::ops::DistanceMetric;
use crate::types::Embedding;

/// A vector filter operation for use in WHERE and ORDER BY clauses.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorFilter {
    /// Column containing the vector.
    pub column: String,
    /// Query vector to compare against.
    pub query_vector: Embedding,
    /// Distance metric to use.
    pub metric: DistanceMetric,
    /// Type of vector filter.
    pub filter_type: VectorFilterType,
}

/// The type of vector filter operation.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[non_exhaustive]
pub enum VectorFilterType {
    /// K-nearest neighbor search (ORDER BY distance LIMIT k).
    Nearest {
        /// Maximum number of results to return.
        limit: usize,
    },

    /// Distance-based filter (WHERE distance < threshold).
    WithinDistance {
        /// Maximum distance threshold.
        max_distance: f64,
        /// Optional result limit.
        limit: Option<usize>,
    },

    /// Distance range filter (WHERE distance BETWEEN min AND max).
    DistanceRange {
        /// Minimum distance.
        min_distance: f64,
        /// Maximum distance.
        max_distance: f64,
        /// Optional result limit.
        limit: Option<usize>,
    },
}

impl VectorFilter {
    /// Create a k-nearest neighbor filter.
    ///
    /// This generates an ORDER BY with the vector distance operator and LIMIT.
    pub fn nearest(
        column: impl Into<String>,
        query_vector: Embedding,
        metric: DistanceMetric,
        limit: usize,
    ) -> Self {
        Self {
            column: column.into(),
            query_vector,
            metric,
            filter_type: VectorFilterType::Nearest { limit },
        }
    }

    /// Create a distance-based filter.
    ///
    /// This generates a WHERE clause filtering by maximum distance.
    pub fn within_distance(
        column: impl Into<String>,
        query_vector: Embedding,
        metric: DistanceMetric,
        max_distance: f64,
    ) -> Self {
        Self {
            column: column.into(),
            query_vector,
            metric,
            filter_type: VectorFilterType::WithinDistance {
                max_distance,
                limit: None,
            },
        }
    }

    /// Create a distance range filter.
    pub fn distance_range(
        column: impl Into<String>,
        query_vector: Embedding,
        metric: DistanceMetric,
        min_distance: f64,
        max_distance: f64,
    ) -> Self {
        Self {
            column: column.into(),
            query_vector,
            metric,
            filter_type: VectorFilterType::DistanceRange {
                min_distance,
                max_distance,
                limit: None,
            },
        }
    }

    /// Add a limit to this filter.
    pub fn with_limit(mut self, limit: usize) -> Self {
        match &mut self.filter_type {
            VectorFilterType::Nearest { limit: l } => *l = limit,
            VectorFilterType::WithinDistance { limit: l, .. } => *l = Some(limit),
            VectorFilterType::DistanceRange { limit: l, .. } => *l = Some(limit),
        }
        self
    }

    /// Generate the distance expression SQL fragment.
    ///
    /// Returns something like: `embedding <=> $1`
    pub fn distance_expr_sql(&self, param_index: usize) -> String {
        format!(
            "{} {} ${}",
            self.column,
            self.metric.operator(),
            param_index
        )
    }

    /// Generate the WHERE clause SQL fragment.
    ///
    /// Returns `None` for nearest-neighbor searches (which only use ORDER BY).
    pub fn where_sql(&self, param_index: usize) -> Option<String> {
        let distance_expr = self.distance_expr_sql(param_index);

        match &self.filter_type {
            VectorFilterType::Nearest { .. } => None,
            VectorFilterType::WithinDistance { max_distance, .. } => {
                Some(format!("{distance_expr} < {max_distance}"))
            }
            VectorFilterType::DistanceRange {
                min_distance,
                max_distance,
                ..
            } => Some(format!(
                "{distance_expr} BETWEEN {min_distance} AND {max_distance}"
            )),
        }
    }

    /// Generate the ORDER BY clause SQL fragment.
    pub fn order_by_sql(&self, param_index: usize) -> String {
        self.distance_expr_sql(param_index)
    }

    /// Generate the LIMIT clause.
    pub fn limit_sql(&self) -> Option<String> {
        let limit = match &self.filter_type {
            VectorFilterType::Nearest { limit } => Some(*limit),
            VectorFilterType::WithinDistance { limit, .. } => *limit,
            VectorFilterType::DistanceRange { limit, .. } => *limit,
        };

        limit.map(|l| format!("LIMIT {l}"))
    }

    /// Generate the complete SELECT query incorporating this vector filter.
    ///
    /// This produces a query like:
    /// ```sql
    /// SELECT *, embedding <=> $1 AS distance
    /// FROM documents
    /// WHERE embedding <=> $1 < 0.5
    /// ORDER BY distance
    /// LIMIT 10
    /// ```
    pub fn to_select_sql(
        &self,
        table: &str,
        param_index: usize,
        extra_where: Option<&str>,
        select_columns: &str,
    ) -> String {
        let distance_expr = self.distance_expr_sql(param_index);

        let mut sql = format!(
            "SELECT {}, {} AS distance FROM {}",
            select_columns, distance_expr, table
        );

        // WHERE clause
        let mut where_parts = Vec::new();
        if let Some(vec_where) = self.where_sql(param_index) {
            where_parts.push(vec_where);
        }
        if let Some(extra) = extra_where {
            where_parts.push(extra.to_string());
        }
        if !where_parts.is_empty() {
            sql.push_str(&format!(" WHERE {}", where_parts.join(" AND ")));
        }

        // ORDER BY
        sql.push_str(&format!(" ORDER BY {}", self.order_by_sql(param_index)));

        // LIMIT
        if let Some(limit) = self.limit_sql() {
            sql.push_str(&format!(" {limit}"));
        }

        sql
    }
}

/// Vector ordering specification for use with query builders.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VectorOrderBy {
    /// Column containing the vector.
    pub column: String,
    /// Query vector to compare against.
    pub query_vector: Embedding,
    /// Distance metric.
    pub metric: DistanceMetric,
    /// Whether to include the distance as a result column.
    pub include_distance: bool,
    /// Alias for the distance column.
    pub distance_alias: String,
}

impl VectorOrderBy {
    /// Create a new vector ordering.
    pub fn new(column: impl Into<String>, query_vector: Embedding, metric: DistanceMetric) -> Self {
        Self {
            column: column.into(),
            query_vector,
            metric,
            include_distance: true,
            distance_alias: "distance".to_string(),
        }
    }

    /// Set the distance column alias.
    pub fn alias(mut self, alias: impl Into<String>) -> Self {
        self.distance_alias = alias.into();
        self
    }

    /// Don't include the distance as a result column.
    pub fn without_distance(mut self) -> Self {
        self.include_distance = false;
        self
    }

    /// Generate the SELECT addition for the distance column.
    pub fn select_distance_sql(&self, param_index: usize) -> Option<String> {
        if self.include_distance {
            Some(format!(
                "{} {} ${} AS {}",
                self.column,
                self.metric.operator(),
                param_index,
                self.distance_alias
            ))
        } else {
            None
        }
    }

    /// Generate the ORDER BY clause.
    pub fn order_by_sql(&self, param_index: usize) -> String {
        if self.include_distance {
            self.distance_alias.clone()
        } else {
            format!(
                "{} {} ${}",
                self.column,
                self.metric.operator(),
                param_index
            )
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    fn test_embedding() -> Embedding {
        Embedding::new(vec![0.1, 0.2, 0.3])
    }

    #[test]
    fn test_nearest_filter() {
        let filter =
            VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 10);
        assert!(filter.where_sql(1).is_none());
        assert_eq!(filter.order_by_sql(1), "embedding <=> $1");
        assert_eq!(filter.limit_sql(), Some("LIMIT 10".to_string()));
    }

    #[test]
    fn test_within_distance_filter() {
        let filter =
            VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5);
        let where_sql = filter.where_sql(1).unwrap();
        assert!(where_sql.contains("<->"));
        assert!(where_sql.contains("< 0.5"));
    }

    #[test]
    fn test_distance_range_filter() {
        let filter = VectorFilter::distance_range(
            "embedding",
            test_embedding(),
            DistanceMetric::L2,
            0.1,
            0.5,
        );
        let where_sql = filter.where_sql(1).unwrap();
        assert!(where_sql.contains("BETWEEN"));
        assert!(where_sql.contains("0.1"));
        assert!(where_sql.contains("0.5"));
    }

    #[test]
    fn test_filter_with_limit() {
        let filter =
            VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
                .with_limit(50);

        assert_eq!(filter.limit_sql(), Some("LIMIT 50".to_string()));
    }

    #[test]
    fn test_to_select_sql_nearest() {
        let filter =
            VectorFilter::nearest("embedding", test_embedding(), DistanceMetric::Cosine, 5);
        let sql = filter.to_select_sql("documents", 1, None, "*");

        assert!(sql.contains("SELECT *, embedding <=> $1 AS distance"));
        assert!(sql.contains("FROM documents"));
        assert!(sql.contains("ORDER BY"));
        assert!(sql.contains("LIMIT 5"));
        assert!(!sql.contains("WHERE")); // No WHERE for nearest
    }

    #[test]
    fn test_to_select_sql_with_extra_where() {
        let filter =
            VectorFilter::within_distance("embedding", test_embedding(), DistanceMetric::L2, 0.5)
                .with_limit(20);

        let sql = filter.to_select_sql("documents", 1, Some("category = 'tech'"), "*");
        assert!(sql.contains("WHERE"));
        assert!(sql.contains("< 0.5"));
        assert!(sql.contains("category = 'tech'"));
        assert!(sql.contains("AND"));
    }

    #[test]
    fn test_vector_order_by() {
        let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine);
        assert!(order.include_distance);

        let select = order.select_distance_sql(1).unwrap();
        assert!(select.contains("<=>"));
        assert!(select.contains("AS distance"));

        let order_by = order.order_by_sql(1);
        assert_eq!(order_by, "distance");
    }

    #[test]
    fn test_vector_order_by_without_distance() {
        let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::L2)
            .without_distance();

        assert!(order.select_distance_sql(1).is_none());
        let order_by = order.order_by_sql(1);
        assert!(order_by.contains("<->"));
    }

    #[test]
    fn test_vector_order_by_custom_alias() {
        let order = VectorOrderBy::new("embedding", test_embedding(), DistanceMetric::Cosine)
            .alias("similarity");

        let select = order.select_distance_sql(1).unwrap();
        assert!(select.contains("AS similarity"));
    }

    #[test]
    fn test_distance_expr_sql() {
        let filter =
            VectorFilter::nearest("emb", test_embedding(), DistanceMetric::InnerProduct, 5);
        let expr = filter.distance_expr_sql(2);
        assert_eq!(expr, "emb <#> $2");
    }
}