Skip to main content

prax_pgvector/
query.rs

1//! High-level query builder for vector similarity search.
2//!
3//! This module provides a fluent builder API for constructing vector search queries
4//! that integrate with the prax-postgres engine.
5//!
6//! # Examples
7//!
8//! ```rust
9//! use prax_pgvector::query::VectorSearchBuilder;
10//! use prax_pgvector::{Embedding, DistanceMetric};
11//!
12//! let query = VectorSearchBuilder::new("documents", "embedding")
13//!     .query(Embedding::new(vec![0.1, 0.2, 0.3]))
14//!     .metric(DistanceMetric::Cosine)
15//!     .limit(10)
16//!     .select(&["id", "title", "content"])
17//!     .where_clause("category = 'tech'")
18//!     .build();
19//!
20//! let sql = query.to_sql();
21//! assert!(sql.contains("<=>")); // cosine distance operator
22//! ```
23
24use serde::{Deserialize, Serialize};
25
26use crate::ops::{DistanceMetric, SearchParams};
27use crate::types::Embedding;
28
29/// A fully constructed vector search query ready for execution.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct VectorSearchQuery {
32    /// The table to search.
33    pub table: String,
34    /// The vector column.
35    pub column: String,
36    /// The query vector.
37    pub query_vector: Embedding,
38    /// Distance metric.
39    pub metric: DistanceMetric,
40    /// Maximum number of results.
41    pub limit: usize,
42    /// Columns to select (empty = all).
43    pub select_columns: Vec<String>,
44    /// Additional WHERE conditions.
45    pub where_clauses: Vec<String>,
46    /// Whether to include the distance in results.
47    pub include_distance: bool,
48    /// Alias for the distance column.
49    pub distance_alias: String,
50    /// Maximum distance threshold (radius search).
51    pub max_distance: Option<f64>,
52    /// Minimum distance threshold.
53    pub min_distance: Option<f64>,
54    /// Additional ORDER BY clauses (after distance).
55    pub extra_order_by: Vec<String>,
56    /// Offset for pagination.
57    pub offset: Option<usize>,
58    /// Search parameters (probes, ef_search).
59    pub search_params: SearchParams,
60}
61
62impl VectorSearchQuery {
63    /// Generate the complete SQL query.
64    ///
65    /// The query vector should be passed as parameter `$1`.
66    pub fn to_sql(&self) -> String {
67        self.to_sql_with_param(1)
68    }
69
70    /// Generate the complete SQL query with a custom parameter index.
71    pub fn to_sql_with_param(&self, param_index: usize) -> String {
72        let param = format!("${param_index}");
73        let distance_expr = format!("{} {} {}", self.column, self.metric.operator(), param);
74
75        // SELECT clause
76        let select = if self.select_columns.is_empty() {
77            "*".to_string()
78        } else {
79            self.select_columns.join(", ")
80        };
81
82        let distance_select = if self.include_distance {
83            format!(", {} AS {}", distance_expr, self.distance_alias)
84        } else {
85            String::new()
86        };
87
88        // WHERE clause
89        let mut where_parts = Vec::new();
90
91        if let Some(max) = self.max_distance {
92            where_parts.push(format!("{distance_expr} < {max}"));
93        }
94        if let Some(min) = self.min_distance {
95            where_parts.push(format!("{distance_expr} >= {min}"));
96        }
97        where_parts.extend(self.where_clauses.clone());
98
99        let where_clause = if where_parts.is_empty() {
100            String::new()
101        } else {
102            format!(" WHERE {}", where_parts.join(" AND "))
103        };
104
105        // ORDER BY clause
106        let order_by_main = if self.include_distance {
107            self.distance_alias.clone()
108        } else {
109            distance_expr
110        };
111
112        let order_by = if self.extra_order_by.is_empty() {
113            order_by_main
114        } else {
115            let mut parts = vec![order_by_main];
116            parts.extend(self.extra_order_by.clone());
117            parts.join(", ")
118        };
119
120        // LIMIT and OFFSET
121        let limit = format!(" LIMIT {}", self.limit);
122        let offset = self
123            .offset
124            .map(|o| format!(" OFFSET {o}"))
125            .unwrap_or_default();
126
127        format!(
128            "SELECT {}{} FROM {}{}  ORDER BY {}{}{}",
129            select, distance_select, self.table, where_clause, order_by, limit, offset
130        )
131    }
132
133    /// Generate SET commands for search parameters.
134    ///
135    /// These should be executed before the search query to tune index scan behavior.
136    pub fn param_set_sql(&self) -> Vec<String> {
137        self.search_params.to_set_sql()
138    }
139}
140
141/// Fluent builder for vector search queries.
142///
143/// # Examples
144///
145/// ```rust
146/// use prax_pgvector::query::VectorSearchBuilder;
147/// use prax_pgvector::{Embedding, DistanceMetric};
148///
149/// let query = VectorSearchBuilder::new("documents", "embedding")
150///     .query(Embedding::new(vec![0.1, 0.2, 0.3]))
151///     .metric(DistanceMetric::Cosine)
152///     .limit(10)
153///     .ef_search(200)
154///     .build();
155/// ```
156pub struct VectorSearchBuilder {
157    table: String,
158    column: String,
159    query_vector: Option<Embedding>,
160    metric: DistanceMetric,
161    limit: usize,
162    select_columns: Vec<String>,
163    where_clauses: Vec<String>,
164    include_distance: bool,
165    distance_alias: String,
166    max_distance: Option<f64>,
167    min_distance: Option<f64>,
168    extra_order_by: Vec<String>,
169    offset: Option<usize>,
170    search_params: SearchParams,
171}
172
173impl VectorSearchBuilder {
174    /// Create a new search builder for a table and vector column.
175    pub fn new(table: impl Into<String>, column: impl Into<String>) -> Self {
176        Self {
177            table: table.into(),
178            column: column.into(),
179            query_vector: None,
180            metric: DistanceMetric::L2,
181            limit: 10,
182            select_columns: Vec::new(),
183            where_clauses: Vec::new(),
184            include_distance: true,
185            distance_alias: "distance".to_string(),
186            max_distance: None,
187            min_distance: None,
188            extra_order_by: Vec::new(),
189            offset: None,
190            search_params: SearchParams::new(),
191        }
192    }
193
194    /// Set the query vector.
195    pub fn query(mut self, embedding: Embedding) -> Self {
196        self.query_vector = Some(embedding);
197        self
198    }
199
200    /// Set the distance metric.
201    pub fn metric(mut self, metric: DistanceMetric) -> Self {
202        self.metric = metric;
203        self
204    }
205
206    /// Set the result limit.
207    pub fn limit(mut self, limit: usize) -> Self {
208        self.limit = limit;
209        self
210    }
211
212    /// Set specific columns to select.
213    pub fn select(mut self, columns: &[&str]) -> Self {
214        self.select_columns = columns.iter().map(|c| (*c).to_string()).collect();
215        self
216    }
217
218    /// Add a WHERE condition.
219    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
220        self.where_clauses.push(condition.into());
221        self
222    }
223
224    /// Set the maximum distance (radius search).
225    pub fn max_distance(mut self, distance: f64) -> Self {
226        self.max_distance = Some(distance);
227        self
228    }
229
230    /// Set the minimum distance.
231    pub fn min_distance(mut self, distance: f64) -> Self {
232        self.min_distance = Some(distance);
233        self
234    }
235
236    /// Don't include the distance in the results.
237    pub fn without_distance(mut self) -> Self {
238        self.include_distance = false;
239        self
240    }
241
242    /// Set a custom distance column alias.
243    pub fn distance_alias(mut self, alias: impl Into<String>) -> Self {
244        self.distance_alias = alias.into();
245        self
246    }
247
248    /// Add an additional ORDER BY clause (after distance).
249    pub fn then_order_by(mut self, clause: impl Into<String>) -> Self {
250        self.extra_order_by.push(clause.into());
251        self
252    }
253
254    /// Set the offset for pagination.
255    pub fn offset(mut self, offset: usize) -> Self {
256        self.offset = Some(offset);
257        self
258    }
259
260    /// Set the IVFFlat probes parameter.
261    pub fn probes(mut self, probes: usize) -> Self {
262        self.search_params = self.search_params.probes(probes);
263        self
264    }
265
266    /// Set the HNSW ef_search parameter.
267    pub fn ef_search(mut self, ef: usize) -> Self {
268        self.search_params = self.search_params.ef_search(ef);
269        self
270    }
271
272    /// Build the vector search query.
273    ///
274    /// # Panics
275    ///
276    /// Panics if no query vector has been set. Use [`Self::try_build`] for
277    /// a non-panicking alternative.
278    pub fn build(self) -> VectorSearchQuery {
279        self.try_build()
280            .expect("query vector must be set before building")
281    }
282
283    /// Try to build the vector search query.
284    ///
285    /// Returns `None` if no query vector has been set.
286    pub fn try_build(self) -> Option<VectorSearchQuery> {
287        let query_vector = self.query_vector?;
288
289        Some(VectorSearchQuery {
290            table: self.table,
291            column: self.column,
292            query_vector,
293            metric: self.metric,
294            limit: self.limit,
295            select_columns: self.select_columns,
296            where_clauses: self.where_clauses,
297            include_distance: self.include_distance,
298            distance_alias: self.distance_alias,
299            max_distance: self.max_distance,
300            min_distance: self.min_distance,
301            extra_order_by: self.extra_order_by,
302            offset: self.offset,
303            search_params: self.search_params,
304        })
305    }
306}
307
308/// Builder for hybrid search queries that combine vector similarity with full-text search.
309///
310/// This generates queries that use both pgvector distance operators and
311/// PostgreSQL tsvector/tsquery for combined similarity scoring.
312///
313/// # Examples
314///
315/// ```rust
316/// use prax_pgvector::query::HybridSearchBuilder;
317/// use prax_pgvector::{Embedding, DistanceMetric};
318///
319/// let query = HybridSearchBuilder::new("documents")
320///     .vector_column("embedding")
321///     .text_column("content")
322///     .query_vector(Embedding::new(vec![0.1, 0.2, 0.3]))
323///     .query_text("machine learning")
324///     .metric(DistanceMetric::Cosine)
325///     .vector_weight(0.7)
326///     .text_weight(0.3)
327///     .limit(10)
328///     .build();
329///
330/// let sql = query.to_sql();
331/// ```
332pub struct HybridSearchBuilder {
333    table: String,
334    vector_column: Option<String>,
335    text_column: Option<String>,
336    query_vector: Option<Embedding>,
337    query_text: Option<String>,
338    metric: DistanceMetric,
339    vector_weight: f64,
340    text_weight: f64,
341    limit: usize,
342    language: String,
343    where_clauses: Vec<String>,
344}
345
346impl HybridSearchBuilder {
347    /// Create a new hybrid search builder.
348    pub fn new(table: impl Into<String>) -> Self {
349        Self {
350            table: table.into(),
351            vector_column: None,
352            text_column: None,
353            query_vector: None,
354            query_text: None,
355            metric: DistanceMetric::Cosine,
356            vector_weight: 0.5,
357            text_weight: 0.5,
358            limit: 10,
359            language: "english".to_string(),
360            where_clauses: Vec::new(),
361        }
362    }
363
364    /// Set the vector column name.
365    pub fn vector_column(mut self, column: impl Into<String>) -> Self {
366        self.vector_column = Some(column.into());
367        self
368    }
369
370    /// Set the text column name.
371    pub fn text_column(mut self, column: impl Into<String>) -> Self {
372        self.text_column = Some(column.into());
373        self
374    }
375
376    /// Set the query vector.
377    pub fn query_vector(mut self, embedding: Embedding) -> Self {
378        self.query_vector = Some(embedding);
379        self
380    }
381
382    /// Set the text query.
383    pub fn query_text(mut self, text: impl Into<String>) -> Self {
384        self.query_text = Some(text.into());
385        self
386    }
387
388    /// Set the vector distance metric.
389    pub fn metric(mut self, metric: DistanceMetric) -> Self {
390        self.metric = metric;
391        self
392    }
393
394    /// Set the weight for the vector similarity component (0.0 to 1.0).
395    pub fn vector_weight(mut self, weight: f64) -> Self {
396        self.vector_weight = weight;
397        self
398    }
399
400    /// Set the weight for the text relevance component (0.0 to 1.0).
401    pub fn text_weight(mut self, weight: f64) -> Self {
402        self.text_weight = weight;
403        self
404    }
405
406    /// Set the result limit.
407    pub fn limit(mut self, limit: usize) -> Self {
408        self.limit = limit;
409        self
410    }
411
412    /// Set the text search language.
413    pub fn language(mut self, language: impl Into<String>) -> Self {
414        self.language = language.into();
415        self
416    }
417
418    /// Add a WHERE condition.
419    pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
420        self.where_clauses.push(condition.into());
421        self
422    }
423
424    /// Build the hybrid search query.
425    pub fn build(self) -> HybridSearchQuery {
426        HybridSearchQuery {
427            table: self.table,
428            vector_column: self
429                .vector_column
430                .unwrap_or_else(|| "embedding".to_string()),
431            text_column: self.text_column.unwrap_or_else(|| "content".to_string()),
432            query_vector: self.query_vector,
433            query_text: self.query_text,
434            metric: self.metric,
435            vector_weight: self.vector_weight,
436            text_weight: self.text_weight,
437            limit: self.limit,
438            language: self.language,
439            where_clauses: self.where_clauses,
440        }
441    }
442}
443
444/// A hybrid search query combining vector similarity and full-text search.
445#[derive(Debug, Clone, Serialize, Deserialize)]
446pub struct HybridSearchQuery {
447    /// Table name.
448    pub table: String,
449    /// Vector column.
450    pub vector_column: String,
451    /// Text column.
452    pub text_column: String,
453    /// Query vector.
454    pub query_vector: Option<Embedding>,
455    /// Text query.
456    pub query_text: Option<String>,
457    /// Distance metric.
458    pub metric: DistanceMetric,
459    /// Weight for vector similarity (0.0-1.0).
460    pub vector_weight: f64,
461    /// Weight for text relevance (0.0-1.0).
462    pub text_weight: f64,
463    /// Result limit.
464    pub limit: usize,
465    /// Text search language.
466    pub language: String,
467    /// Additional WHERE conditions.
468    pub where_clauses: Vec<String>,
469}
470
471impl HybridSearchQuery {
472    /// Generate the SQL query using Reciprocal Rank Fusion (RRF).
473    ///
474    /// RRF combines rankings from multiple retrieval methods:
475    /// `score = sum(1 / (k + rank_i))` where k is a constant (typically 60).
476    ///
477    /// The query vector should be `$1` and the text query should be `$2`.
478    pub fn to_sql(&self) -> String {
479        let vec_distance = format!("{} {} $1", self.vector_column, self.metric.operator());
480        let text_rank = format!(
481            "ts_rank(to_tsvector('{}', {}), plainto_tsquery('{}', $2))",
482            self.language, self.text_column, self.language
483        );
484
485        let where_clause = if self.where_clauses.is_empty() {
486            String::new()
487        } else {
488            format!(" WHERE {}", self.where_clauses.join(" AND "))
489        };
490
491        // Use RRF scoring: combine vector and text rankings
492        format!(
493            "WITH vector_results AS (\
494                SELECT *, ROW_NUMBER() OVER (ORDER BY {vec_distance}) AS vec_rank \
495                FROM {table}{where_clause} \
496                ORDER BY {vec_distance} \
497                LIMIT {fetch_limit}\
498            ), \
499            text_results AS (\
500                SELECT *, ROW_NUMBER() OVER (ORDER BY {text_rank} DESC) AS text_rank \
501                FROM {table}{where_clause} \
502                WHERE to_tsvector('{lang}', {text_col}) @@ plainto_tsquery('{lang}', $2) \
503                ORDER BY {text_rank} DESC \
504                LIMIT {fetch_limit}\
505            ) \
506            SELECT COALESCE(v.*, t.*), \
507                ({vec_weight} / (60.0 + COALESCE(v.vec_rank, 1000))) + \
508                ({text_weight} / (60.0 + COALESCE(t.text_rank, 1000))) AS rrf_score \
509            FROM vector_results v \
510            FULL OUTER JOIN text_results t ON v.id = t.id \
511            ORDER BY rrf_score DESC \
512            LIMIT {limit}",
513            table = self.table,
514            where_clause = where_clause,
515            fetch_limit = self.limit * 3, // Fetch more for fusion
516            vec_weight = self.vector_weight,
517            text_weight = self.text_weight,
518            lang = self.language,
519            text_col = self.text_column,
520            limit = self.limit,
521        )
522    }
523}
524
525#[cfg(test)]
526mod tests {
527    use super::*;
528
529    fn test_embedding() -> Embedding {
530        Embedding::new(vec![0.1, 0.2, 0.3])
531    }
532
533    #[test]
534    fn test_basic_search_query() {
535        let query = VectorSearchBuilder::new("documents", "embedding")
536            .query(test_embedding())
537            .metric(DistanceMetric::Cosine)
538            .limit(10)
539            .build();
540
541        let sql = query.to_sql();
542        assert!(sql.contains("SELECT *"));
543        assert!(sql.contains("AS distance"));
544        assert!(sql.contains("<=>"));
545        assert!(sql.contains("$1"));
546        assert!(sql.contains("FROM documents"));
547        assert!(sql.contains("LIMIT 10"));
548    }
549
550    #[test]
551    fn test_search_with_select() {
552        let query = VectorSearchBuilder::new("documents", "embedding")
553            .query(test_embedding())
554            .select(&["id", "title"])
555            .build();
556
557        let sql = query.to_sql();
558        assert!(sql.contains("SELECT id, title"));
559    }
560
561    #[test]
562    fn test_search_with_where() {
563        let query = VectorSearchBuilder::new("documents", "embedding")
564            .query(test_embedding())
565            .where_clause("category = 'tech'")
566            .where_clause("published = true")
567            .build();
568
569        let sql = query.to_sql();
570        assert!(sql.contains("WHERE"));
571        assert!(sql.contains("category = 'tech'"));
572        assert!(sql.contains("published = true"));
573        assert!(sql.contains("AND"));
574    }
575
576    #[test]
577    fn test_search_with_max_distance() {
578        let query = VectorSearchBuilder::new("documents", "embedding")
579            .query(test_embedding())
580            .metric(DistanceMetric::L2)
581            .max_distance(0.5)
582            .build();
583
584        let sql = query.to_sql();
585        assert!(sql.contains("< 0.5"));
586    }
587
588    #[test]
589    fn test_search_with_distance_range() {
590        let query = VectorSearchBuilder::new("documents", "embedding")
591            .query(test_embedding())
592            .min_distance(0.1)
593            .max_distance(0.5)
594            .build();
595
596        let sql = query.to_sql();
597        assert!(sql.contains("< 0.5"));
598        assert!(sql.contains(">= 0.1"));
599    }
600
601    #[test]
602    fn test_search_without_distance() {
603        let query = VectorSearchBuilder::new("documents", "embedding")
604            .query(test_embedding())
605            .without_distance()
606            .build();
607
608        let sql = query.to_sql();
609        assert!(!sql.contains("AS distance"));
610    }
611
612    #[test]
613    fn test_search_custom_alias() {
614        let query = VectorSearchBuilder::new("documents", "embedding")
615            .query(test_embedding())
616            .distance_alias("similarity")
617            .build();
618
619        let sql = query.to_sql();
620        assert!(sql.contains("AS similarity"));
621    }
622
623    #[test]
624    fn test_search_with_pagination() {
625        let query = VectorSearchBuilder::new("documents", "embedding")
626            .query(test_embedding())
627            .limit(10)
628            .offset(20)
629            .build();
630
631        let sql = query.to_sql();
632        assert!(sql.contains("LIMIT 10"));
633        assert!(sql.contains("OFFSET 20"));
634    }
635
636    #[test]
637    fn test_search_with_extra_order_by() {
638        let query = VectorSearchBuilder::new("documents", "embedding")
639            .query(test_embedding())
640            .then_order_by("created_at DESC")
641            .build();
642
643        let sql = query.to_sql();
644        assert!(sql.contains("ORDER BY distance, created_at DESC"));
645    }
646
647    #[test]
648    fn test_search_params() {
649        let query = VectorSearchBuilder::new("documents", "embedding")
650            .query(test_embedding())
651            .probes(10)
652            .ef_search(200)
653            .build();
654
655        let set_sql = query.param_set_sql();
656        assert_eq!(set_sql.len(), 2);
657        assert!(set_sql[0].contains("ivfflat.probes = 10"));
658        assert!(set_sql[1].contains("hnsw.ef_search = 200"));
659    }
660
661    #[test]
662    fn test_try_build_without_vector() {
663        let result = VectorSearchBuilder::new("documents", "embedding").try_build();
664        assert!(result.is_none());
665    }
666
667    #[test]
668    fn test_custom_param_index() {
669        let query = VectorSearchBuilder::new("documents", "embedding")
670            .query(test_embedding())
671            .build();
672
673        let sql = query.to_sql_with_param(3);
674        assert!(sql.contains("$3"));
675    }
676
677    #[test]
678    fn test_hybrid_search() {
679        let query = HybridSearchBuilder::new("documents")
680            .vector_column("embedding")
681            .text_column("content")
682            .query_vector(test_embedding())
683            .query_text("machine learning")
684            .metric(DistanceMetric::Cosine)
685            .vector_weight(0.7)
686            .text_weight(0.3)
687            .limit(10)
688            .build();
689
690        let sql = query.to_sql();
691        assert!(sql.contains("vector_results"));
692        assert!(sql.contains("text_results"));
693        assert!(sql.contains("rrf_score"));
694        assert!(sql.contains("<=>"));
695        assert!(sql.contains("ts_rank"));
696        assert!(sql.contains("FULL OUTER JOIN"));
697    }
698
699    #[test]
700    fn test_all_metrics_produce_valid_sql() {
701        for metric in [
702            DistanceMetric::L2,
703            DistanceMetric::InnerProduct,
704            DistanceMetric::Cosine,
705            DistanceMetric::L1,
706        ] {
707            let query = VectorSearchBuilder::new("t", "c")
708                .query(test_embedding())
709                .metric(metric)
710                .build();
711            let sql = query.to_sql();
712            assert!(sql.contains(metric.operator()), "failed for {metric}");
713        }
714    }
715}