Skip to main content

fraiseql_core/utils/
vector.rs

1//! Vector query builder for pgvector similarity search.
2//!
3//! This module provides SQL query generation for pgvector operations including:
4//! - Similarity search with configurable distance metrics
5//! - Vector insert and upsert operations
6//! - Proper parameter binding for vector data
7//!
8//! # Example
9//!
10//! ```
11//! use fraiseql_core::utils::vector::{VectorQueryBuilder, VectorSearchQuery};
12//! use fraiseql_core::schema::DistanceMetric;
13//!
14//! let builder = VectorQueryBuilder::new();
15//! let query = VectorSearchQuery {
16//!     table: "documents".to_string(),
17//!     embedding_column: "embedding".to_string(),
18//!     select_columns: vec!["id".to_string(), "content".to_string()],
19//!     distance_metric: DistanceMetric::Cosine,
20//!     limit: 10,
21//!     where_clause: None,
22//!     order_by: None,
23//!     include_distance: false,
24//!     offset: None,
25//! };
26//!
27//! let (sql, _params) = builder.similarity_search(&query, &[0.1, 0.2, 0.3]);
28//! assert!(sql.contains("documents"));
29//! ```
30
31use serde::{Deserialize, Serialize};
32
33use crate::schema::{DistanceMetric, VectorConfig};
34
35/// A SQL parameter value for vector queries.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[non_exhaustive]
38pub enum VectorParam {
39    /// A vector embedding (array of floats).
40    Vector(Vec<f32>),
41    /// An integer value.
42    Int(i64),
43    /// A string value.
44    String(String),
45    /// A JSON value.
46    Json(serde_json::Value),
47}
48
49impl VectorParam {
50    /// Convert to SQL literal string for debugging.
51    #[must_use]
52    pub fn to_sql_literal(&self) -> String {
53        match self {
54            VectorParam::Vector(v) => {
55                let values: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
56                format!("'[{}]'::vector", values.join(","))
57            },
58            VectorParam::Int(i) => i.to_string(),
59            VectorParam::String(s) => format!("'{}'", s.replace('\'', "''")),
60            VectorParam::Json(j) => format!("'{j}'::jsonb"),
61        }
62    }
63}
64
65/// Configuration for a similarity search query.
66#[derive(Debug, Clone)]
67pub struct VectorSearchQuery {
68    /// Table or view to query.
69    pub table:            String,
70    /// Column containing the vector embedding.
71    pub embedding_column: String,
72    /// Columns to select (empty = all).
73    pub select_columns:   Vec<String>,
74    /// Distance metric to use.
75    pub distance_metric:  DistanceMetric,
76    /// Maximum number of results.
77    pub limit:            u32,
78    /// Optional WHERE clause (without "WHERE" keyword).
79    pub where_clause:     Option<String>,
80    /// Optional additional ORDER BY clause (applied after distance ordering).
81    pub order_by:         Option<String>,
82    /// Whether to include the distance score in results.
83    pub include_distance: bool,
84    /// Optional offset for pagination.
85    pub offset:           Option<u32>,
86}
87
88impl Default for VectorSearchQuery {
89    fn default() -> Self {
90        Self {
91            table:            String::new(),
92            embedding_column: "embedding".to_string(),
93            select_columns:   Vec::new(),
94            distance_metric:  DistanceMetric::Cosine,
95            limit:            10,
96            where_clause:     None,
97            order_by:         None,
98            include_distance: false,
99            offset:           None,
100        }
101    }
102}
103
104impl VectorSearchQuery {
105    /// Create a new search query for a table.
106    pub fn new(table: impl Into<String>) -> Self {
107        Self {
108            table: table.into(),
109            ..Default::default()
110        }
111    }
112
113    /// Set the embedding column.
114    pub fn with_embedding_column(mut self, column: impl Into<String>) -> Self {
115        self.embedding_column = column.into();
116        self
117    }
118
119    /// Set the columns to select.
120    pub fn with_select_columns(mut self, columns: Vec<String>) -> Self {
121        self.select_columns = columns;
122        self
123    }
124
125    /// Set the distance metric.
126    pub const fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
127        self.distance_metric = metric;
128        self
129    }
130
131    /// Set the result limit.
132    pub const fn with_limit(mut self, limit: u32) -> Self {
133        self.limit = limit;
134        self
135    }
136
137    /// Set a WHERE clause filter.
138    pub fn with_where(mut self, clause: impl Into<String>) -> Self {
139        self.where_clause = Some(clause.into());
140        self
141    }
142
143    /// Include distance score in results.
144    pub const fn with_distance_score(mut self) -> Self {
145        self.include_distance = true;
146        self
147    }
148
149    /// Set pagination offset.
150    pub const fn with_offset(mut self, offset: u32) -> Self {
151        self.offset = Some(offset);
152        self
153    }
154}
155
156/// Configuration for a vector insert/upsert operation.
157#[derive(Debug, Clone)]
158pub struct VectorInsertQuery {
159    /// Table to insert into.
160    pub table:            String,
161    /// Columns to insert (in order).
162    pub columns:          Vec<String>,
163    /// Name of the vector column.
164    pub vector_column:    String,
165    /// Whether to upsert (ON CONFLICT DO UPDATE).
166    pub upsert:           bool,
167    /// Conflict column(s) for upsert.
168    pub conflict_columns: Vec<String>,
169    /// Columns to update on conflict (empty = all non-conflict columns).
170    pub update_columns:   Vec<String>,
171    /// Whether to return inserted IDs.
172    pub returning:        Option<String>,
173}
174
175impl Default for VectorInsertQuery {
176    fn default() -> Self {
177        Self {
178            table:            String::new(),
179            columns:          Vec::new(),
180            vector_column:    "embedding".to_string(),
181            upsert:           false,
182            conflict_columns: vec!["id".to_string()],
183            update_columns:   Vec::new(),
184            returning:        Some("id".to_string()),
185        }
186    }
187}
188
189impl VectorInsertQuery {
190    /// Create a new insert query.
191    pub fn new(table: impl Into<String>) -> Self {
192        Self {
193            table: table.into(),
194            ..Default::default()
195        }
196    }
197
198    /// Set the columns to insert.
199    pub fn with_columns(mut self, columns: Vec<String>) -> Self {
200        self.columns = columns;
201        self
202    }
203
204    /// Set the vector column name.
205    pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
206        self.vector_column = column.into();
207        self
208    }
209
210    /// Enable upsert mode.
211    pub fn with_upsert(mut self, conflict_columns: Vec<String>) -> Self {
212        self.upsert = true;
213        self.conflict_columns = conflict_columns;
214        self
215    }
216
217    /// Set columns to update on conflict.
218    pub fn with_update_columns(mut self, columns: Vec<String>) -> Self {
219        self.update_columns = columns;
220        self
221    }
222
223    /// Set the RETURNING clause.
224    pub fn with_returning(mut self, column: impl Into<String>) -> Self {
225        self.returning = Some(column.into());
226        self
227    }
228}
229
230/// Builder for pgvector SQL queries.
231///
232/// This struct generates SQL for vector similarity search and manipulation
233/// operations using `PostgreSQL`'s `pgvector` extension.
234#[must_use = "call .build() to construct the final value"]
235#[derive(Debug, Clone, Default)]
236pub struct VectorQueryBuilder {
237    /// Parameter placeholder style ($1 vs ?).
238    placeholder_style: PlaceholderStyle,
239}
240
241/// Style of parameter placeholders in generated SQL.
242#[derive(Debug, Clone, Copy, Default)]
243#[non_exhaustive]
244pub enum PlaceholderStyle {
245    /// `PostgreSQL` style: `$1`, `$2`, `$3`
246    #[default]
247    Dollar,
248    /// MySQL/SQLite style: `?`, `?`, `?`
249    QuestionMark,
250}
251
252impl VectorQueryBuilder {
253    /// Create a new vector query builder.
254    pub fn new() -> Self {
255        Self::default()
256    }
257
258    /// Create a builder with question mark placeholders.
259    pub const fn with_question_marks() -> Self {
260        Self {
261            placeholder_style: PlaceholderStyle::QuestionMark,
262        }
263    }
264
265    /// Generate a parameter placeholder.
266    fn placeholder(&self, index: usize) -> String {
267        match self.placeholder_style {
268            PlaceholderStyle::Dollar => format!("${index}"),
269            PlaceholderStyle::QuestionMark => "?".to_string(),
270        }
271    }
272
273    /// Build a similarity search query.
274    ///
275    /// Generates SQL like:
276    /// ```sql
277    /// SELECT id, content, (embedding <=> $1::vector) AS distance
278    /// FROM documents
279    /// WHERE metadata->>'type' = 'article'
280    /// ORDER BY embedding <=> $1::vector
281    /// LIMIT 10
282    /// ```
283    ///
284    /// # Arguments
285    /// * `query` - The search query configuration
286    /// * `query_embedding` - The embedding vector to search for
287    ///
288    /// # Returns
289    /// A tuple of (SQL string, parameter values)
290    #[must_use]
291    pub fn similarity_search(
292        &self,
293        query: &VectorSearchQuery,
294        query_embedding: &[f32],
295    ) -> (String, Vec<VectorParam>) {
296        let mut params = Vec::new();
297        let mut param_idx = 1;
298
299        // Add the query embedding as the first parameter
300        params.push(VectorParam::Vector(query_embedding.to_vec()));
301        let embedding_placeholder = format!("{}::vector", self.placeholder(param_idx));
302        param_idx += 1;
303
304        let distance_op = query.distance_metric.operator();
305
306        // Build SELECT clause
307        let select_clause = if query.select_columns.is_empty() {
308            if query.include_distance {
309                format!(
310                    "*, ({} {} {}) AS distance",
311                    query.embedding_column, distance_op, embedding_placeholder
312                )
313            } else {
314                "*".to_string()
315            }
316        } else {
317            let cols = query.select_columns.join(", ");
318            if query.include_distance {
319                format!(
320                    "{}, ({} {} {}) AS distance",
321                    cols, query.embedding_column, distance_op, embedding_placeholder
322                )
323            } else {
324                cols
325            }
326        };
327
328        // Build WHERE clause
329        let where_clause = if let Some(ref clause) = query.where_clause {
330            format!("\nWHERE {clause}")
331        } else {
332            String::new()
333        };
334
335        // Build ORDER BY clause (always order by distance for similarity search)
336        let order_clause = format!(
337            "\nORDER BY {} {} {}",
338            query.embedding_column, distance_op, embedding_placeholder
339        );
340
341        // Build LIMIT clause
342        let limit_clause = format!("\nLIMIT {}", self.placeholder(param_idx));
343        params.push(VectorParam::Int(i64::from(query.limit)));
344        param_idx += 1;
345
346        // Build OFFSET clause
347        let offset_clause = if let Some(offset) = query.offset {
348            let clause = format!("\nOFFSET {}", self.placeholder(param_idx));
349            params.push(VectorParam::Int(i64::from(offset)));
350            clause
351        } else {
352            String::new()
353        };
354
355        let sql = format!(
356            "SELECT {}\nFROM {}{}{}{}{}",
357            select_clause, query.table, where_clause, order_clause, limit_clause, offset_clause
358        );
359
360        (sql, params)
361    }
362
363    /// Build a single vector insert query.
364    ///
365    /// Generates SQL like:
366    /// ```sql
367    /// INSERT INTO documents (id, content, embedding)
368    /// VALUES ($1, $2, $3::vector)
369    /// RETURNING id
370    /// ```
371    #[must_use]
372    pub fn insert_one(
373        &self,
374        query: &VectorInsertQuery,
375        values: &[VectorParam],
376    ) -> (String, Vec<VectorParam>) {
377        let columns = query.columns.join(", ");
378
379        let placeholders: Vec<String> = values
380            .iter()
381            .enumerate()
382            .map(|(i, v)| {
383                let ph = self.placeholder(i + 1);
384                if matches!(v, VectorParam::Vector(_)) {
385                    format!("{ph}::vector")
386                } else {
387                    ph
388                }
389            })
390            .collect();
391
392        let values_clause = placeholders.join(", ");
393
394        let returning_clause = if let Some(ref col) = query.returning {
395            format!("\nRETURNING {col}")
396        } else {
397            String::new()
398        };
399
400        let sql = if query.upsert {
401            let conflict_cols = query.conflict_columns.join(", ");
402
403            // Determine which columns to update
404            let update_cols: Vec<&String> = if query.update_columns.is_empty() {
405                // Update all non-conflict columns
406                query.columns.iter().filter(|c| !query.conflict_columns.contains(c)).collect()
407            } else {
408                query.update_columns.iter().collect()
409            };
410
411            let update_clause: String = update_cols
412                .iter()
413                .map(|c| format!("{c} = EXCLUDED.{c}"))
414                .collect::<Vec<_>>()
415                .join(", ");
416
417            format!(
418                "INSERT INTO {} ({})\nVALUES ({})\nON CONFLICT ({}) DO UPDATE SET {}{}",
419                query.table, columns, values_clause, conflict_cols, update_clause, returning_clause
420            )
421        } else {
422            format!(
423                "INSERT INTO {} ({})\nVALUES ({}){}",
424                query.table, columns, values_clause, returning_clause
425            )
426        };
427
428        (sql, values.to_vec())
429    }
430
431    /// Build a batch vector insert query.
432    ///
433    /// Generates SQL like:
434    /// ```sql
435    /// INSERT INTO documents (id, content, embedding)
436    /// VALUES
437    ///   ($1, $2, $3::vector),
438    ///   ($4, $5, $6::vector),
439    ///   ($7, $8, $9::vector)
440    /// RETURNING id
441    /// ```
442    #[must_use]
443    pub fn insert_batch(
444        &self,
445        query: &VectorInsertQuery,
446        rows: &[Vec<VectorParam>],
447    ) -> (String, Vec<VectorParam>) {
448        if rows.is_empty() {
449            return (String::new(), Vec::new());
450        }
451
452        let columns = query.columns.join(", ");
453        let cols_per_row = query.columns.len();
454
455        let mut all_params = Vec::new();
456        let mut values_clauses = Vec::new();
457
458        for (row_idx, row) in rows.iter().enumerate() {
459            let base_idx = row_idx * cols_per_row + 1;
460            let placeholders: Vec<String> = row
461                .iter()
462                .enumerate()
463                .map(|(i, v)| {
464                    let ph = self.placeholder(base_idx + i);
465                    if matches!(v, VectorParam::Vector(_)) {
466                        format!("{ph}::vector")
467                    } else {
468                        ph
469                    }
470                })
471                .collect();
472
473            values_clauses.push(format!("({})", placeholders.join(", ")));
474            all_params.extend(row.clone());
475        }
476
477        let returning_clause = if let Some(ref col) = query.returning {
478            format!("\nRETURNING {col}")
479        } else {
480            String::new()
481        };
482
483        let sql = format!(
484            "INSERT INTO {} ({})\nVALUES\n  {}{}",
485            query.table,
486            columns,
487            values_clauses.join(",\n  "),
488            returning_clause
489        );
490
491        (sql, all_params)
492    }
493
494    /// Build a query to create a vector index.
495    ///
496    /// Generates SQL like:
497    /// ```sql
498    /// CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)
499    /// ```
500    #[must_use]
501    pub fn create_index(&self, config: &VectorConfig, table: &str, column: &str) -> Option<String> {
502        config.index_type.index_sql(table, column, config.distance_metric)
503    }
504}
505
506#[cfg(test)]
507mod tests {
508    use super::*;
509
510    #[test]
511    fn test_similarity_search_basic() {
512        let builder = VectorQueryBuilder::new();
513        let query = VectorSearchQuery::new("documents")
514            .with_embedding_column("embedding")
515            .with_limit(10);
516
517        let embedding = vec![0.1, 0.2, 0.3];
518        let (sql, params) = builder.similarity_search(&query, &embedding);
519
520        assert!(sql.contains("SELECT *"));
521        assert!(sql.contains("FROM documents"));
522        assert!(sql.contains("ORDER BY embedding <=>"));
523        assert!(sql.contains("$1::vector"));
524        assert!(sql.contains("LIMIT $2"));
525        assert_eq!(params.len(), 2);
526    }
527
528    #[test]
529    fn test_similarity_search_with_columns() {
530        let builder = VectorQueryBuilder::new();
531        let query = VectorSearchQuery::new("docs")
532            .with_select_columns(vec!["id".to_string(), "content".to_string()])
533            .with_distance_score()
534            .with_limit(5);
535
536        let embedding = vec![0.1, 0.2];
537        let (sql, params) = builder.similarity_search(&query, &embedding);
538
539        assert!(sql.contains("SELECT id, content,"));
540        assert!(sql.contains("AS distance"));
541        assert!(sql.contains("LIMIT $2"));
542        assert_eq!(params.len(), 2);
543    }
544
545    #[test]
546    fn test_similarity_search_with_where() {
547        let builder = VectorQueryBuilder::new();
548        let query = VectorSearchQuery::new("documents")
549            .with_where("metadata->>'type' = 'article'")
550            .with_limit(10);
551
552        let embedding = vec![0.1, 0.2, 0.3];
553        let (sql, _) = builder.similarity_search(&query, &embedding);
554
555        assert!(sql.contains("WHERE metadata->>'type' = 'article'"));
556    }
557
558    #[test]
559    fn test_similarity_search_with_offset() {
560        let builder = VectorQueryBuilder::new();
561        let query = VectorSearchQuery::new("documents").with_limit(10).with_offset(20);
562
563        let embedding = vec![0.1, 0.2];
564        let (sql, params) = builder.similarity_search(&query, &embedding);
565
566        assert!(sql.contains("OFFSET $3"));
567        assert_eq!(params.len(), 3);
568    }
569
570    #[test]
571    fn test_similarity_search_l2_distance() {
572        let builder = VectorQueryBuilder::new();
573        let query = VectorSearchQuery::new("docs")
574            .with_distance_metric(DistanceMetric::L2)
575            .with_limit(5);
576
577        let embedding = vec![0.1, 0.2];
578        let (sql, _) = builder.similarity_search(&query, &embedding);
579
580        assert!(sql.contains("<->"));
581    }
582
583    #[test]
584    fn test_similarity_search_inner_product() {
585        let builder = VectorQueryBuilder::new();
586        let query = VectorSearchQuery::new("docs")
587            .with_distance_metric(DistanceMetric::InnerProduct)
588            .with_limit(5);
589
590        let embedding = vec![0.1, 0.2];
591        let (sql, _) = builder.similarity_search(&query, &embedding);
592
593        assert!(sql.contains("<#>"));
594    }
595
596    #[test]
597    fn test_insert_one_basic() {
598        let builder = VectorQueryBuilder::new();
599        let query = VectorInsertQuery::new("documents").with_columns(vec![
600            "id".to_string(),
601            "content".to_string(),
602            "embedding".to_string(),
603        ]);
604
605        let values = vec![
606            VectorParam::String("doc1".to_string()),
607            VectorParam::String("Hello world".to_string()),
608            VectorParam::Vector(vec![0.1, 0.2, 0.3]),
609        ];
610
611        let (sql, params) = builder.insert_one(&query, &values);
612
613        assert!(sql.contains("INSERT INTO documents (id, content, embedding)"));
614        assert!(sql.contains("VALUES ($1, $2, $3::vector)"));
615        assert!(sql.contains("RETURNING id"));
616        assert_eq!(params.len(), 3);
617    }
618
619    #[test]
620    fn test_insert_upsert() {
621        let builder = VectorQueryBuilder::new();
622        let query = VectorInsertQuery::new("documents")
623            .with_columns(vec![
624                "id".to_string(),
625                "content".to_string(),
626                "embedding".to_string(),
627            ])
628            .with_upsert(vec!["id".to_string()]);
629
630        let values = vec![
631            VectorParam::String("doc1".to_string()),
632            VectorParam::String("Hello world".to_string()),
633            VectorParam::Vector(vec![0.1, 0.2, 0.3]),
634        ];
635
636        let (sql, _) = builder.insert_one(&query, &values);
637
638        assert!(sql.contains("ON CONFLICT (id) DO UPDATE SET"));
639        assert!(sql.contains("content = EXCLUDED.content"));
640        assert!(sql.contains("embedding = EXCLUDED.embedding"));
641    }
642
643    #[test]
644    fn test_insert_batch() {
645        let builder = VectorQueryBuilder::new();
646        let query = VectorInsertQuery::new("documents")
647            .with_columns(vec!["id".to_string(), "embedding".to_string()]);
648
649        let rows = vec![
650            vec![
651                VectorParam::String("doc1".to_string()),
652                VectorParam::Vector(vec![0.1, 0.2]),
653            ],
654            vec![
655                VectorParam::String("doc2".to_string()),
656                VectorParam::Vector(vec![0.3, 0.4]),
657            ],
658        ];
659
660        let (sql, params) = builder.insert_batch(&query, &rows);
661
662        assert!(sql.contains("INSERT INTO documents (id, embedding)"));
663        assert!(sql.contains("($1, $2::vector)"));
664        assert!(sql.contains("($3, $4::vector)"));
665        assert_eq!(params.len(), 4);
666    }
667
668    #[test]
669    fn test_insert_batch_empty() {
670        let builder = VectorQueryBuilder::new();
671        let query = VectorInsertQuery::new("documents").with_columns(vec!["id".to_string()]);
672
673        let (sql, params) = builder.insert_batch(&query, &[]);
674
675        assert!(sql.is_empty());
676        assert!(params.is_empty());
677    }
678
679    #[test]
680    fn test_create_index_hnsw() {
681        let builder = VectorQueryBuilder::new();
682        let config = VectorConfig::openai();
683
684        let sql = builder.create_index(&config, "documents", "embedding");
685
686        assert_eq!(
687            sql,
688            Some("CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)".to_string())
689        );
690    }
691
692    #[test]
693    fn test_create_index_ivfflat() {
694        let builder = VectorQueryBuilder::new();
695        let config = VectorConfig::new(1536)
696            .with_index(crate::schema::VectorIndexType::IvfFlat)
697            .with_distance(DistanceMetric::L2);
698
699        let sql = builder.create_index(&config, "docs", "vec");
700
701        assert_eq!(sql, Some("CREATE INDEX ON docs USING ivfflat (vec vector_l2_ops)".to_string()));
702    }
703
704    #[test]
705    fn test_create_index_none() {
706        let builder = VectorQueryBuilder::new();
707        let config = VectorConfig::new(1536).with_index(crate::schema::VectorIndexType::None);
708
709        let sql = builder.create_index(&config, "documents", "embedding");
710
711        assert_eq!(sql, None);
712    }
713
714    #[test]
715    fn test_vector_param_to_sql_literal() {
716        let vec_param = VectorParam::Vector(vec![0.1, 0.2, 0.3]);
717        assert_eq!(vec_param.to_sql_literal(), "'[0.1,0.2,0.3]'::vector");
718
719        let int_param = VectorParam::Int(42);
720        assert_eq!(int_param.to_sql_literal(), "42");
721
722        let str_param = VectorParam::String("hello".to_string());
723        assert_eq!(str_param.to_sql_literal(), "'hello'");
724
725        let str_param_escape = VectorParam::String("it's a test".to_string());
726        assert_eq!(str_param_escape.to_sql_literal(), "'it''s a test'");
727    }
728
729    #[test]
730    fn test_question_mark_placeholders() {
731        let builder = VectorQueryBuilder::with_question_marks();
732        let query = VectorSearchQuery::new("docs").with_limit(10);
733
734        let embedding = vec![0.1, 0.2];
735        let (sql, _) = builder.similarity_search(&query, &embedding);
736
737        assert!(sql.contains("?::vector"));
738        assert!(!sql.contains("$1"));
739    }
740}