1use serde::{Deserialize, Serialize};
25
26use crate::ops::{DistanceMetric, SearchParams};
27use crate::types::Embedding;
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct VectorSearchQuery {
32 pub table: String,
34 pub column: String,
36 pub query_vector: Embedding,
38 pub metric: DistanceMetric,
40 pub limit: usize,
42 pub select_columns: Vec<String>,
44 pub where_clauses: Vec<String>,
46 pub include_distance: bool,
48 pub distance_alias: String,
50 pub max_distance: Option<f64>,
52 pub min_distance: Option<f64>,
54 pub extra_order_by: Vec<String>,
56 pub offset: Option<usize>,
58 pub search_params: SearchParams,
60}
61
62impl VectorSearchQuery {
63 pub fn to_sql(&self) -> String {
67 self.to_sql_with_param(1)
68 }
69
70 pub fn to_sql_with_param(&self, param_index: usize) -> String {
72 let param = format!("${param_index}");
73 let distance_expr = format!("{} {} {}", self.column, self.metric.operator(), param);
74
75 let select = if self.select_columns.is_empty() {
77 "*".to_string()
78 } else {
79 self.select_columns.join(", ")
80 };
81
82 let distance_select = if self.include_distance {
83 format!(", {} AS {}", distance_expr, self.distance_alias)
84 } else {
85 String::new()
86 };
87
88 let mut where_parts = Vec::new();
90
91 if let Some(max) = self.max_distance {
92 where_parts.push(format!("{distance_expr} < {max}"));
93 }
94 if let Some(min) = self.min_distance {
95 where_parts.push(format!("{distance_expr} >= {min}"));
96 }
97 where_parts.extend(self.where_clauses.clone());
98
99 let where_clause = if where_parts.is_empty() {
100 String::new()
101 } else {
102 format!(" WHERE {}", where_parts.join(" AND "))
103 };
104
105 let order_by_main = if self.include_distance {
107 self.distance_alias.clone()
108 } else {
109 distance_expr
110 };
111
112 let order_by = if self.extra_order_by.is_empty() {
113 order_by_main
114 } else {
115 let mut parts = vec![order_by_main];
116 parts.extend(self.extra_order_by.clone());
117 parts.join(", ")
118 };
119
120 let limit = format!(" LIMIT {}", self.limit);
122 let offset = self
123 .offset
124 .map(|o| format!(" OFFSET {o}"))
125 .unwrap_or_default();
126
127 format!(
128 "SELECT {}{} FROM {}{} ORDER BY {}{}{}",
129 select, distance_select, self.table, where_clause, order_by, limit, offset
130 )
131 }
132
133 pub fn param_set_sql(&self) -> Vec<String> {
137 self.search_params.to_set_sql()
138 }
139}
140
141pub struct VectorSearchBuilder {
157 table: String,
158 column: String,
159 query_vector: Option<Embedding>,
160 metric: DistanceMetric,
161 limit: usize,
162 select_columns: Vec<String>,
163 where_clauses: Vec<String>,
164 include_distance: bool,
165 distance_alias: String,
166 max_distance: Option<f64>,
167 min_distance: Option<f64>,
168 extra_order_by: Vec<String>,
169 offset: Option<usize>,
170 search_params: SearchParams,
171}
172
173impl VectorSearchBuilder {
174 pub fn new(table: impl Into<String>, column: impl Into<String>) -> Self {
176 Self {
177 table: table.into(),
178 column: column.into(),
179 query_vector: None,
180 metric: DistanceMetric::L2,
181 limit: 10,
182 select_columns: Vec::new(),
183 where_clauses: Vec::new(),
184 include_distance: true,
185 distance_alias: "distance".to_string(),
186 max_distance: None,
187 min_distance: None,
188 extra_order_by: Vec::new(),
189 offset: None,
190 search_params: SearchParams::new(),
191 }
192 }
193
194 pub fn query(mut self, embedding: Embedding) -> Self {
196 self.query_vector = Some(embedding);
197 self
198 }
199
200 pub fn metric(mut self, metric: DistanceMetric) -> Self {
202 self.metric = metric;
203 self
204 }
205
206 pub fn limit(mut self, limit: usize) -> Self {
208 self.limit = limit;
209 self
210 }
211
212 pub fn select(mut self, columns: &[&str]) -> Self {
214 self.select_columns = columns.iter().map(|c| (*c).to_string()).collect();
215 self
216 }
217
218 pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
220 self.where_clauses.push(condition.into());
221 self
222 }
223
224 pub fn max_distance(mut self, distance: f64) -> Self {
226 self.max_distance = Some(distance);
227 self
228 }
229
230 pub fn min_distance(mut self, distance: f64) -> Self {
232 self.min_distance = Some(distance);
233 self
234 }
235
236 pub fn without_distance(mut self) -> Self {
238 self.include_distance = false;
239 self
240 }
241
242 pub fn distance_alias(mut self, alias: impl Into<String>) -> Self {
244 self.distance_alias = alias.into();
245 self
246 }
247
248 pub fn then_order_by(mut self, clause: impl Into<String>) -> Self {
250 self.extra_order_by.push(clause.into());
251 self
252 }
253
254 pub fn offset(mut self, offset: usize) -> Self {
256 self.offset = Some(offset);
257 self
258 }
259
260 pub fn probes(mut self, probes: usize) -> Self {
262 self.search_params = self.search_params.probes(probes);
263 self
264 }
265
266 pub fn ef_search(mut self, ef: usize) -> Self {
268 self.search_params = self.search_params.ef_search(ef);
269 self
270 }
271
272 pub fn build(self) -> VectorSearchQuery {
279 self.try_build()
280 .expect("query vector must be set before building")
281 }
282
283 pub fn try_build(self) -> Option<VectorSearchQuery> {
287 let query_vector = self.query_vector?;
288
289 Some(VectorSearchQuery {
290 table: self.table,
291 column: self.column,
292 query_vector,
293 metric: self.metric,
294 limit: self.limit,
295 select_columns: self.select_columns,
296 where_clauses: self.where_clauses,
297 include_distance: self.include_distance,
298 distance_alias: self.distance_alias,
299 max_distance: self.max_distance,
300 min_distance: self.min_distance,
301 extra_order_by: self.extra_order_by,
302 offset: self.offset,
303 search_params: self.search_params,
304 })
305 }
306}
307
308pub struct HybridSearchBuilder {
333 table: String,
334 vector_column: Option<String>,
335 text_column: Option<String>,
336 query_vector: Option<Embedding>,
337 query_text: Option<String>,
338 metric: DistanceMetric,
339 vector_weight: f64,
340 text_weight: f64,
341 limit: usize,
342 language: String,
343 where_clauses: Vec<String>,
344}
345
346impl HybridSearchBuilder {
347 pub fn new(table: impl Into<String>) -> Self {
349 Self {
350 table: table.into(),
351 vector_column: None,
352 text_column: None,
353 query_vector: None,
354 query_text: None,
355 metric: DistanceMetric::Cosine,
356 vector_weight: 0.5,
357 text_weight: 0.5,
358 limit: 10,
359 language: "english".to_string(),
360 where_clauses: Vec::new(),
361 }
362 }
363
364 pub fn vector_column(mut self, column: impl Into<String>) -> Self {
366 self.vector_column = Some(column.into());
367 self
368 }
369
370 pub fn text_column(mut self, column: impl Into<String>) -> Self {
372 self.text_column = Some(column.into());
373 self
374 }
375
376 pub fn query_vector(mut self, embedding: Embedding) -> Self {
378 self.query_vector = Some(embedding);
379 self
380 }
381
382 pub fn query_text(mut self, text: impl Into<String>) -> Self {
384 self.query_text = Some(text.into());
385 self
386 }
387
388 pub fn metric(mut self, metric: DistanceMetric) -> Self {
390 self.metric = metric;
391 self
392 }
393
394 pub fn vector_weight(mut self, weight: f64) -> Self {
396 self.vector_weight = weight;
397 self
398 }
399
400 pub fn text_weight(mut self, weight: f64) -> Self {
402 self.text_weight = weight;
403 self
404 }
405
406 pub fn limit(mut self, limit: usize) -> Self {
408 self.limit = limit;
409 self
410 }
411
412 pub fn language(mut self, language: impl Into<String>) -> Self {
414 self.language = language.into();
415 self
416 }
417
418 pub fn where_clause(mut self, condition: impl Into<String>) -> Self {
420 self.where_clauses.push(condition.into());
421 self
422 }
423
424 pub fn build(self) -> HybridSearchQuery {
426 HybridSearchQuery {
427 table: self.table,
428 vector_column: self
429 .vector_column
430 .unwrap_or_else(|| "embedding".to_string()),
431 text_column: self.text_column.unwrap_or_else(|| "content".to_string()),
432 query_vector: self.query_vector,
433 query_text: self.query_text,
434 metric: self.metric,
435 vector_weight: self.vector_weight,
436 text_weight: self.text_weight,
437 limit: self.limit,
438 language: self.language,
439 where_clauses: self.where_clauses,
440 }
441 }
442}
443
444#[derive(Debug, Clone, Serialize, Deserialize)]
446pub struct HybridSearchQuery {
447 pub table: String,
449 pub vector_column: String,
451 pub text_column: String,
453 pub query_vector: Option<Embedding>,
455 pub query_text: Option<String>,
457 pub metric: DistanceMetric,
459 pub vector_weight: f64,
461 pub text_weight: f64,
463 pub limit: usize,
465 pub language: String,
467 pub where_clauses: Vec<String>,
469}
470
471impl HybridSearchQuery {
472 pub fn to_sql(&self) -> String {
479 let vec_distance = format!("{} {} $1", self.vector_column, self.metric.operator());
480 let text_rank = format!(
481 "ts_rank(to_tsvector('{}', {}), plainto_tsquery('{}', $2))",
482 self.language, self.text_column, self.language
483 );
484
485 let where_clause = if self.where_clauses.is_empty() {
486 String::new()
487 } else {
488 format!(" WHERE {}", self.where_clauses.join(" AND "))
489 };
490
491 format!(
493 "WITH vector_results AS (\
494 SELECT *, ROW_NUMBER() OVER (ORDER BY {vec_distance}) AS vec_rank \
495 FROM {table}{where_clause} \
496 ORDER BY {vec_distance} \
497 LIMIT {fetch_limit}\
498 ), \
499 text_results AS (\
500 SELECT *, ROW_NUMBER() OVER (ORDER BY {text_rank} DESC) AS text_rank \
501 FROM {table}{where_clause} \
502 WHERE to_tsvector('{lang}', {text_col}) @@ plainto_tsquery('{lang}', $2) \
503 ORDER BY {text_rank} DESC \
504 LIMIT {fetch_limit}\
505 ) \
506 SELECT COALESCE(v.*, t.*), \
507 ({vec_weight} / (60.0 + COALESCE(v.vec_rank, 1000))) + \
508 ({text_weight} / (60.0 + COALESCE(t.text_rank, 1000))) AS rrf_score \
509 FROM vector_results v \
510 FULL OUTER JOIN text_results t ON v.id = t.id \
511 ORDER BY rrf_score DESC \
512 LIMIT {limit}",
513 table = self.table,
514 where_clause = where_clause,
515 fetch_limit = self.limit * 3, vec_weight = self.vector_weight,
517 text_weight = self.text_weight,
518 lang = self.language,
519 text_col = self.text_column,
520 limit = self.limit,
521 )
522 }
523}
524
525#[cfg(test)]
526mod tests {
527 use super::*;
528
529 fn test_embedding() -> Embedding {
530 Embedding::new(vec![0.1, 0.2, 0.3])
531 }
532
533 #[test]
534 fn test_basic_search_query() {
535 let query = VectorSearchBuilder::new("documents", "embedding")
536 .query(test_embedding())
537 .metric(DistanceMetric::Cosine)
538 .limit(10)
539 .build();
540
541 let sql = query.to_sql();
542 assert!(sql.contains("SELECT *"));
543 assert!(sql.contains("AS distance"));
544 assert!(sql.contains("<=>"));
545 assert!(sql.contains("$1"));
546 assert!(sql.contains("FROM documents"));
547 assert!(sql.contains("LIMIT 10"));
548 }
549
550 #[test]
551 fn test_search_with_select() {
552 let query = VectorSearchBuilder::new("documents", "embedding")
553 .query(test_embedding())
554 .select(&["id", "title"])
555 .build();
556
557 let sql = query.to_sql();
558 assert!(sql.contains("SELECT id, title"));
559 }
560
561 #[test]
562 fn test_search_with_where() {
563 let query = VectorSearchBuilder::new("documents", "embedding")
564 .query(test_embedding())
565 .where_clause("category = 'tech'")
566 .where_clause("published = true")
567 .build();
568
569 let sql = query.to_sql();
570 assert!(sql.contains("WHERE"));
571 assert!(sql.contains("category = 'tech'"));
572 assert!(sql.contains("published = true"));
573 assert!(sql.contains("AND"));
574 }
575
576 #[test]
577 fn test_search_with_max_distance() {
578 let query = VectorSearchBuilder::new("documents", "embedding")
579 .query(test_embedding())
580 .metric(DistanceMetric::L2)
581 .max_distance(0.5)
582 .build();
583
584 let sql = query.to_sql();
585 assert!(sql.contains("< 0.5"));
586 }
587
588 #[test]
589 fn test_search_with_distance_range() {
590 let query = VectorSearchBuilder::new("documents", "embedding")
591 .query(test_embedding())
592 .min_distance(0.1)
593 .max_distance(0.5)
594 .build();
595
596 let sql = query.to_sql();
597 assert!(sql.contains("< 0.5"));
598 assert!(sql.contains(">= 0.1"));
599 }
600
601 #[test]
602 fn test_search_without_distance() {
603 let query = VectorSearchBuilder::new("documents", "embedding")
604 .query(test_embedding())
605 .without_distance()
606 .build();
607
608 let sql = query.to_sql();
609 assert!(!sql.contains("AS distance"));
610 }
611
612 #[test]
613 fn test_search_custom_alias() {
614 let query = VectorSearchBuilder::new("documents", "embedding")
615 .query(test_embedding())
616 .distance_alias("similarity")
617 .build();
618
619 let sql = query.to_sql();
620 assert!(sql.contains("AS similarity"));
621 }
622
623 #[test]
624 fn test_search_with_pagination() {
625 let query = VectorSearchBuilder::new("documents", "embedding")
626 .query(test_embedding())
627 .limit(10)
628 .offset(20)
629 .build();
630
631 let sql = query.to_sql();
632 assert!(sql.contains("LIMIT 10"));
633 assert!(sql.contains("OFFSET 20"));
634 }
635
636 #[test]
637 fn test_search_with_extra_order_by() {
638 let query = VectorSearchBuilder::new("documents", "embedding")
639 .query(test_embedding())
640 .then_order_by("created_at DESC")
641 .build();
642
643 let sql = query.to_sql();
644 assert!(sql.contains("ORDER BY distance, created_at DESC"));
645 }
646
647 #[test]
648 fn test_search_params() {
649 let query = VectorSearchBuilder::new("documents", "embedding")
650 .query(test_embedding())
651 .probes(10)
652 .ef_search(200)
653 .build();
654
655 let set_sql = query.param_set_sql();
656 assert_eq!(set_sql.len(), 2);
657 assert!(set_sql[0].contains("ivfflat.probes = 10"));
658 assert!(set_sql[1].contains("hnsw.ef_search = 200"));
659 }
660
661 #[test]
662 fn test_try_build_without_vector() {
663 let result = VectorSearchBuilder::new("documents", "embedding").try_build();
664 assert!(result.is_none());
665 }
666
667 #[test]
668 fn test_custom_param_index() {
669 let query = VectorSearchBuilder::new("documents", "embedding")
670 .query(test_embedding())
671 .build();
672
673 let sql = query.to_sql_with_param(3);
674 assert!(sql.contains("$3"));
675 }
676
677 #[test]
678 fn test_hybrid_search() {
679 let query = HybridSearchBuilder::new("documents")
680 .vector_column("embedding")
681 .text_column("content")
682 .query_vector(test_embedding())
683 .query_text("machine learning")
684 .metric(DistanceMetric::Cosine)
685 .vector_weight(0.7)
686 .text_weight(0.3)
687 .limit(10)
688 .build();
689
690 let sql = query.to_sql();
691 assert!(sql.contains("vector_results"));
692 assert!(sql.contains("text_results"));
693 assert!(sql.contains("rrf_score"));
694 assert!(sql.contains("<=>"));
695 assert!(sql.contains("ts_rank"));
696 assert!(sql.contains("FULL OUTER JOIN"));
697 }
698
699 #[test]
700 fn test_all_metrics_produce_valid_sql() {
701 for metric in [
702 DistanceMetric::L2,
703 DistanceMetric::InnerProduct,
704 DistanceMetric::Cosine,
705 DistanceMetric::L1,
706 ] {
707 let query = VectorSearchBuilder::new("t", "c")
708 .query(test_embedding())
709 .metric(metric)
710 .build();
711 let sql = query.to_sql();
712 assert!(sql.contains(metric.operator()), "failed for {metric}");
713 }
714 }
715}