1use serde::{Deserialize, Serialize};
32
33use crate::schema::{DistanceMetric, VectorConfig};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37pub enum VectorParam {
38 Vector(Vec<f32>),
40 Int(i64),
42 String(String),
44 Json(serde_json::Value),
46}
47
48impl VectorParam {
49 #[must_use]
51 pub fn to_sql_literal(&self) -> String {
52 match self {
53 VectorParam::Vector(v) => {
54 let values: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
55 format!("'[{}]'::vector", values.join(","))
56 },
57 VectorParam::Int(i) => i.to_string(),
58 VectorParam::String(s) => format!("'{}'", s.replace('\'', "''")),
59 VectorParam::Json(j) => format!("'{j}'::jsonb"),
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
66pub struct VectorSearchQuery {
67 pub table: String,
69 pub embedding_column: String,
71 pub select_columns: Vec<String>,
73 pub distance_metric: DistanceMetric,
75 pub limit: u32,
77 pub where_clause: Option<String>,
79 pub order_by: Option<String>,
81 pub include_distance: bool,
83 pub offset: Option<u32>,
85}
86
87impl Default for VectorSearchQuery {
88 fn default() -> Self {
89 Self {
90 table: String::new(),
91 embedding_column: "embedding".to_string(),
92 select_columns: Vec::new(),
93 distance_metric: DistanceMetric::Cosine,
94 limit: 10,
95 where_clause: None,
96 order_by: None,
97 include_distance: false,
98 offset: None,
99 }
100 }
101}
102
103impl VectorSearchQuery {
104 pub fn new(table: impl Into<String>) -> Self {
106 Self {
107 table: table.into(),
108 ..Default::default()
109 }
110 }
111
112 #[must_use]
114 pub fn with_embedding_column(mut self, column: impl Into<String>) -> Self {
115 self.embedding_column = column.into();
116 self
117 }
118
119 #[must_use]
121 pub fn with_select_columns(mut self, columns: Vec<String>) -> Self {
122 self.select_columns = columns;
123 self
124 }
125
126 #[must_use]
128 pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
129 self.distance_metric = metric;
130 self
131 }
132
133 #[must_use]
135 pub fn with_limit(mut self, limit: u32) -> Self {
136 self.limit = limit;
137 self
138 }
139
140 #[must_use]
142 pub fn with_where(mut self, clause: impl Into<String>) -> Self {
143 self.where_clause = Some(clause.into());
144 self
145 }
146
147 #[must_use]
149 pub fn with_distance_score(mut self) -> Self {
150 self.include_distance = true;
151 self
152 }
153
154 #[must_use]
156 pub fn with_offset(mut self, offset: u32) -> Self {
157 self.offset = Some(offset);
158 self
159 }
160}
161
162#[derive(Debug, Clone)]
164pub struct VectorInsertQuery {
165 pub table: String,
167 pub columns: Vec<String>,
169 pub vector_column: String,
171 pub upsert: bool,
173 pub conflict_columns: Vec<String>,
175 pub update_columns: Vec<String>,
177 pub returning: Option<String>,
179}
180
181impl Default for VectorInsertQuery {
182 fn default() -> Self {
183 Self {
184 table: String::new(),
185 columns: Vec::new(),
186 vector_column: "embedding".to_string(),
187 upsert: false,
188 conflict_columns: vec!["id".to_string()],
189 update_columns: Vec::new(),
190 returning: Some("id".to_string()),
191 }
192 }
193}
194
195impl VectorInsertQuery {
196 pub fn new(table: impl Into<String>) -> Self {
198 Self {
199 table: table.into(),
200 ..Default::default()
201 }
202 }
203
204 #[must_use]
206 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
207 self.columns = columns;
208 self
209 }
210
211 #[must_use]
213 pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
214 self.vector_column = column.into();
215 self
216 }
217
218 #[must_use]
220 pub fn with_upsert(mut self, conflict_columns: Vec<String>) -> Self {
221 self.upsert = true;
222 self.conflict_columns = conflict_columns;
223 self
224 }
225
226 #[must_use]
228 pub fn with_update_columns(mut self, columns: Vec<String>) -> Self {
229 self.update_columns = columns;
230 self
231 }
232
233 #[must_use]
235 pub fn with_returning(mut self, column: impl Into<String>) -> Self {
236 self.returning = Some(column.into());
237 self
238 }
239}
240
241#[derive(Debug, Clone, Default)]
246pub struct VectorQueryBuilder {
247 placeholder_style: PlaceholderStyle,
249}
250
251#[derive(Debug, Clone, Copy, Default)]
253pub enum PlaceholderStyle {
254 #[default]
256 Dollar,
257 QuestionMark,
259}
260
261impl VectorQueryBuilder {
262 #[must_use]
264 pub fn new() -> Self {
265 Self::default()
266 }
267
268 #[must_use]
270 pub fn with_question_marks() -> Self {
271 Self {
272 placeholder_style: PlaceholderStyle::QuestionMark,
273 }
274 }
275
276 fn placeholder(&self, index: usize) -> String {
278 match self.placeholder_style {
279 PlaceholderStyle::Dollar => format!("${index}"),
280 PlaceholderStyle::QuestionMark => "?".to_string(),
281 }
282 }
283
284 #[must_use]
302 pub fn similarity_search(
303 &self,
304 query: &VectorSearchQuery,
305 query_embedding: &[f32],
306 ) -> (String, Vec<VectorParam>) {
307 let mut params = Vec::new();
308 let mut param_idx = 1;
309
310 params.push(VectorParam::Vector(query_embedding.to_vec()));
312 let embedding_placeholder = format!("{}::vector", self.placeholder(param_idx));
313 param_idx += 1;
314
315 let distance_op = query.distance_metric.operator();
316
317 let select_clause = if query.select_columns.is_empty() {
319 if query.include_distance {
320 format!(
321 "*, ({} {} {}) AS distance",
322 query.embedding_column, distance_op, embedding_placeholder
323 )
324 } else {
325 "*".to_string()
326 }
327 } else {
328 let cols = query.select_columns.join(", ");
329 if query.include_distance {
330 format!(
331 "{}, ({} {} {}) AS distance",
332 cols, query.embedding_column, distance_op, embedding_placeholder
333 )
334 } else {
335 cols
336 }
337 };
338
339 let where_clause = if let Some(ref clause) = query.where_clause {
341 format!("\nWHERE {clause}")
342 } else {
343 String::new()
344 };
345
346 let order_clause = format!(
348 "\nORDER BY {} {} {}",
349 query.embedding_column, distance_op, embedding_placeholder
350 );
351
352 let limit_clause = format!("\nLIMIT {}", self.placeholder(param_idx));
354 params.push(VectorParam::Int(i64::from(query.limit)));
355 param_idx += 1;
356
357 let offset_clause = if let Some(offset) = query.offset {
359 let clause = format!("\nOFFSET {}", self.placeholder(param_idx));
360 params.push(VectorParam::Int(i64::from(offset)));
361 clause
362 } else {
363 String::new()
364 };
365
366 let sql = format!(
367 "SELECT {}\nFROM {}{}{}{}{}",
368 select_clause, query.table, where_clause, order_clause, limit_clause, offset_clause
369 );
370
371 (sql, params)
372 }
373
374 #[must_use]
383 pub fn insert_one(
384 &self,
385 query: &VectorInsertQuery,
386 values: &[VectorParam],
387 ) -> (String, Vec<VectorParam>) {
388 let columns = query.columns.join(", ");
389
390 let placeholders: Vec<String> = values
391 .iter()
392 .enumerate()
393 .map(|(i, v)| {
394 let ph = self.placeholder(i + 1);
395 if matches!(v, VectorParam::Vector(_)) {
396 format!("{ph}::vector")
397 } else {
398 ph
399 }
400 })
401 .collect();
402
403 let values_clause = placeholders.join(", ");
404
405 let returning_clause = if let Some(ref col) = query.returning {
406 format!("\nRETURNING {col}")
407 } else {
408 String::new()
409 };
410
411 let sql = if query.upsert {
412 let conflict_cols = query.conflict_columns.join(", ");
413
414 let update_cols: Vec<&String> = if query.update_columns.is_empty() {
416 query.columns.iter().filter(|c| !query.conflict_columns.contains(c)).collect()
418 } else {
419 query.update_columns.iter().collect()
420 };
421
422 let update_clause: String = update_cols
423 .iter()
424 .map(|c| format!("{c} = EXCLUDED.{c}"))
425 .collect::<Vec<_>>()
426 .join(", ");
427
428 format!(
429 "INSERT INTO {} ({})\nVALUES ({})\nON CONFLICT ({}) DO UPDATE SET {}{}",
430 query.table, columns, values_clause, conflict_cols, update_clause, returning_clause
431 )
432 } else {
433 format!(
434 "INSERT INTO {} ({})\nVALUES ({}){}",
435 query.table, columns, values_clause, returning_clause
436 )
437 };
438
439 (sql, values.to_vec())
440 }
441
442 #[must_use]
454 pub fn insert_batch(
455 &self,
456 query: &VectorInsertQuery,
457 rows: &[Vec<VectorParam>],
458 ) -> (String, Vec<VectorParam>) {
459 if rows.is_empty() {
460 return (String::new(), Vec::new());
461 }
462
463 let columns = query.columns.join(", ");
464 let cols_per_row = query.columns.len();
465
466 let mut all_params = Vec::new();
467 let mut values_clauses = Vec::new();
468
469 for (row_idx, row) in rows.iter().enumerate() {
470 let base_idx = row_idx * cols_per_row + 1;
471 let placeholders: Vec<String> = row
472 .iter()
473 .enumerate()
474 .map(|(i, v)| {
475 let ph = self.placeholder(base_idx + i);
476 if matches!(v, VectorParam::Vector(_)) {
477 format!("{ph}::vector")
478 } else {
479 ph
480 }
481 })
482 .collect();
483
484 values_clauses.push(format!("({})", placeholders.join(", ")));
485 all_params.extend(row.clone());
486 }
487
488 let returning_clause = if let Some(ref col) = query.returning {
489 format!("\nRETURNING {col}")
490 } else {
491 String::new()
492 };
493
494 let sql = format!(
495 "INSERT INTO {} ({})\nVALUES\n {}{}",
496 query.table,
497 columns,
498 values_clauses.join(",\n "),
499 returning_clause
500 );
501
502 (sql, all_params)
503 }
504
505 #[must_use]
512 pub fn create_index(&self, config: &VectorConfig, table: &str, column: &str) -> Option<String> {
513 config.index_type.index_sql(table, column, config.distance_metric)
514 }
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_similarity_search_basic() {
523 let builder = VectorQueryBuilder::new();
524 let query = VectorSearchQuery::new("documents")
525 .with_embedding_column("embedding")
526 .with_limit(10);
527
528 let embedding = vec![0.1, 0.2, 0.3];
529 let (sql, params) = builder.similarity_search(&query, &embedding);
530
531 assert!(sql.contains("SELECT *"));
532 assert!(sql.contains("FROM documents"));
533 assert!(sql.contains("ORDER BY embedding <=>"));
534 assert!(sql.contains("$1::vector"));
535 assert!(sql.contains("LIMIT $2"));
536 assert_eq!(params.len(), 2);
537 }
538
539 #[test]
540 fn test_similarity_search_with_columns() {
541 let builder = VectorQueryBuilder::new();
542 let query = VectorSearchQuery::new("docs")
543 .with_select_columns(vec!["id".to_string(), "content".to_string()])
544 .with_distance_score()
545 .with_limit(5);
546
547 let embedding = vec![0.1, 0.2];
548 let (sql, params) = builder.similarity_search(&query, &embedding);
549
550 assert!(sql.contains("SELECT id, content,"));
551 assert!(sql.contains("AS distance"));
552 assert!(sql.contains("LIMIT $2"));
553 assert_eq!(params.len(), 2);
554 }
555
556 #[test]
557 fn test_similarity_search_with_where() {
558 let builder = VectorQueryBuilder::new();
559 let query = VectorSearchQuery::new("documents")
560 .with_where("metadata->>'type' = 'article'")
561 .with_limit(10);
562
563 let embedding = vec![0.1, 0.2, 0.3];
564 let (sql, _) = builder.similarity_search(&query, &embedding);
565
566 assert!(sql.contains("WHERE metadata->>'type' = 'article'"));
567 }
568
569 #[test]
570 fn test_similarity_search_with_offset() {
571 let builder = VectorQueryBuilder::new();
572 let query = VectorSearchQuery::new("documents").with_limit(10).with_offset(20);
573
574 let embedding = vec![0.1, 0.2];
575 let (sql, params) = builder.similarity_search(&query, &embedding);
576
577 assert!(sql.contains("OFFSET $3"));
578 assert_eq!(params.len(), 3);
579 }
580
581 #[test]
582 fn test_similarity_search_l2_distance() {
583 let builder = VectorQueryBuilder::new();
584 let query = VectorSearchQuery::new("docs")
585 .with_distance_metric(DistanceMetric::L2)
586 .with_limit(5);
587
588 let embedding = vec![0.1, 0.2];
589 let (sql, _) = builder.similarity_search(&query, &embedding);
590
591 assert!(sql.contains("<->"));
592 }
593
594 #[test]
595 fn test_similarity_search_inner_product() {
596 let builder = VectorQueryBuilder::new();
597 let query = VectorSearchQuery::new("docs")
598 .with_distance_metric(DistanceMetric::InnerProduct)
599 .with_limit(5);
600
601 let embedding = vec![0.1, 0.2];
602 let (sql, _) = builder.similarity_search(&query, &embedding);
603
604 assert!(sql.contains("<#>"));
605 }
606
607 #[test]
608 fn test_insert_one_basic() {
609 let builder = VectorQueryBuilder::new();
610 let query = VectorInsertQuery::new("documents").with_columns(vec![
611 "id".to_string(),
612 "content".to_string(),
613 "embedding".to_string(),
614 ]);
615
616 let values = vec![
617 VectorParam::String("doc1".to_string()),
618 VectorParam::String("Hello world".to_string()),
619 VectorParam::Vector(vec![0.1, 0.2, 0.3]),
620 ];
621
622 let (sql, params) = builder.insert_one(&query, &values);
623
624 assert!(sql.contains("INSERT INTO documents (id, content, embedding)"));
625 assert!(sql.contains("VALUES ($1, $2, $3::vector)"));
626 assert!(sql.contains("RETURNING id"));
627 assert_eq!(params.len(), 3);
628 }
629
630 #[test]
631 fn test_insert_upsert() {
632 let builder = VectorQueryBuilder::new();
633 let query = VectorInsertQuery::new("documents")
634 .with_columns(vec![
635 "id".to_string(),
636 "content".to_string(),
637 "embedding".to_string(),
638 ])
639 .with_upsert(vec!["id".to_string()]);
640
641 let values = vec![
642 VectorParam::String("doc1".to_string()),
643 VectorParam::String("Hello world".to_string()),
644 VectorParam::Vector(vec![0.1, 0.2, 0.3]),
645 ];
646
647 let (sql, _) = builder.insert_one(&query, &values);
648
649 assert!(sql.contains("ON CONFLICT (id) DO UPDATE SET"));
650 assert!(sql.contains("content = EXCLUDED.content"));
651 assert!(sql.contains("embedding = EXCLUDED.embedding"));
652 }
653
654 #[test]
655 fn test_insert_batch() {
656 let builder = VectorQueryBuilder::new();
657 let query = VectorInsertQuery::new("documents")
658 .with_columns(vec!["id".to_string(), "embedding".to_string()]);
659
660 let rows = vec![
661 vec![
662 VectorParam::String("doc1".to_string()),
663 VectorParam::Vector(vec![0.1, 0.2]),
664 ],
665 vec![
666 VectorParam::String("doc2".to_string()),
667 VectorParam::Vector(vec![0.3, 0.4]),
668 ],
669 ];
670
671 let (sql, params) = builder.insert_batch(&query, &rows);
672
673 assert!(sql.contains("INSERT INTO documents (id, embedding)"));
674 assert!(sql.contains("($1, $2::vector)"));
675 assert!(sql.contains("($3, $4::vector)"));
676 assert_eq!(params.len(), 4);
677 }
678
679 #[test]
680 fn test_insert_batch_empty() {
681 let builder = VectorQueryBuilder::new();
682 let query = VectorInsertQuery::new("documents").with_columns(vec!["id".to_string()]);
683
684 let (sql, params) = builder.insert_batch(&query, &[]);
685
686 assert!(sql.is_empty());
687 assert!(params.is_empty());
688 }
689
690 #[test]
691 fn test_create_index_hnsw() {
692 let builder = VectorQueryBuilder::new();
693 let config = VectorConfig::openai();
694
695 let sql = builder.create_index(&config, "documents", "embedding");
696
697 assert_eq!(
698 sql,
699 Some("CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)".to_string())
700 );
701 }
702
703 #[test]
704 fn test_create_index_ivfflat() {
705 let builder = VectorQueryBuilder::new();
706 let config = VectorConfig::new(1536)
707 .with_index(crate::schema::VectorIndexType::IvfFlat)
708 .with_distance(DistanceMetric::L2);
709
710 let sql = builder.create_index(&config, "docs", "vec");
711
712 assert_eq!(sql, Some("CREATE INDEX ON docs USING ivfflat (vec vector_l2_ops)".to_string()));
713 }
714
715 #[test]
716 fn test_create_index_none() {
717 let builder = VectorQueryBuilder::new();
718 let config = VectorConfig::new(1536).with_index(crate::schema::VectorIndexType::None);
719
720 let sql = builder.create_index(&config, "documents", "embedding");
721
722 assert_eq!(sql, None);
723 }
724
725 #[test]
726 fn test_vector_param_to_sql_literal() {
727 let vec_param = VectorParam::Vector(vec![0.1, 0.2, 0.3]);
728 assert_eq!(vec_param.to_sql_literal(), "'[0.1,0.2,0.3]'::vector");
729
730 let int_param = VectorParam::Int(42);
731 assert_eq!(int_param.to_sql_literal(), "42");
732
733 let str_param = VectorParam::String("hello".to_string());
734 assert_eq!(str_param.to_sql_literal(), "'hello'");
735
736 let str_param_escape = VectorParam::String("it's a test".to_string());
737 assert_eq!(str_param_escape.to_sql_literal(), "'it''s a test'");
738 }
739
740 #[test]
741 fn test_question_mark_placeholders() {
742 let builder = VectorQueryBuilder::with_question_marks();
743 let query = VectorSearchQuery::new("docs").with_limit(10);
744
745 let embedding = vec![0.1, 0.2];
746 let (sql, _) = builder.similarity_search(&query, &embedding);
747
748 assert!(sql.contains("?::vector"));
749 assert!(!sql.contains("$1"));
750 }
751}