Skip to main content

fraiseql_core/utils/
vector.rs

1//! Vector query builder for pgvector similarity search.
2//!
3//! This module provides SQL query generation for pgvector operations including:
4//! - Similarity search with configurable distance metrics
5//! - Vector insert and upsert operations
6//! - Proper parameter binding for vector data
7//!
8//! # Example
9//!
10//! ```
11//! use fraiseql_core::utils::vector::{VectorQueryBuilder, VectorSearchQuery};
12//! use fraiseql_core::schema::DistanceMetric;
13//!
14//! let builder = VectorQueryBuilder::new();
15//! let query = VectorSearchQuery {
16//!     table: "documents".to_string(),
17//!     embedding_column: "embedding".to_string(),
18//!     select_columns: vec!["id".to_string(), "content".to_string()],
19//!     distance_metric: DistanceMetric::Cosine,
20//!     limit: 10,
21//!     where_clause: None,
22//!     order_by: None,
23//!     include_distance: false,
24//!     offset: None,
25//! };
26//!
27//! let (sql, _params) = builder.similarity_search(&query, &[0.1, 0.2, 0.3]);
28//! assert!(sql.contains("documents"));
29//! ```
30
31use serde::{Deserialize, Serialize};
32
33use crate::schema::{DistanceMetric, VectorConfig};
34
35/// A SQL parameter value for vector queries.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum VectorParam {
38    /// A vector embedding (array of floats).
39    Vector(Vec<f32>),
40    /// An integer value.
41    Int(i64),
42    /// A string value.
43    String(String),
44    /// A JSON value.
45    Json(serde_json::Value),
46}
47
48impl VectorParam {
49    /// Convert to SQL literal string for debugging.
50    #[must_use]
51    pub fn to_sql_literal(&self) -> String {
52        match self {
53            VectorParam::Vector(v) => {
54                let values: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
55                format!("'[{}]'::vector", values.join(","))
56            },
57            VectorParam::Int(i) => i.to_string(),
58            VectorParam::String(s) => format!("'{}'", s.replace('\'', "''")),
59            VectorParam::Json(j) => format!("'{j}'::jsonb"),
60        }
61    }
62}
63
64/// Configuration for a similarity search query.
65#[derive(Debug, Clone)]
66pub struct VectorSearchQuery {
67    /// Table or view to query.
68    pub table:            String,
69    /// Column containing the vector embedding.
70    pub embedding_column: String,
71    /// Columns to select (empty = all).
72    pub select_columns:   Vec<String>,
73    /// Distance metric to use.
74    pub distance_metric:  DistanceMetric,
75    /// Maximum number of results.
76    pub limit:            u32,
77    /// Optional WHERE clause (without "WHERE" keyword).
78    pub where_clause:     Option<String>,
79    /// Optional additional ORDER BY clause (applied after distance ordering).
80    pub order_by:         Option<String>,
81    /// Whether to include the distance score in results.
82    pub include_distance: bool,
83    /// Optional offset for pagination.
84    pub offset:           Option<u32>,
85}
86
87impl Default for VectorSearchQuery {
88    fn default() -> Self {
89        Self {
90            table:            String::new(),
91            embedding_column: "embedding".to_string(),
92            select_columns:   Vec::new(),
93            distance_metric:  DistanceMetric::Cosine,
94            limit:            10,
95            where_clause:     None,
96            order_by:         None,
97            include_distance: false,
98            offset:           None,
99        }
100    }
101}
102
103impl VectorSearchQuery {
104    /// Create a new search query for a table.
105    pub fn new(table: impl Into<String>) -> Self {
106        Self {
107            table: table.into(),
108            ..Default::default()
109        }
110    }
111
112    /// Set the embedding column.
113    #[must_use]
114    pub fn with_embedding_column(mut self, column: impl Into<String>) -> Self {
115        self.embedding_column = column.into();
116        self
117    }
118
119    /// Set the columns to select.
120    #[must_use]
121    pub fn with_select_columns(mut self, columns: Vec<String>) -> Self {
122        self.select_columns = columns;
123        self
124    }
125
126    /// Set the distance metric.
127    #[must_use]
128    pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
129        self.distance_metric = metric;
130        self
131    }
132
133    /// Set the result limit.
134    #[must_use]
135    pub fn with_limit(mut self, limit: u32) -> Self {
136        self.limit = limit;
137        self
138    }
139
140    /// Set a WHERE clause filter.
141    #[must_use]
142    pub fn with_where(mut self, clause: impl Into<String>) -> Self {
143        self.where_clause = Some(clause.into());
144        self
145    }
146
147    /// Include distance score in results.
148    #[must_use]
149    pub fn with_distance_score(mut self) -> Self {
150        self.include_distance = true;
151        self
152    }
153
154    /// Set pagination offset.
155    #[must_use]
156    pub fn with_offset(mut self, offset: u32) -> Self {
157        self.offset = Some(offset);
158        self
159    }
160}
161
162/// Configuration for a vector insert/upsert operation.
163#[derive(Debug, Clone)]
164pub struct VectorInsertQuery {
165    /// Table to insert into.
166    pub table:            String,
167    /// Columns to insert (in order).
168    pub columns:          Vec<String>,
169    /// Name of the vector column.
170    pub vector_column:    String,
171    /// Whether to upsert (ON CONFLICT DO UPDATE).
172    pub upsert:           bool,
173    /// Conflict column(s) for upsert.
174    pub conflict_columns: Vec<String>,
175    /// Columns to update on conflict (empty = all non-conflict columns).
176    pub update_columns:   Vec<String>,
177    /// Whether to return inserted IDs.
178    pub returning:        Option<String>,
179}
180
181impl Default for VectorInsertQuery {
182    fn default() -> Self {
183        Self {
184            table:            String::new(),
185            columns:          Vec::new(),
186            vector_column:    "embedding".to_string(),
187            upsert:           false,
188            conflict_columns: vec!["id".to_string()],
189            update_columns:   Vec::new(),
190            returning:        Some("id".to_string()),
191        }
192    }
193}
194
195impl VectorInsertQuery {
196    /// Create a new insert query.
197    pub fn new(table: impl Into<String>) -> Self {
198        Self {
199            table: table.into(),
200            ..Default::default()
201        }
202    }
203
204    /// Set the columns to insert.
205    #[must_use]
206    pub fn with_columns(mut self, columns: Vec<String>) -> Self {
207        self.columns = columns;
208        self
209    }
210
211    /// Set the vector column name.
212    #[must_use]
213    pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
214        self.vector_column = column.into();
215        self
216    }
217
218    /// Enable upsert mode.
219    #[must_use]
220    pub fn with_upsert(mut self, conflict_columns: Vec<String>) -> Self {
221        self.upsert = true;
222        self.conflict_columns = conflict_columns;
223        self
224    }
225
226    /// Set columns to update on conflict.
227    #[must_use]
228    pub fn with_update_columns(mut self, columns: Vec<String>) -> Self {
229        self.update_columns = columns;
230        self
231    }
232
233    /// Set the RETURNING clause.
234    #[must_use]
235    pub fn with_returning(mut self, column: impl Into<String>) -> Self {
236        self.returning = Some(column.into());
237        self
238    }
239}
240
241/// Builder for pgvector SQL queries.
242///
243/// This struct generates SQL for vector similarity search and manipulation
244/// operations using `PostgreSQL`'s `pgvector` extension.
245#[derive(Debug, Clone, Default)]
246pub struct VectorQueryBuilder {
247    /// Parameter placeholder style ($1 vs ?).
248    placeholder_style: PlaceholderStyle,
249}
250
251/// Style of parameter placeholders in generated SQL.
252#[derive(Debug, Clone, Copy, Default)]
253pub enum PlaceholderStyle {
254    /// `PostgreSQL` style: `$1`, `$2`, `$3`
255    #[default]
256    Dollar,
257    /// MySQL/SQLite style: `?`, `?`, `?`
258    QuestionMark,
259}
260
261impl VectorQueryBuilder {
262    /// Create a new vector query builder.
263    #[must_use]
264    pub fn new() -> Self {
265        Self::default()
266    }
267
268    /// Create a builder with question mark placeholders.
269    #[must_use]
270    pub fn with_question_marks() -> Self {
271        Self {
272            placeholder_style: PlaceholderStyle::QuestionMark,
273        }
274    }
275
276    /// Generate a parameter placeholder.
277    fn placeholder(&self, index: usize) -> String {
278        match self.placeholder_style {
279            PlaceholderStyle::Dollar => format!("${index}"),
280            PlaceholderStyle::QuestionMark => "?".to_string(),
281        }
282    }
283
284    /// Build a similarity search query.
285    ///
286    /// Generates SQL like:
287    /// ```sql
288    /// SELECT id, content, (embedding <=> $1::vector) AS distance
289    /// FROM documents
290    /// WHERE metadata->>'type' = 'article'
291    /// ORDER BY embedding <=> $1::vector
292    /// LIMIT 10
293    /// ```
294    ///
295    /// # Arguments
296    /// * `query` - The search query configuration
297    /// * `query_embedding` - The embedding vector to search for
298    ///
299    /// # Returns
300    /// A tuple of (SQL string, parameter values)
301    #[must_use]
302    pub fn similarity_search(
303        &self,
304        query: &VectorSearchQuery,
305        query_embedding: &[f32],
306    ) -> (String, Vec<VectorParam>) {
307        let mut params = Vec::new();
308        let mut param_idx = 1;
309
310        // Add the query embedding as the first parameter
311        params.push(VectorParam::Vector(query_embedding.to_vec()));
312        let embedding_placeholder = format!("{}::vector", self.placeholder(param_idx));
313        param_idx += 1;
314
315        let distance_op = query.distance_metric.operator();
316
317        // Build SELECT clause
318        let select_clause = if query.select_columns.is_empty() {
319            if query.include_distance {
320                format!(
321                    "*, ({} {} {}) AS distance",
322                    query.embedding_column, distance_op, embedding_placeholder
323                )
324            } else {
325                "*".to_string()
326            }
327        } else {
328            let cols = query.select_columns.join(", ");
329            if query.include_distance {
330                format!(
331                    "{}, ({} {} {}) AS distance",
332                    cols, query.embedding_column, distance_op, embedding_placeholder
333                )
334            } else {
335                cols
336            }
337        };
338
339        // Build WHERE clause
340        let where_clause = if let Some(ref clause) = query.where_clause {
341            format!("\nWHERE {clause}")
342        } else {
343            String::new()
344        };
345
346        // Build ORDER BY clause (always order by distance for similarity search)
347        let order_clause = format!(
348            "\nORDER BY {} {} {}",
349            query.embedding_column, distance_op, embedding_placeholder
350        );
351
352        // Build LIMIT clause
353        let limit_clause = format!("\nLIMIT {}", self.placeholder(param_idx));
354        params.push(VectorParam::Int(i64::from(query.limit)));
355        param_idx += 1;
356
357        // Build OFFSET clause
358        let offset_clause = if let Some(offset) = query.offset {
359            let clause = format!("\nOFFSET {}", self.placeholder(param_idx));
360            params.push(VectorParam::Int(i64::from(offset)));
361            clause
362        } else {
363            String::new()
364        };
365
366        let sql = format!(
367            "SELECT {}\nFROM {}{}{}{}{}",
368            select_clause, query.table, where_clause, order_clause, limit_clause, offset_clause
369        );
370
371        (sql, params)
372    }
373
374    /// Build a single vector insert query.
375    ///
376    /// Generates SQL like:
377    /// ```sql
378    /// INSERT INTO documents (id, content, embedding)
379    /// VALUES ($1, $2, $3::vector)
380    /// RETURNING id
381    /// ```
382    #[must_use]
383    pub fn insert_one(
384        &self,
385        query: &VectorInsertQuery,
386        values: &[VectorParam],
387    ) -> (String, Vec<VectorParam>) {
388        let columns = query.columns.join(", ");
389
390        let placeholders: Vec<String> = values
391            .iter()
392            .enumerate()
393            .map(|(i, v)| {
394                let ph = self.placeholder(i + 1);
395                if matches!(v, VectorParam::Vector(_)) {
396                    format!("{ph}::vector")
397                } else {
398                    ph
399                }
400            })
401            .collect();
402
403        let values_clause = placeholders.join(", ");
404
405        let returning_clause = if let Some(ref col) = query.returning {
406            format!("\nRETURNING {col}")
407        } else {
408            String::new()
409        };
410
411        let sql = if query.upsert {
412            let conflict_cols = query.conflict_columns.join(", ");
413
414            // Determine which columns to update
415            let update_cols: Vec<&String> = if query.update_columns.is_empty() {
416                // Update all non-conflict columns
417                query.columns.iter().filter(|c| !query.conflict_columns.contains(c)).collect()
418            } else {
419                query.update_columns.iter().collect()
420            };
421
422            let update_clause: String = update_cols
423                .iter()
424                .map(|c| format!("{c} = EXCLUDED.{c}"))
425                .collect::<Vec<_>>()
426                .join(", ");
427
428            format!(
429                "INSERT INTO {} ({})\nVALUES ({})\nON CONFLICT ({}) DO UPDATE SET {}{}",
430                query.table, columns, values_clause, conflict_cols, update_clause, returning_clause
431            )
432        } else {
433            format!(
434                "INSERT INTO {} ({})\nVALUES ({}){}",
435                query.table, columns, values_clause, returning_clause
436            )
437        };
438
439        (sql, values.to_vec())
440    }
441
442    /// Build a batch vector insert query.
443    ///
444    /// Generates SQL like:
445    /// ```sql
446    /// INSERT INTO documents (id, content, embedding)
447    /// VALUES
448    ///   ($1, $2, $3::vector),
449    ///   ($4, $5, $6::vector),
450    ///   ($7, $8, $9::vector)
451    /// RETURNING id
452    /// ```
453    #[must_use]
454    pub fn insert_batch(
455        &self,
456        query: &VectorInsertQuery,
457        rows: &[Vec<VectorParam>],
458    ) -> (String, Vec<VectorParam>) {
459        if rows.is_empty() {
460            return (String::new(), Vec::new());
461        }
462
463        let columns = query.columns.join(", ");
464        let cols_per_row = query.columns.len();
465
466        let mut all_params = Vec::new();
467        let mut values_clauses = Vec::new();
468
469        for (row_idx, row) in rows.iter().enumerate() {
470            let base_idx = row_idx * cols_per_row + 1;
471            let placeholders: Vec<String> = row
472                .iter()
473                .enumerate()
474                .map(|(i, v)| {
475                    let ph = self.placeholder(base_idx + i);
476                    if matches!(v, VectorParam::Vector(_)) {
477                        format!("{ph}::vector")
478                    } else {
479                        ph
480                    }
481                })
482                .collect();
483
484            values_clauses.push(format!("({})", placeholders.join(", ")));
485            all_params.extend(row.clone());
486        }
487
488        let returning_clause = if let Some(ref col) = query.returning {
489            format!("\nRETURNING {col}")
490        } else {
491            String::new()
492        };
493
494        let sql = format!(
495            "INSERT INTO {} ({})\nVALUES\n  {}{}",
496            query.table,
497            columns,
498            values_clauses.join(",\n  "),
499            returning_clause
500        );
501
502        (sql, all_params)
503    }
504
505    /// Build a query to create a vector index.
506    ///
507    /// Generates SQL like:
508    /// ```sql
509    /// CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)
510    /// ```
511    #[must_use]
512    pub fn create_index(&self, config: &VectorConfig, table: &str, column: &str) -> Option<String> {
513        config.index_type.index_sql(table, column, config.distance_metric)
514    }
515}
516
517#[cfg(test)]
518mod tests {
519    use super::*;
520
521    #[test]
522    fn test_similarity_search_basic() {
523        let builder = VectorQueryBuilder::new();
524        let query = VectorSearchQuery::new("documents")
525            .with_embedding_column("embedding")
526            .with_limit(10);
527
528        let embedding = vec![0.1, 0.2, 0.3];
529        let (sql, params) = builder.similarity_search(&query, &embedding);
530
531        assert!(sql.contains("SELECT *"));
532        assert!(sql.contains("FROM documents"));
533        assert!(sql.contains("ORDER BY embedding <=>"));
534        assert!(sql.contains("$1::vector"));
535        assert!(sql.contains("LIMIT $2"));
536        assert_eq!(params.len(), 2);
537    }
538
539    #[test]
540    fn test_similarity_search_with_columns() {
541        let builder = VectorQueryBuilder::new();
542        let query = VectorSearchQuery::new("docs")
543            .with_select_columns(vec!["id".to_string(), "content".to_string()])
544            .with_distance_score()
545            .with_limit(5);
546
547        let embedding = vec![0.1, 0.2];
548        let (sql, params) = builder.similarity_search(&query, &embedding);
549
550        assert!(sql.contains("SELECT id, content,"));
551        assert!(sql.contains("AS distance"));
552        assert!(sql.contains("LIMIT $2"));
553        assert_eq!(params.len(), 2);
554    }
555
556    #[test]
557    fn test_similarity_search_with_where() {
558        let builder = VectorQueryBuilder::new();
559        let query = VectorSearchQuery::new("documents")
560            .with_where("metadata->>'type' = 'article'")
561            .with_limit(10);
562
563        let embedding = vec![0.1, 0.2, 0.3];
564        let (sql, _) = builder.similarity_search(&query, &embedding);
565
566        assert!(sql.contains("WHERE metadata->>'type' = 'article'"));
567    }
568
569    #[test]
570    fn test_similarity_search_with_offset() {
571        let builder = VectorQueryBuilder::new();
572        let query = VectorSearchQuery::new("documents").with_limit(10).with_offset(20);
573
574        let embedding = vec![0.1, 0.2];
575        let (sql, params) = builder.similarity_search(&query, &embedding);
576
577        assert!(sql.contains("OFFSET $3"));
578        assert_eq!(params.len(), 3);
579    }
580
581    #[test]
582    fn test_similarity_search_l2_distance() {
583        let builder = VectorQueryBuilder::new();
584        let query = VectorSearchQuery::new("docs")
585            .with_distance_metric(DistanceMetric::L2)
586            .with_limit(5);
587
588        let embedding = vec![0.1, 0.2];
589        let (sql, _) = builder.similarity_search(&query, &embedding);
590
591        assert!(sql.contains("<->"));
592    }
593
594    #[test]
595    fn test_similarity_search_inner_product() {
596        let builder = VectorQueryBuilder::new();
597        let query = VectorSearchQuery::new("docs")
598            .with_distance_metric(DistanceMetric::InnerProduct)
599            .with_limit(5);
600
601        let embedding = vec![0.1, 0.2];
602        let (sql, _) = builder.similarity_search(&query, &embedding);
603
604        assert!(sql.contains("<#>"));
605    }
606
607    #[test]
608    fn test_insert_one_basic() {
609        let builder = VectorQueryBuilder::new();
610        let query = VectorInsertQuery::new("documents").with_columns(vec![
611            "id".to_string(),
612            "content".to_string(),
613            "embedding".to_string(),
614        ]);
615
616        let values = vec![
617            VectorParam::String("doc1".to_string()),
618            VectorParam::String("Hello world".to_string()),
619            VectorParam::Vector(vec![0.1, 0.2, 0.3]),
620        ];
621
622        let (sql, params) = builder.insert_one(&query, &values);
623
624        assert!(sql.contains("INSERT INTO documents (id, content, embedding)"));
625        assert!(sql.contains("VALUES ($1, $2, $3::vector)"));
626        assert!(sql.contains("RETURNING id"));
627        assert_eq!(params.len(), 3);
628    }
629
630    #[test]
631    fn test_insert_upsert() {
632        let builder = VectorQueryBuilder::new();
633        let query = VectorInsertQuery::new("documents")
634            .with_columns(vec![
635                "id".to_string(),
636                "content".to_string(),
637                "embedding".to_string(),
638            ])
639            .with_upsert(vec!["id".to_string()]);
640
641        let values = vec![
642            VectorParam::String("doc1".to_string()),
643            VectorParam::String("Hello world".to_string()),
644            VectorParam::Vector(vec![0.1, 0.2, 0.3]),
645        ];
646
647        let (sql, _) = builder.insert_one(&query, &values);
648
649        assert!(sql.contains("ON CONFLICT (id) DO UPDATE SET"));
650        assert!(sql.contains("content = EXCLUDED.content"));
651        assert!(sql.contains("embedding = EXCLUDED.embedding"));
652    }
653
654    #[test]
655    fn test_insert_batch() {
656        let builder = VectorQueryBuilder::new();
657        let query = VectorInsertQuery::new("documents")
658            .with_columns(vec!["id".to_string(), "embedding".to_string()]);
659
660        let rows = vec![
661            vec![
662                VectorParam::String("doc1".to_string()),
663                VectorParam::Vector(vec![0.1, 0.2]),
664            ],
665            vec![
666                VectorParam::String("doc2".to_string()),
667                VectorParam::Vector(vec![0.3, 0.4]),
668            ],
669        ];
670
671        let (sql, params) = builder.insert_batch(&query, &rows);
672
673        assert!(sql.contains("INSERT INTO documents (id, embedding)"));
674        assert!(sql.contains("($1, $2::vector)"));
675        assert!(sql.contains("($3, $4::vector)"));
676        assert_eq!(params.len(), 4);
677    }
678
679    #[test]
680    fn test_insert_batch_empty() {
681        let builder = VectorQueryBuilder::new();
682        let query = VectorInsertQuery::new("documents").with_columns(vec!["id".to_string()]);
683
684        let (sql, params) = builder.insert_batch(&query, &[]);
685
686        assert!(sql.is_empty());
687        assert!(params.is_empty());
688    }
689
690    #[test]
691    fn test_create_index_hnsw() {
692        let builder = VectorQueryBuilder::new();
693        let config = VectorConfig::openai();
694
695        let sql = builder.create_index(&config, "documents", "embedding");
696
697        assert_eq!(
698            sql,
699            Some("CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)".to_string())
700        );
701    }
702
703    #[test]
704    fn test_create_index_ivfflat() {
705        let builder = VectorQueryBuilder::new();
706        let config = VectorConfig::new(1536)
707            .with_index(crate::schema::VectorIndexType::IvfFlat)
708            .with_distance(DistanceMetric::L2);
709
710        let sql = builder.create_index(&config, "docs", "vec");
711
712        assert_eq!(sql, Some("CREATE INDEX ON docs USING ivfflat (vec vector_l2_ops)".to_string()));
713    }
714
715    #[test]
716    fn test_create_index_none() {
717        let builder = VectorQueryBuilder::new();
718        let config = VectorConfig::new(1536).with_index(crate::schema::VectorIndexType::None);
719
720        let sql = builder.create_index(&config, "documents", "embedding");
721
722        assert_eq!(sql, None);
723    }
724
725    #[test]
726    fn test_vector_param_to_sql_literal() {
727        let vec_param = VectorParam::Vector(vec![0.1, 0.2, 0.3]);
728        assert_eq!(vec_param.to_sql_literal(), "'[0.1,0.2,0.3]'::vector");
729
730        let int_param = VectorParam::Int(42);
731        assert_eq!(int_param.to_sql_literal(), "42");
732
733        let str_param = VectorParam::String("hello".to_string());
734        assert_eq!(str_param.to_sql_literal(), "'hello'");
735
736        let str_param_escape = VectorParam::String("it's a test".to_string());
737        assert_eq!(str_param_escape.to_sql_literal(), "'it''s a test'");
738    }
739
740    #[test]
741    fn test_question_mark_placeholders() {
742        let builder = VectorQueryBuilder::with_question_marks();
743        let query = VectorSearchQuery::new("docs").with_limit(10);
744
745        let embedding = vec![0.1, 0.2];
746        let (sql, _) = builder.similarity_search(&query, &embedding);
747
748        assert!(sql.contains("?::vector"));
749        assert!(!sql.contains("$1"));
750    }
751}