Skip to main content

prax_pgvector/
ops.rs

1//! Distance operators and similarity metrics for pgvector.
2//!
3//! pgvector supports multiple distance functions, each with a corresponding
4//! PostgreSQL operator. This module provides type-safe abstractions for these.
5//!
6//! # Operators
7//!
8//! | Metric | Operator | Index Ops Class |
9//! |--------|----------|-----------------|
10//! | L2 (Euclidean) | `<->` | `vector_l2_ops` |
11//! | Inner Product | `<#>` | `vector_ip_ops` |
12//! | Cosine | `<=>` | `vector_cosine_ops` |
13//! | L1 (Manhattan) | `<+>` | `vector_l1_ops` |
14//! | Hamming | `<~>` | `bit_hamming_ops` |
15//! | Jaccard | `<%>` | `bit_jaccard_ops` |
16
17use std::fmt;
18
19use serde::{Deserialize, Serialize};
20
21use crate::types::Embedding;
22
23/// Vector distance metric supported by pgvector.
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
25#[non_exhaustive]
26pub enum DistanceMetric {
27    /// Euclidean distance (L2 norm).
28    ///
29    /// Operator: `<->`
30    /// Range: [0, ∞)
31    /// Use when: comparing absolute distances between vectors.
32    L2,
33
34    /// Negative inner product.
35    ///
36    /// Operator: `<#>`
37    /// Range: (-∞, ∞)
38    /// Use when: vectors are normalized and you want maximum inner product.
39    ///
40    /// Note: pgvector returns *negative* inner product so that smaller = more similar,
41    /// consistent with the ORDER BY ASC convention.
42    InnerProduct,
43
44    /// Cosine distance (1 - cosine similarity).
45    ///
46    /// Operator: `<=>`
47    /// Range: [0, 2]
48    /// Use when: comparing direction regardless of magnitude.
49    Cosine,
50
51    /// Manhattan distance (L1 norm).
52    ///
53    /// Operator: `<+>`
54    /// Range: [0, ∞)
55    /// Use when: you need L1 distance, often in recommendation systems.
56    L1,
57}
58
59impl DistanceMetric {
60    /// Get the PostgreSQL operator for this metric.
61    pub fn operator(&self) -> &'static str {
62        match self {
63            Self::L2 => "<->",
64            Self::InnerProduct => "<#>",
65            Self::Cosine => "<=>",
66            Self::L1 => "<+>",
67        }
68    }
69
70    /// Get the operator class name for index creation.
71    pub fn ops_class(&self) -> &'static str {
72        match self {
73            Self::L2 => "vector_l2_ops",
74            Self::InnerProduct => "vector_ip_ops",
75            Self::Cosine => "vector_cosine_ops",
76            Self::L1 => "vector_l1_ops",
77        }
78    }
79
80    /// Get a human-readable name for this metric.
81    pub fn name(&self) -> &'static str {
82        match self {
83            Self::L2 => "euclidean",
84            Self::InnerProduct => "inner_product",
85            Self::Cosine => "cosine",
86            Self::L1 => "manhattan",
87        }
88    }
89
90    /// Whether this metric benefits from normalized vectors.
91    pub fn prefers_normalized(&self) -> bool {
92        matches!(self, Self::InnerProduct | Self::Cosine)
93    }
94}
95
96impl fmt::Display for DistanceMetric {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        write!(f, "{}", self.name())
99    }
100}
101
102/// Binary vector distance metric supported by pgvector.
103#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
104#[non_exhaustive]
105pub enum BinaryDistanceMetric {
106    /// Hamming distance (number of differing bits).
107    ///
108    /// Operator: `<~>`
109    Hamming,
110
111    /// Jaccard distance (1 - Jaccard index).
112    ///
113    /// Operator: `<%>`
114    Jaccard,
115}
116
117impl BinaryDistanceMetric {
118    /// Get the PostgreSQL operator for this metric.
119    pub fn operator(&self) -> &'static str {
120        match self {
121            Self::Hamming => "<~>",
122            Self::Jaccard => "<%>",
123        }
124    }
125
126    /// Get the operator class name for index creation.
127    pub fn ops_class(&self) -> &'static str {
128        match self {
129            Self::Hamming => "bit_hamming_ops",
130            Self::Jaccard => "bit_jaccard_ops",
131        }
132    }
133
134    /// Get a human-readable name for this metric.
135    pub fn name(&self) -> &'static str {
136        match self {
137            Self::Hamming => "hamming",
138            Self::Jaccard => "jaccard",
139        }
140    }
141}
142
143impl fmt::Display for BinaryDistanceMetric {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        write!(f, "{}", self.name())
146    }
147}
148
149/// Generate SQL for computing the distance between a column and a query vector.
150///
151/// # Examples
152///
153/// ```rust
154/// use prax_pgvector::{Embedding, DistanceMetric, ops::distance_sql};
155///
156/// let query = Embedding::new(vec![0.1, 0.2, 0.3]);
157/// let sql = distance_sql("embedding", &query, DistanceMetric::Cosine);
158/// assert!(sql.contains("<=>"));
159/// ```
160pub fn distance_sql(column: &str, query_vector: &Embedding, metric: DistanceMetric) -> String {
161    format!(
162        "{} {} {}",
163        column,
164        metric.operator(),
165        query_vector.to_sql_literal()
166    )
167}
168
169/// Generate SQL for computing distance with a parameter placeholder.
170///
171/// This is preferred over [`distance_sql`] when using parameterized queries
172/// to prevent SQL injection.
173///
174/// # Examples
175///
176/// ```rust
177/// use prax_pgvector::{DistanceMetric, ops::distance_param_sql};
178///
179/// let sql = distance_param_sql("embedding", "$1", DistanceMetric::L2);
180/// assert_eq!(sql, "embedding <-> $1");
181/// ```
182pub fn distance_param_sql(column: &str, param: &str, metric: DistanceMetric) -> String {
183    format!("{} {} {}", column, metric.operator(), param)
184}
185
186/// Generate an ORDER BY clause for nearest-neighbor search.
187///
188/// Returns a SQL fragment like: `embedding <-> '[0.1,0.2,0.3]'::vector`
189/// suitable for use in ORDER BY.
190pub fn order_by_distance(column: &str, query_vector: &Embedding, metric: DistanceMetric) -> String {
191    distance_sql(column, query_vector, metric)
192}
193
194/// Generate a complete nearest-neighbor search query.
195///
196/// This generates SQL like:
197/// ```sql
198/// SELECT *, embedding <-> $1 AS distance
199/// FROM documents
200/// ORDER BY distance
201/// LIMIT 10
202/// ```
203pub fn nearest_neighbor_sql(
204    table: &str,
205    column: &str,
206    metric: DistanceMetric,
207    param_index: usize,
208    limit: usize,
209    extra_columns: &[&str],
210) -> String {
211    let distance_expr = distance_param_sql(column, &format!("${param_index}"), metric);
212
213    let select_cols = if extra_columns.is_empty() {
214        "*".to_string()
215    } else {
216        let mut cols = vec!["*".to_string()];
217        cols.extend(extra_columns.iter().map(|c| (*c).to_string()));
218        cols.join(", ")
219    };
220
221    format!(
222        "SELECT {}, {} AS distance FROM {} ORDER BY distance LIMIT {}",
223        select_cols, distance_expr, table, limit
224    )
225}
226
227/// Generate SQL for a distance-filtered search (within a radius).
228///
229/// Returns SQL like:
230/// ```sql
231/// SELECT *, embedding <-> $1 AS distance
232/// FROM documents
233/// WHERE embedding <-> $1 < 0.5
234/// ORDER BY distance
235/// LIMIT 100
236/// ```
237pub fn radius_search_sql(
238    table: &str,
239    column: &str,
240    metric: DistanceMetric,
241    param_index: usize,
242    max_distance: f64,
243    limit: Option<usize>,
244) -> String {
245    let param = format!("${param_index}");
246    let distance_expr = distance_param_sql(column, &param, metric);
247
248    let limit_clause = limit.map(|l| format!(" LIMIT {l}")).unwrap_or_default();
249
250    format!(
251        "SELECT *, {} AS distance FROM {} WHERE {} < {} ORDER BY distance{}",
252        distance_expr, table, distance_expr, max_distance, limit_clause
253    )
254}
255
256/// Configuration for setting pgvector search parameters.
257///
258/// These SET commands tune the behavior of approximate index scans.
259#[derive(Debug, Clone, Serialize, Deserialize)]
260pub struct SearchParams {
261    /// Number of IVFFlat lists to probe (default: 1).
262    ///
263    /// Higher values improve recall at the cost of speed.
264    pub ivfflat_probes: Option<usize>,
265
266    /// HNSW search ef parameter (default: 40).
267    ///
268    /// Higher values improve recall at the cost of speed.
269    pub hnsw_ef_search: Option<usize>,
270}
271
272impl SearchParams {
273    /// Create new search parameters.
274    pub fn new() -> Self {
275        Self {
276            ivfflat_probes: None,
277            hnsw_ef_search: None,
278        }
279    }
280
281    /// Set the number of IVFFlat probes.
282    pub fn probes(mut self, probes: usize) -> Self {
283        self.ivfflat_probes = Some(probes);
284        self
285    }
286
287    /// Set the HNSW ef_search parameter.
288    pub fn ef_search(mut self, ef: usize) -> Self {
289        self.hnsw_ef_search = Some(ef);
290        self
291    }
292
293    /// Generate the SET commands for these parameters.
294    pub fn to_set_sql(&self) -> Vec<String> {
295        let mut statements = Vec::new();
296
297        if let Some(probes) = self.ivfflat_probes {
298            statements.push(format!("SET ivfflat.probes = {probes}"));
299        }
300        if let Some(ef) = self.hnsw_ef_search {
301            statements.push(format!("SET hnsw.ef_search = {ef}"));
302        }
303
304        statements
305    }
306
307    /// Check if any parameters are set.
308    pub fn has_params(&self) -> bool {
309        self.ivfflat_probes.is_some() || self.hnsw_ef_search.is_some()
310    }
311}
312
313impl Default for SearchParams {
314    fn default() -> Self {
315        Self::new()
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_distance_metric_operator() {
325        assert_eq!(DistanceMetric::L2.operator(), "<->");
326        assert_eq!(DistanceMetric::InnerProduct.operator(), "<#>");
327        assert_eq!(DistanceMetric::Cosine.operator(), "<=>");
328        assert_eq!(DistanceMetric::L1.operator(), "<+>");
329    }
330
331    #[test]
332    fn test_distance_metric_ops_class() {
333        assert_eq!(DistanceMetric::L2.ops_class(), "vector_l2_ops");
334        assert_eq!(DistanceMetric::InnerProduct.ops_class(), "vector_ip_ops");
335        assert_eq!(DistanceMetric::Cosine.ops_class(), "vector_cosine_ops");
336        assert_eq!(DistanceMetric::L1.ops_class(), "vector_l1_ops");
337    }
338
339    #[test]
340    fn test_distance_metric_prefers_normalized() {
341        assert!(!DistanceMetric::L2.prefers_normalized());
342        assert!(DistanceMetric::InnerProduct.prefers_normalized());
343        assert!(DistanceMetric::Cosine.prefers_normalized());
344        assert!(!DistanceMetric::L1.prefers_normalized());
345    }
346
347    #[test]
348    fn test_binary_distance_metric_operator() {
349        assert_eq!(BinaryDistanceMetric::Hamming.operator(), "<~>");
350        assert_eq!(BinaryDistanceMetric::Jaccard.operator(), "<%>");
351    }
352
353    #[test]
354    fn test_binary_distance_metric_ops_class() {
355        assert_eq!(BinaryDistanceMetric::Hamming.ops_class(), "bit_hamming_ops");
356        assert_eq!(BinaryDistanceMetric::Jaccard.ops_class(), "bit_jaccard_ops");
357    }
358
359    #[test]
360    fn test_distance_sql() {
361        let query = Embedding::new(vec![0.1, 0.2, 0.3]);
362        let sql = distance_sql("embedding", &query, DistanceMetric::Cosine);
363        assert!(sql.contains("<=>"));
364        assert!(sql.contains("::vector"));
365    }
366
367    #[test]
368    fn test_distance_param_sql() {
369        let sql = distance_param_sql("embedding", "$1", DistanceMetric::L2);
370        assert_eq!(sql, "embedding <-> $1");
371    }
372
373    #[test]
374    fn test_nearest_neighbor_sql() {
375        let sql =
376            nearest_neighbor_sql("documents", "embedding", DistanceMetric::Cosine, 1, 10, &[]);
377        assert!(sql.contains("SELECT *"));
378        assert!(sql.contains("<=>"));
379        assert!(sql.contains("$1"));
380        assert!(sql.contains("LIMIT 10"));
381        assert!(sql.contains("AS distance"));
382        assert!(sql.contains("ORDER BY distance"));
383    }
384
385    #[test]
386    fn test_radius_search_sql() {
387        let sql = radius_search_sql(
388            "documents",
389            "embedding",
390            DistanceMetric::L2,
391            1,
392            0.5,
393            Some(100),
394        );
395        assert!(sql.contains("<->"));
396        assert!(sql.contains("< 0.5"));
397        assert!(sql.contains("LIMIT 100"));
398    }
399
400    #[test]
401    fn test_radius_search_sql_no_limit() {
402        let sql = radius_search_sql("documents", "embedding", DistanceMetric::L2, 1, 1.0, None);
403        assert!(!sql.contains("LIMIT"));
404    }
405
406    #[test]
407    fn test_search_params_probes() {
408        let params = SearchParams::new().probes(10);
409        let sql = params.to_set_sql();
410        assert_eq!(sql.len(), 1);
411        assert_eq!(sql[0], "SET ivfflat.probes = 10");
412    }
413
414    #[test]
415    fn test_search_params_ef_search() {
416        let params = SearchParams::new().ef_search(200);
417        let sql = params.to_set_sql();
418        assert_eq!(sql.len(), 1);
419        assert_eq!(sql[0], "SET hnsw.ef_search = 200");
420    }
421
422    #[test]
423    fn test_search_params_both() {
424        let params = SearchParams::new().probes(10).ef_search(200);
425        let sql = params.to_set_sql();
426        assert_eq!(sql.len(), 2);
427        assert!(params.has_params());
428    }
429
430    #[test]
431    fn test_search_params_empty() {
432        let params = SearchParams::new();
433        assert!(!params.has_params());
434        assert!(params.to_set_sql().is_empty());
435    }
436
437    #[test]
438    fn test_distance_metric_display() {
439        assert_eq!(format!("{}", DistanceMetric::L2), "euclidean");
440        assert_eq!(format!("{}", DistanceMetric::Cosine), "cosine");
441    }
442}