1use serde::{Deserialize, Serialize};
32
33use crate::schema::{DistanceMetric, VectorConfig};
34
35#[derive(Debug, Clone, Serialize, Deserialize)]
37#[non_exhaustive]
38pub enum VectorParam {
39 Vector(Vec<f32>),
41 Int(i64),
43 String(String),
45 Json(serde_json::Value),
47}
48
49impl VectorParam {
50 #[must_use]
52 pub fn to_sql_literal(&self) -> String {
53 match self {
54 VectorParam::Vector(v) => {
55 let values: Vec<String> = v.iter().map(std::string::ToString::to_string).collect();
56 format!("'[{}]'::vector", values.join(","))
57 },
58 VectorParam::Int(i) => i.to_string(),
59 VectorParam::String(s) => format!("'{}'", s.replace('\'', "''")),
60 VectorParam::Json(j) => format!("'{j}'::jsonb"),
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct VectorSearchQuery {
68 pub table: String,
70 pub embedding_column: String,
72 pub select_columns: Vec<String>,
74 pub distance_metric: DistanceMetric,
76 pub limit: u32,
78 pub where_clause: Option<String>,
80 pub order_by: Option<String>,
82 pub include_distance: bool,
84 pub offset: Option<u32>,
86}
87
88impl Default for VectorSearchQuery {
89 fn default() -> Self {
90 Self {
91 table: String::new(),
92 embedding_column: "embedding".to_string(),
93 select_columns: Vec::new(),
94 distance_metric: DistanceMetric::Cosine,
95 limit: 10,
96 where_clause: None,
97 order_by: None,
98 include_distance: false,
99 offset: None,
100 }
101 }
102}
103
104impl VectorSearchQuery {
105 pub fn new(table: impl Into<String>) -> Self {
107 Self {
108 table: table.into(),
109 ..Default::default()
110 }
111 }
112
113 pub fn with_embedding_column(mut self, column: impl Into<String>) -> Self {
115 self.embedding_column = column.into();
116 self
117 }
118
119 pub fn with_select_columns(mut self, columns: Vec<String>) -> Self {
121 self.select_columns = columns;
122 self
123 }
124
125 pub const fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
127 self.distance_metric = metric;
128 self
129 }
130
131 pub const fn with_limit(mut self, limit: u32) -> Self {
133 self.limit = limit;
134 self
135 }
136
137 pub fn with_where(mut self, clause: impl Into<String>) -> Self {
139 self.where_clause = Some(clause.into());
140 self
141 }
142
143 pub const fn with_distance_score(mut self) -> Self {
145 self.include_distance = true;
146 self
147 }
148
149 pub const fn with_offset(mut self, offset: u32) -> Self {
151 self.offset = Some(offset);
152 self
153 }
154}
155
156#[derive(Debug, Clone)]
158pub struct VectorInsertQuery {
159 pub table: String,
161 pub columns: Vec<String>,
163 pub vector_column: String,
165 pub upsert: bool,
167 pub conflict_columns: Vec<String>,
169 pub update_columns: Vec<String>,
171 pub returning: Option<String>,
173}
174
175impl Default for VectorInsertQuery {
176 fn default() -> Self {
177 Self {
178 table: String::new(),
179 columns: Vec::new(),
180 vector_column: "embedding".to_string(),
181 upsert: false,
182 conflict_columns: vec!["id".to_string()],
183 update_columns: Vec::new(),
184 returning: Some("id".to_string()),
185 }
186 }
187}
188
189impl VectorInsertQuery {
190 pub fn new(table: impl Into<String>) -> Self {
192 Self {
193 table: table.into(),
194 ..Default::default()
195 }
196 }
197
198 pub fn with_columns(mut self, columns: Vec<String>) -> Self {
200 self.columns = columns;
201 self
202 }
203
204 pub fn with_vector_column(mut self, column: impl Into<String>) -> Self {
206 self.vector_column = column.into();
207 self
208 }
209
210 pub fn with_upsert(mut self, conflict_columns: Vec<String>) -> Self {
212 self.upsert = true;
213 self.conflict_columns = conflict_columns;
214 self
215 }
216
217 pub fn with_update_columns(mut self, columns: Vec<String>) -> Self {
219 self.update_columns = columns;
220 self
221 }
222
223 pub fn with_returning(mut self, column: impl Into<String>) -> Self {
225 self.returning = Some(column.into());
226 self
227 }
228}
229
230#[must_use = "call .build() to construct the final value"]
235#[derive(Debug, Clone, Default)]
236pub struct VectorQueryBuilder {
237 placeholder_style: PlaceholderStyle,
239}
240
241#[derive(Debug, Clone, Copy, Default)]
243#[non_exhaustive]
244pub enum PlaceholderStyle {
245 #[default]
247 Dollar,
248 QuestionMark,
250}
251
252impl VectorQueryBuilder {
253 pub fn new() -> Self {
255 Self::default()
256 }
257
258 pub const fn with_question_marks() -> Self {
260 Self {
261 placeholder_style: PlaceholderStyle::QuestionMark,
262 }
263 }
264
265 fn placeholder(&self, index: usize) -> String {
267 match self.placeholder_style {
268 PlaceholderStyle::Dollar => format!("${index}"),
269 PlaceholderStyle::QuestionMark => "?".to_string(),
270 }
271 }
272
273 #[must_use]
291 pub fn similarity_search(
292 &self,
293 query: &VectorSearchQuery,
294 query_embedding: &[f32],
295 ) -> (String, Vec<VectorParam>) {
296 let mut params = Vec::new();
297 let mut param_idx = 1;
298
299 params.push(VectorParam::Vector(query_embedding.to_vec()));
301 let embedding_placeholder = format!("{}::vector", self.placeholder(param_idx));
302 param_idx += 1;
303
304 let distance_op = query.distance_metric.operator();
305
306 let select_clause = if query.select_columns.is_empty() {
308 if query.include_distance {
309 format!(
310 "*, ({} {} {}) AS distance",
311 query.embedding_column, distance_op, embedding_placeholder
312 )
313 } else {
314 "*".to_string()
315 }
316 } else {
317 let cols = query.select_columns.join(", ");
318 if query.include_distance {
319 format!(
320 "{}, ({} {} {}) AS distance",
321 cols, query.embedding_column, distance_op, embedding_placeholder
322 )
323 } else {
324 cols
325 }
326 };
327
328 let where_clause = if let Some(ref clause) = query.where_clause {
330 format!("\nWHERE {clause}")
331 } else {
332 String::new()
333 };
334
335 let order_clause = format!(
337 "\nORDER BY {} {} {}",
338 query.embedding_column, distance_op, embedding_placeholder
339 );
340
341 let limit_clause = format!("\nLIMIT {}", self.placeholder(param_idx));
343 params.push(VectorParam::Int(i64::from(query.limit)));
344 param_idx += 1;
345
346 let offset_clause = if let Some(offset) = query.offset {
348 let clause = format!("\nOFFSET {}", self.placeholder(param_idx));
349 params.push(VectorParam::Int(i64::from(offset)));
350 clause
351 } else {
352 String::new()
353 };
354
355 let sql = format!(
356 "SELECT {}\nFROM {}{}{}{}{}",
357 select_clause, query.table, where_clause, order_clause, limit_clause, offset_clause
358 );
359
360 (sql, params)
361 }
362
363 #[must_use]
372 pub fn insert_one(
373 &self,
374 query: &VectorInsertQuery,
375 values: &[VectorParam],
376 ) -> (String, Vec<VectorParam>) {
377 let columns = query.columns.join(", ");
378
379 let placeholders: Vec<String> = values
380 .iter()
381 .enumerate()
382 .map(|(i, v)| {
383 let ph = self.placeholder(i + 1);
384 if matches!(v, VectorParam::Vector(_)) {
385 format!("{ph}::vector")
386 } else {
387 ph
388 }
389 })
390 .collect();
391
392 let values_clause = placeholders.join(", ");
393
394 let returning_clause = if let Some(ref col) = query.returning {
395 format!("\nRETURNING {col}")
396 } else {
397 String::new()
398 };
399
400 let sql = if query.upsert {
401 let conflict_cols = query.conflict_columns.join(", ");
402
403 let update_cols: Vec<&String> = if query.update_columns.is_empty() {
405 query.columns.iter().filter(|c| !query.conflict_columns.contains(c)).collect()
407 } else {
408 query.update_columns.iter().collect()
409 };
410
411 let update_clause: String = update_cols
412 .iter()
413 .map(|c| format!("{c} = EXCLUDED.{c}"))
414 .collect::<Vec<_>>()
415 .join(", ");
416
417 format!(
418 "INSERT INTO {} ({})\nVALUES ({})\nON CONFLICT ({}) DO UPDATE SET {}{}",
419 query.table, columns, values_clause, conflict_cols, update_clause, returning_clause
420 )
421 } else {
422 format!(
423 "INSERT INTO {} ({})\nVALUES ({}){}",
424 query.table, columns, values_clause, returning_clause
425 )
426 };
427
428 (sql, values.to_vec())
429 }
430
431 #[must_use]
443 pub fn insert_batch(
444 &self,
445 query: &VectorInsertQuery,
446 rows: &[Vec<VectorParam>],
447 ) -> (String, Vec<VectorParam>) {
448 if rows.is_empty() {
449 return (String::new(), Vec::new());
450 }
451
452 let columns = query.columns.join(", ");
453 let cols_per_row = query.columns.len();
454
455 let mut all_params = Vec::new();
456 let mut values_clauses = Vec::new();
457
458 for (row_idx, row) in rows.iter().enumerate() {
459 let base_idx = row_idx * cols_per_row + 1;
460 let placeholders: Vec<String> = row
461 .iter()
462 .enumerate()
463 .map(|(i, v)| {
464 let ph = self.placeholder(base_idx + i);
465 if matches!(v, VectorParam::Vector(_)) {
466 format!("{ph}::vector")
467 } else {
468 ph
469 }
470 })
471 .collect();
472
473 values_clauses.push(format!("({})", placeholders.join(", ")));
474 all_params.extend(row.clone());
475 }
476
477 let returning_clause = if let Some(ref col) = query.returning {
478 format!("\nRETURNING {col}")
479 } else {
480 String::new()
481 };
482
483 let sql = format!(
484 "INSERT INTO {} ({})\nVALUES\n {}{}",
485 query.table,
486 columns,
487 values_clauses.join(",\n "),
488 returning_clause
489 );
490
491 (sql, all_params)
492 }
493
494 #[must_use]
501 pub fn create_index(&self, config: &VectorConfig, table: &str, column: &str) -> Option<String> {
502 config.index_type.index_sql(table, column, config.distance_metric)
503 }
504}
505
506#[cfg(test)]
507mod tests {
508 use super::*;
509
510 #[test]
511 fn test_similarity_search_basic() {
512 let builder = VectorQueryBuilder::new();
513 let query = VectorSearchQuery::new("documents")
514 .with_embedding_column("embedding")
515 .with_limit(10);
516
517 let embedding = vec![0.1, 0.2, 0.3];
518 let (sql, params) = builder.similarity_search(&query, &embedding);
519
520 assert!(sql.contains("SELECT *"));
521 assert!(sql.contains("FROM documents"));
522 assert!(sql.contains("ORDER BY embedding <=>"));
523 assert!(sql.contains("$1::vector"));
524 assert!(sql.contains("LIMIT $2"));
525 assert_eq!(params.len(), 2);
526 }
527
528 #[test]
529 fn test_similarity_search_with_columns() {
530 let builder = VectorQueryBuilder::new();
531 let query = VectorSearchQuery::new("docs")
532 .with_select_columns(vec!["id".to_string(), "content".to_string()])
533 .with_distance_score()
534 .with_limit(5);
535
536 let embedding = vec![0.1, 0.2];
537 let (sql, params) = builder.similarity_search(&query, &embedding);
538
539 assert!(sql.contains("SELECT id, content,"));
540 assert!(sql.contains("AS distance"));
541 assert!(sql.contains("LIMIT $2"));
542 assert_eq!(params.len(), 2);
543 }
544
545 #[test]
546 fn test_similarity_search_with_where() {
547 let builder = VectorQueryBuilder::new();
548 let query = VectorSearchQuery::new("documents")
549 .with_where("metadata->>'type' = 'article'")
550 .with_limit(10);
551
552 let embedding = vec![0.1, 0.2, 0.3];
553 let (sql, _) = builder.similarity_search(&query, &embedding);
554
555 assert!(sql.contains("WHERE metadata->>'type' = 'article'"));
556 }
557
558 #[test]
559 fn test_similarity_search_with_offset() {
560 let builder = VectorQueryBuilder::new();
561 let query = VectorSearchQuery::new("documents").with_limit(10).with_offset(20);
562
563 let embedding = vec![0.1, 0.2];
564 let (sql, params) = builder.similarity_search(&query, &embedding);
565
566 assert!(sql.contains("OFFSET $3"));
567 assert_eq!(params.len(), 3);
568 }
569
570 #[test]
571 fn test_similarity_search_l2_distance() {
572 let builder = VectorQueryBuilder::new();
573 let query = VectorSearchQuery::new("docs")
574 .with_distance_metric(DistanceMetric::L2)
575 .with_limit(5);
576
577 let embedding = vec![0.1, 0.2];
578 let (sql, _) = builder.similarity_search(&query, &embedding);
579
580 assert!(sql.contains("<->"));
581 }
582
583 #[test]
584 fn test_similarity_search_inner_product() {
585 let builder = VectorQueryBuilder::new();
586 let query = VectorSearchQuery::new("docs")
587 .with_distance_metric(DistanceMetric::InnerProduct)
588 .with_limit(5);
589
590 let embedding = vec![0.1, 0.2];
591 let (sql, _) = builder.similarity_search(&query, &embedding);
592
593 assert!(sql.contains("<#>"));
594 }
595
596 #[test]
597 fn test_insert_one_basic() {
598 let builder = VectorQueryBuilder::new();
599 let query = VectorInsertQuery::new("documents").with_columns(vec![
600 "id".to_string(),
601 "content".to_string(),
602 "embedding".to_string(),
603 ]);
604
605 let values = vec![
606 VectorParam::String("doc1".to_string()),
607 VectorParam::String("Hello world".to_string()),
608 VectorParam::Vector(vec![0.1, 0.2, 0.3]),
609 ];
610
611 let (sql, params) = builder.insert_one(&query, &values);
612
613 assert!(sql.contains("INSERT INTO documents (id, content, embedding)"));
614 assert!(sql.contains("VALUES ($1, $2, $3::vector)"));
615 assert!(sql.contains("RETURNING id"));
616 assert_eq!(params.len(), 3);
617 }
618
619 #[test]
620 fn test_insert_upsert() {
621 let builder = VectorQueryBuilder::new();
622 let query = VectorInsertQuery::new("documents")
623 .with_columns(vec![
624 "id".to_string(),
625 "content".to_string(),
626 "embedding".to_string(),
627 ])
628 .with_upsert(vec!["id".to_string()]);
629
630 let values = vec![
631 VectorParam::String("doc1".to_string()),
632 VectorParam::String("Hello world".to_string()),
633 VectorParam::Vector(vec![0.1, 0.2, 0.3]),
634 ];
635
636 let (sql, _) = builder.insert_one(&query, &values);
637
638 assert!(sql.contains("ON CONFLICT (id) DO UPDATE SET"));
639 assert!(sql.contains("content = EXCLUDED.content"));
640 assert!(sql.contains("embedding = EXCLUDED.embedding"));
641 }
642
643 #[test]
644 fn test_insert_batch() {
645 let builder = VectorQueryBuilder::new();
646 let query = VectorInsertQuery::new("documents")
647 .with_columns(vec!["id".to_string(), "embedding".to_string()]);
648
649 let rows = vec![
650 vec![
651 VectorParam::String("doc1".to_string()),
652 VectorParam::Vector(vec![0.1, 0.2]),
653 ],
654 vec![
655 VectorParam::String("doc2".to_string()),
656 VectorParam::Vector(vec![0.3, 0.4]),
657 ],
658 ];
659
660 let (sql, params) = builder.insert_batch(&query, &rows);
661
662 assert!(sql.contains("INSERT INTO documents (id, embedding)"));
663 assert!(sql.contains("($1, $2::vector)"));
664 assert!(sql.contains("($3, $4::vector)"));
665 assert_eq!(params.len(), 4);
666 }
667
668 #[test]
669 fn test_insert_batch_empty() {
670 let builder = VectorQueryBuilder::new();
671 let query = VectorInsertQuery::new("documents").with_columns(vec!["id".to_string()]);
672
673 let (sql, params) = builder.insert_batch(&query, &[]);
674
675 assert!(sql.is_empty());
676 assert!(params.is_empty());
677 }
678
679 #[test]
680 fn test_create_index_hnsw() {
681 let builder = VectorQueryBuilder::new();
682 let config = VectorConfig::openai();
683
684 let sql = builder.create_index(&config, "documents", "embedding");
685
686 assert_eq!(
687 sql,
688 Some("CREATE INDEX ON documents USING hnsw (embedding vector_cosine_ops)".to_string())
689 );
690 }
691
692 #[test]
693 fn test_create_index_ivfflat() {
694 let builder = VectorQueryBuilder::new();
695 let config = VectorConfig::new(1536)
696 .with_index(crate::schema::VectorIndexType::IvfFlat)
697 .with_distance(DistanceMetric::L2);
698
699 let sql = builder.create_index(&config, "docs", "vec");
700
701 assert_eq!(sql, Some("CREATE INDEX ON docs USING ivfflat (vec vector_l2_ops)".to_string()));
702 }
703
704 #[test]
705 fn test_create_index_none() {
706 let builder = VectorQueryBuilder::new();
707 let config = VectorConfig::new(1536).with_index(crate::schema::VectorIndexType::None);
708
709 let sql = builder.create_index(&config, "documents", "embedding");
710
711 assert_eq!(sql, None);
712 }
713
714 #[test]
715 fn test_vector_param_to_sql_literal() {
716 let vec_param = VectorParam::Vector(vec![0.1, 0.2, 0.3]);
717 assert_eq!(vec_param.to_sql_literal(), "'[0.1,0.2,0.3]'::vector");
718
719 let int_param = VectorParam::Int(42);
720 assert_eq!(int_param.to_sql_literal(), "42");
721
722 let str_param = VectorParam::String("hello".to_string());
723 assert_eq!(str_param.to_sql_literal(), "'hello'");
724
725 let str_param_escape = VectorParam::String("it's a test".to_string());
726 assert_eq!(str_param_escape.to_sql_literal(), "'it''s a test'");
727 }
728
729 #[test]
730 fn test_question_mark_placeholders() {
731 let builder = VectorQueryBuilder::with_question_marks();
732 let query = VectorSearchQuery::new("docs").with_limit(10);
733
734 let embedding = vec![0.1, 0.2];
735 let (sql, _) = builder.similarity_search(&query, &embedding);
736
737 assert!(sql.contains("?::vector"));
738 assert!(!sql.contains("$1"));
739 }
740}