1use std::borrow::Cow;
20
21use bytes::{BufMut, Bytes};
22
23use crate::filter::FilterExpression;
24
25#[derive(Debug, Clone)]
27pub enum QueryParamValue {
28 String(String),
30 Binary(Vec<u8>),
32}
33
34#[derive(Debug, Clone)]
36pub struct QueryParam {
37 pub name: String,
39 pub value: QueryParamValue,
41}
42
43#[derive(Debug, Clone, Copy)]
45pub enum SortDirection {
46 Asc,
48 Desc,
50}
51
52#[derive(Debug, Clone)]
54pub struct SortBy {
55 pub field: String,
57 pub direction: SortDirection,
59}
60
61#[derive(Debug, Clone, Copy)]
63pub struct QueryLimit {
64 pub offset: usize,
66 pub num: usize,
68}
69
70#[derive(Debug, Clone)]
72pub struct GeoFilter {
73 pub field: String,
75 pub lon: f64,
77 pub lat: f64,
79 pub radius: f64,
81 pub unit: String,
83}
84
85#[derive(Debug, Clone)]
87pub struct QueryRender {
88 pub query_string: String,
90 pub params: Vec<QueryParam>,
92 pub return_fields: Vec<String>,
94 pub sort_by: Option<SortBy>,
96 pub limit: Option<QueryLimit>,
98 pub dialect: u32,
100 pub in_order: bool,
102 pub no_content: bool,
104 pub scorer: Option<String>,
106 pub geofilter: Option<GeoFilter>,
108}
109
110#[derive(Debug, Clone, Copy, PartialEq, Eq)]
112pub enum QueryKind {
113 Documents,
115 Count,
117}
118
119pub trait QueryString {
121 fn to_redis_query(&self) -> String;
123
124 fn render(&self) -> QueryRender {
126 QueryRender {
127 query_string: self.to_redis_query(),
128 params: self.params(),
129 return_fields: self.return_fields(),
130 sort_by: self.sort_by(),
131 limit: self.limit(),
132 dialect: self.dialect(),
133 in_order: self.in_order(),
134 no_content: self.no_content(),
135 scorer: self.scorer(),
136 geofilter: self.geofilter(),
137 }
138 }
139
140 fn params(&self) -> Vec<QueryParam> {
142 Vec::new()
143 }
144
145 fn return_fields(&self) -> Vec<String> {
147 Vec::new()
148 }
149
150 fn sort_by(&self) -> Option<SortBy> {
152 None
153 }
154
155 fn limit(&self) -> Option<QueryLimit> {
157 None
158 }
159
160 fn dialect(&self) -> u32 {
162 2
163 }
164
165 fn in_order(&self) -> bool {
167 false
168 }
169
170 fn no_content(&self) -> bool {
172 false
173 }
174
175 fn scorer(&self) -> Option<String> {
177 None
178 }
179
180 fn kind(&self) -> QueryKind {
182 QueryKind::Documents
183 }
184
185 fn should_unpack_json(&self) -> bool {
188 false
189 }
190
191 fn geofilter(&self) -> Option<GeoFilter> {
193 None
194 }
195}
196
197pub trait PageableQuery: QueryString + Clone {
199 fn paged(&self, offset: usize, num: usize) -> Self;
201}
202
203#[derive(Debug, Clone)]
204struct QueryOptions {
205 return_fields: Vec<String>,
206 limit: QueryLimit,
207 dialect: u32,
208 sort_by: Option<SortBy>,
209 in_order: bool,
210 scorer: Option<String>,
211}
212
213impl QueryOptions {
214 fn with_num_results(num_results: usize) -> Self {
215 Self {
216 return_fields: Vec::new(),
217 limit: QueryLimit {
218 offset: 0,
219 num: num_results,
220 },
221 dialect: 2,
222 sort_by: None,
223 in_order: false,
224 scorer: None,
225 }
226 }
227}
228
229#[derive(Debug, Clone)]
231pub struct Vector<'a> {
232 elements: Cow<'a, [f32]>,
233}
234
235impl<'a> Vector<'a> {
236 pub fn new(elements: impl Into<Cow<'a, [f32]>>) -> Self {
238 Self {
239 elements: elements.into(),
240 }
241 }
242
243 pub fn elements(&self) -> &[f32] {
245 &self.elements
246 }
247
248 pub fn to_bytes(&self) -> Bytes {
250 let mut buffer =
251 bytes::BytesMut::with_capacity(self.elements.len() * std::mem::size_of::<f32>());
252 for value in self.elements.iter().copied() {
253 buffer.put_f32_le(value);
254 }
255 buffer.freeze()
256 }
257}
258
259#[derive(Debug, Clone, Copy, PartialEq, Eq)]
261pub enum HybridPolicy {
262 Batches,
264 AdhocBf,
266}
267
268impl HybridPolicy {
269 pub fn as_str(&self) -> &'static str {
271 match self {
272 Self::Batches => "BATCHES",
273 Self::AdhocBf => "ADHOC_BF",
274 }
275 }
276}
277
278#[derive(Debug, Clone, Copy, PartialEq, Eq)]
280pub enum SearchHistoryMode {
281 Off,
283 On,
285 Auto,
287}
288
289impl SearchHistoryMode {
290 pub fn as_str(&self) -> &'static str {
292 match self {
293 Self::Off => "OFF",
294 Self::On => "ON",
295 Self::Auto => "AUTO",
296 }
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct VectorQuery<'a> {
303 vector: Vector<'a>,
304 vector_field_name: String,
305 num_results: usize,
306 filter_expression: Option<FilterExpression>,
307 ef_runtime: Option<usize>,
308 epsilon: Option<f64>,
309 hybrid_policy: Option<HybridPolicy>,
310 batch_size: Option<usize>,
311 search_window_size: Option<usize>,
312 use_search_history: Option<SearchHistoryMode>,
313 search_buffer_capacity: Option<usize>,
314 options: QueryOptions,
315}
316
317impl<'a> VectorQuery<'a> {
318 pub fn new(
320 vector: Vector<'a>,
321 vector_field_name: impl Into<String>,
322 num_results: usize,
323 ) -> Self {
324 let mut options = QueryOptions::with_num_results(num_results);
325 options.return_fields.push("vector_distance".to_owned());
326 options.sort_by = Some(SortBy {
327 field: "vector_distance".to_owned(),
328 direction: SortDirection::Asc,
329 });
330
331 Self {
332 vector,
333 vector_field_name: vector_field_name.into(),
334 num_results,
335 filter_expression: None,
336 ef_runtime: None,
337 epsilon: None,
338 hybrid_policy: None,
339 batch_size: None,
340 search_window_size: None,
341 use_search_history: None,
342 search_buffer_capacity: None,
343 options,
344 }
345 }
346
347 pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
349 self.filter_expression = Some(filter_expression);
350 self
351 }
352
353 pub fn set_filter(&mut self, filter_expression: FilterExpression) {
355 self.filter_expression = Some(filter_expression);
356 }
357
358 pub fn with_ef_runtime(mut self, ef_runtime: usize) -> Self {
360 self.ef_runtime = Some(ef_runtime);
361 self
362 }
363
364 pub fn set_ef_runtime(&mut self, ef_runtime: usize) {
366 self.ef_runtime = Some(ef_runtime);
367 }
368
369 pub fn ef_runtime(&self) -> Option<usize> {
371 self.ef_runtime
372 }
373
374 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
376 self.epsilon = Some(epsilon);
377 self
378 }
379
380 pub fn set_epsilon(&mut self, epsilon: f64) {
382 self.epsilon = Some(epsilon);
383 }
384
385 pub fn epsilon(&self) -> Option<f64> {
387 self.epsilon
388 }
389
390 pub fn with_hybrid_policy(mut self, policy: HybridPolicy) -> Self {
392 self.hybrid_policy = Some(policy);
393 self
394 }
395
396 pub fn set_hybrid_policy(&mut self, policy: HybridPolicy) {
398 self.hybrid_policy = Some(policy);
399 }
400
401 pub fn hybrid_policy(&self) -> Option<HybridPolicy> {
403 self.hybrid_policy
404 }
405
406 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
408 self.batch_size = Some(batch_size);
409 self
410 }
411
412 pub fn set_batch_size(&mut self, batch_size: usize) {
414 self.batch_size = Some(batch_size);
415 }
416
417 pub fn batch_size(&self) -> Option<usize> {
419 self.batch_size
420 }
421
422 pub fn with_search_window_size(mut self, size: usize) -> Self {
424 self.search_window_size = Some(size);
425 self
426 }
427
428 pub fn set_search_window_size(&mut self, size: usize) {
430 self.search_window_size = Some(size);
431 }
432
433 pub fn search_window_size(&self) -> Option<usize> {
435 self.search_window_size
436 }
437
438 pub fn with_use_search_history(mut self, mode: SearchHistoryMode) -> Self {
440 self.use_search_history = Some(mode);
441 self
442 }
443
444 pub fn set_use_search_history(&mut self, mode: SearchHistoryMode) {
446 self.use_search_history = Some(mode);
447 }
448
449 pub fn use_search_history(&self) -> Option<SearchHistoryMode> {
451 self.use_search_history
452 }
453
454 pub fn with_search_buffer_capacity(mut self, capacity: usize) -> Self {
456 self.search_buffer_capacity = Some(capacity);
457 self
458 }
459
460 pub fn set_search_buffer_capacity(&mut self, capacity: usize) {
462 self.search_buffer_capacity = Some(capacity);
463 }
464
465 pub fn search_buffer_capacity(&self) -> Option<usize> {
467 self.search_buffer_capacity
468 }
469
470 pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
472 where
473 I: IntoIterator<Item = S>,
474 S: Into<String>,
475 {
476 self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
477 if !self
478 .options
479 .return_fields
480 .iter()
481 .any(|field| field == "vector_distance")
482 {
483 self.options
484 .return_fields
485 .push("vector_distance".to_owned());
486 }
487 self
488 }
489
490 pub fn paging(mut self, offset: usize, num: usize) -> Self {
492 self.options.limit = QueryLimit { offset, num };
493 self
494 }
495
496 pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
498 self.options.sort_by = Some(SortBy {
499 field: field.into(),
500 direction,
501 });
502 self
503 }
504
505 pub fn in_order(mut self, in_order: bool) -> Self {
507 self.options.in_order = in_order;
508 self
509 }
510
511 pub fn with_dialect(mut self, dialect: u32) -> Self {
513 self.options.dialect = dialect;
514 self
515 }
516
517 pub fn vector(&self) -> &Vector<'a> {
519 &self.vector
520 }
521}
522
523impl QueryString for VectorQuery<'_> {
524 fn to_redis_query(&self) -> String {
525 let base = self
526 .filter_expression
527 .as_ref()
528 .map_or_else(|| "*".to_owned(), FilterExpression::to_redis_syntax);
529 let mut query = format!(
530 "{}=>[KNN {} @{} $vector AS vector_distance",
531 base, self.num_results, self.vector_field_name
532 );
533 if self.ef_runtime.is_some() {
534 query.push_str(" EF_RUNTIME $EF");
535 }
536 if self.epsilon.is_some() {
537 query.push_str(" EPSILON $EPSILON");
538 }
539 if self.search_window_size.is_some() {
540 query.push_str(" SEARCH_WINDOW_SIZE $SEARCH_WINDOW_SIZE");
541 }
542 if self.use_search_history.is_some() {
543 query.push_str(" USE_SEARCH_HISTORY $USE_SEARCH_HISTORY");
544 }
545 if self.search_buffer_capacity.is_some() {
546 query.push_str(" SEARCH_BUFFER_CAPACITY $SEARCH_BUFFER_CAPACITY");
547 }
548 query.push(']');
549 if let Some(policy) = &self.hybrid_policy {
550 query.push_str(&format!(" HYBRID_POLICY {}", policy.as_str()));
551 if let Some(batch_size) = self.batch_size {
552 query.push_str(&format!(" BATCH_SIZE {}", batch_size));
553 }
554 }
555 query
556 }
557
558 fn params(&self) -> Vec<QueryParam> {
559 let mut params = vec![QueryParam {
560 name: "vector".to_owned(),
561 value: QueryParamValue::Binary(self.vector.to_bytes().to_vec()),
562 }];
563 if let Some(ef_runtime) = self.ef_runtime {
564 params.push(QueryParam {
565 name: "EF".to_owned(),
566 value: QueryParamValue::String(ef_runtime.to_string()),
567 });
568 }
569 if let Some(epsilon) = self.epsilon {
570 params.push(QueryParam {
571 name: "EPSILON".to_owned(),
572 value: QueryParamValue::String(epsilon.to_string()),
573 });
574 }
575 if let Some(size) = self.search_window_size {
576 params.push(QueryParam {
577 name: "SEARCH_WINDOW_SIZE".to_owned(),
578 value: QueryParamValue::String(size.to_string()),
579 });
580 }
581 if let Some(mode) = &self.use_search_history {
582 params.push(QueryParam {
583 name: "USE_SEARCH_HISTORY".to_owned(),
584 value: QueryParamValue::String(mode.as_str().to_owned()),
585 });
586 }
587 if let Some(capacity) = self.search_buffer_capacity {
588 params.push(QueryParam {
589 name: "SEARCH_BUFFER_CAPACITY".to_owned(),
590 value: QueryParamValue::String(capacity.to_string()),
591 });
592 }
593 params
594 }
595
596 fn return_fields(&self) -> Vec<String> {
597 self.options.return_fields.clone()
598 }
599
600 fn sort_by(&self) -> Option<SortBy> {
601 self.options.sort_by.clone()
602 }
603
604 fn limit(&self) -> Option<QueryLimit> {
605 Some(self.options.limit)
606 }
607
608 fn dialect(&self) -> u32 {
609 self.options.dialect
610 }
611
612 fn in_order(&self) -> bool {
613 self.options.in_order
614 }
615}
616
617impl PageableQuery for VectorQuery<'_> {
618 fn paged(&self, offset: usize, num: usize) -> Self {
619 self.clone().paging(offset, num)
620 }
621}
622
623#[derive(Debug, Clone)]
625pub struct VectorRangeQuery<'a> {
626 vector: Vector<'a>,
627 vector_field_name: String,
628 distance_threshold: f32,
629 filter_expression: Option<FilterExpression>,
630 epsilon: Option<f64>,
631 hybrid_policy: Option<HybridPolicy>,
632 batch_size: Option<usize>,
633 search_window_size: Option<usize>,
634 use_search_history: Option<SearchHistoryMode>,
635 search_buffer_capacity: Option<usize>,
636 options: QueryOptions,
637}
638
639impl<'a> VectorRangeQuery<'a> {
640 pub fn new(
642 vector: Vector<'a>,
643 vector_field_name: impl Into<String>,
644 distance_threshold: f32,
645 ) -> Self {
646 let mut options = QueryOptions::with_num_results(10);
647 options.return_fields.push("vector_distance".to_owned());
648 options.sort_by = Some(SortBy {
649 field: "vector_distance".to_owned(),
650 direction: SortDirection::Asc,
651 });
652
653 Self {
654 vector,
655 vector_field_name: vector_field_name.into(),
656 distance_threshold,
657 filter_expression: None,
658 epsilon: None,
659 hybrid_policy: None,
660 batch_size: None,
661 search_window_size: None,
662 use_search_history: None,
663 search_buffer_capacity: None,
664 options,
665 }
666 }
667
668 pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
670 self.filter_expression = Some(filter_expression);
671 self
672 }
673
674 pub fn set_filter(&mut self, filter_expression: FilterExpression) {
676 self.filter_expression = Some(filter_expression);
677 }
678
679 pub fn distance_threshold(&self) -> f32 {
681 self.distance_threshold
682 }
683
684 pub fn set_distance_threshold(&mut self, distance_threshold: f32) {
686 self.distance_threshold = distance_threshold;
687 }
688
689 pub fn with_epsilon(mut self, epsilon: f64) -> Self {
691 self.epsilon = Some(epsilon);
692 self
693 }
694
695 pub fn set_epsilon(&mut self, epsilon: f64) {
697 self.epsilon = Some(epsilon);
698 }
699
700 pub fn epsilon(&self) -> Option<f64> {
702 self.epsilon
703 }
704
705 pub fn with_hybrid_policy(mut self, policy: HybridPolicy) -> Self {
707 self.hybrid_policy = Some(policy);
708 self
709 }
710
711 pub fn set_hybrid_policy(&mut self, policy: HybridPolicy) {
713 self.hybrid_policy = Some(policy);
714 }
715
716 pub fn hybrid_policy(&self) -> Option<HybridPolicy> {
718 self.hybrid_policy
719 }
720
721 pub fn with_batch_size(mut self, batch_size: usize) -> Self {
723 self.batch_size = Some(batch_size);
724 self
725 }
726
727 pub fn set_batch_size(&mut self, batch_size: usize) {
729 self.batch_size = Some(batch_size);
730 }
731
732 pub fn batch_size(&self) -> Option<usize> {
734 self.batch_size
735 }
736
737 pub fn with_search_window_size(mut self, size: usize) -> Self {
739 self.search_window_size = Some(size);
740 self
741 }
742
743 pub fn set_search_window_size(&mut self, size: usize) {
745 self.search_window_size = Some(size);
746 }
747
748 pub fn search_window_size(&self) -> Option<usize> {
750 self.search_window_size
751 }
752
753 pub fn with_use_search_history(mut self, mode: SearchHistoryMode) -> Self {
755 self.use_search_history = Some(mode);
756 self
757 }
758
759 pub fn set_use_search_history(&mut self, mode: SearchHistoryMode) {
761 self.use_search_history = Some(mode);
762 }
763
764 pub fn use_search_history(&self) -> Option<SearchHistoryMode> {
766 self.use_search_history
767 }
768
769 pub fn with_search_buffer_capacity(mut self, capacity: usize) -> Self {
771 self.search_buffer_capacity = Some(capacity);
772 self
773 }
774
775 pub fn set_search_buffer_capacity(&mut self, capacity: usize) {
777 self.search_buffer_capacity = Some(capacity);
778 }
779
780 pub fn search_buffer_capacity(&self) -> Option<usize> {
782 self.search_buffer_capacity
783 }
784
785 pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
787 where
788 I: IntoIterator<Item = S>,
789 S: Into<String>,
790 {
791 self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
792 if !self
793 .options
794 .return_fields
795 .iter()
796 .any(|field| field == "vector_distance")
797 {
798 self.options
799 .return_fields
800 .push("vector_distance".to_owned());
801 }
802 self
803 }
804
805 pub fn paging(mut self, offset: usize, num: usize) -> Self {
807 self.options.limit = QueryLimit { offset, num };
808 self
809 }
810
811 pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
813 self.options.sort_by = Some(SortBy {
814 field: field.into(),
815 direction,
816 });
817 self
818 }
819
820 pub fn in_order(mut self, in_order: bool) -> Self {
822 self.options.in_order = in_order;
823 self
824 }
825
826 pub fn with_dialect(mut self, dialect: u32) -> Self {
828 self.options.dialect = dialect;
829 self
830 }
831
832 pub fn vector(&self) -> &Vector<'a> {
834 &self.vector
835 }
836}
837
838impl QueryString for VectorRangeQuery<'_> {
839 fn to_redis_query(&self) -> String {
840 let filter = self
841 .filter_expression
842 .as_ref()
843 .map_or_else(|| "*".to_owned(), FilterExpression::to_redis_syntax);
844
845 let base_query = format!(
847 "@{}:[VECTOR_RANGE $distance_threshold $vector]",
848 self.vector_field_name
849 );
850
851 let mut attr_parts = vec!["$YIELD_DISTANCE_AS: vector_distance".to_owned()];
854 if let Some(epsilon) = self.epsilon {
855 attr_parts.push(format!("$EPSILON: {}", epsilon));
856 }
857 if let Some(size) = self.search_window_size {
858 attr_parts.push(format!("$SEARCH_WINDOW_SIZE: {}", size));
859 }
860 if let Some(mode) = &self.use_search_history {
861 attr_parts.push(format!("$USE_SEARCH_HISTORY: {}", mode.as_str()));
862 }
863 if let Some(capacity) = self.search_buffer_capacity {
864 attr_parts.push(format!("$SEARCH_BUFFER_CAPACITY: {}", capacity));
865 }
866 let attr_section = format!("=>{{{}}}", attr_parts.join("; "));
867
868 if filter == "*" {
871 format!("{}{}", base_query, attr_section)
872 } else {
873 format!("({}{} {})", base_query, attr_section, filter)
874 }
875 }
876
877 fn params(&self) -> Vec<QueryParam> {
878 let mut params = vec![
879 QueryParam {
880 name: "vector".to_owned(),
881 value: QueryParamValue::Binary(self.vector.to_bytes().to_vec()),
882 },
883 QueryParam {
884 name: "distance_threshold".to_owned(),
885 value: QueryParamValue::String(self.distance_threshold.to_string()),
886 },
887 ];
888 if let Some(policy) = &self.hybrid_policy {
889 params.push(QueryParam {
890 name: "HYBRID_POLICY".to_owned(),
891 value: QueryParamValue::String(policy.as_str().to_owned()),
892 });
893 }
894 if let Some(batch_size) = self.batch_size {
895 params.push(QueryParam {
896 name: "BATCH_SIZE".to_owned(),
897 value: QueryParamValue::String(batch_size.to_string()),
898 });
899 }
900 params
901 }
902
903 fn return_fields(&self) -> Vec<String> {
904 self.options.return_fields.clone()
905 }
906
907 fn sort_by(&self) -> Option<SortBy> {
908 self.options.sort_by.clone()
909 }
910
911 fn limit(&self) -> Option<QueryLimit> {
912 Some(self.options.limit)
913 }
914
915 fn dialect(&self) -> u32 {
916 self.options.dialect
917 }
918
919 fn in_order(&self) -> bool {
920 self.options.in_order
921 }
922}
923
924impl PageableQuery for VectorRangeQuery<'_> {
925 fn paged(&self, offset: usize, num: usize) -> Self {
926 self.clone().paging(offset, num)
927 }
928}
929
930#[derive(Debug, Clone)]
932pub struct TextQuery {
933 text: String,
934 text_field_name: Option<String>,
935 filter_expression: Option<FilterExpression>,
936 return_score: bool,
937 options: QueryOptions,
938 stopwords: Option<std::collections::HashSet<String>>,
939 text_weights: Option<std::collections::HashMap<String, f32>>,
940}
941
942impl TextQuery {
943 pub fn new(text: impl Into<String>) -> Self {
945 Self {
946 text: text.into(),
947 text_field_name: None,
948 filter_expression: None,
949 return_score: true,
950 options: QueryOptions::with_num_results(10),
951 stopwords: None,
952 text_weights: None,
953 }
954 }
955
956 pub fn for_field(mut self, text_field_name: impl Into<String>) -> Self {
958 self.text_field_name = Some(text_field_name.into());
959 self
960 }
961
962 pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
964 self.filter_expression = Some(filter_expression);
965 self
966 }
967
968 pub fn set_filter(&mut self, filter_expression: FilterExpression) {
970 self.filter_expression = Some(filter_expression);
971 }
972
973 pub fn with_return_score(mut self, return_score: bool) -> Self {
975 self.return_score = return_score;
976 self
977 }
978
979 pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
981 where
982 I: IntoIterator<Item = S>,
983 S: Into<String>,
984 {
985 self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
986 self
987 }
988
989 pub fn paging(mut self, offset: usize, num: usize) -> Self {
991 self.options.limit = QueryLimit { offset, num };
992 self
993 }
994
995 pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
997 self.options.sort_by = Some(SortBy {
998 field: field.into(),
999 direction,
1000 });
1001 self
1002 }
1003
1004 pub fn in_order(mut self, in_order: bool) -> Self {
1006 self.options.in_order = in_order;
1007 self
1008 }
1009
1010 pub fn with_dialect(mut self, dialect: u32) -> Self {
1012 self.options.dialect = dialect;
1013 self
1014 }
1015
1016 pub fn with_scorer(mut self, scorer: impl Into<String>) -> Self {
1018 self.options.scorer = Some(scorer.into());
1019 self
1020 }
1021
1022 pub fn with_stopwords(mut self, stopwords: std::collections::HashSet<String>) -> Self {
1027 self.stopwords = Some(stopwords);
1028 self
1029 }
1030
1031 pub fn with_text_weights(mut self, weights: std::collections::HashMap<String, f32>) -> Self {
1036 self.text_weights = Some(weights);
1037 self
1038 }
1039
1040 pub fn set_text_weights(&mut self, weights: std::collections::HashMap<String, f32>) {
1042 self.text_weights = Some(weights);
1043 }
1044
1045 pub fn text_weights(&self) -> Option<&std::collections::HashMap<String, f32>> {
1047 self.text_weights.as_ref()
1048 }
1049
1050 fn build_query_text(&self) -> String {
1052 let mut text = self.text.clone();
1053
1054 if let Some(stopwords) = &self.stopwords {
1056 if !stopwords.is_empty() {
1057 let words: Vec<&str> = text.split_whitespace().collect();
1058 let filtered: Vec<&str> = words
1059 .into_iter()
1060 .filter(|w| !stopwords.contains(&w.to_lowercase()))
1061 .collect();
1062 text = filtered.join(" ");
1063 }
1064 }
1065
1066 if let Some(weights) = &self.text_weights {
1068 if !weights.is_empty() {
1069 let words: Vec<String> = text
1070 .split_whitespace()
1071 .map(|w| {
1072 if let Some(weight) = weights.get(w) {
1073 format!("{}=>{{{}}}", w, weight)
1074 } else {
1075 w.to_owned()
1076 }
1077 })
1078 .collect();
1079 text = words.join(" ");
1080 }
1081 }
1082
1083 text
1084 }
1085}
1086
1087impl QueryString for TextQuery {
1088 fn to_redis_query(&self) -> String {
1089 let processed_text = self.build_query_text();
1090 let text_part = match &self.text_field_name {
1091 Some(field) => format!("@{}:({})", field, processed_text),
1092 None => processed_text,
1093 };
1094 match &self.filter_expression {
1095 Some(filter) => {
1096 let filter_str = filter.to_redis_syntax();
1097 if filter_str == "*" {
1098 text_part
1099 } else {
1100 format!("{} AND {}", text_part, filter_str)
1101 }
1102 }
1103 None => text_part,
1104 }
1105 }
1106
1107 fn return_fields(&self) -> Vec<String> {
1108 self.options.return_fields.clone()
1109 }
1110
1111 fn sort_by(&self) -> Option<SortBy> {
1112 self.options.sort_by.clone()
1113 }
1114
1115 fn limit(&self) -> Option<QueryLimit> {
1116 Some(self.options.limit)
1117 }
1118
1119 fn dialect(&self) -> u32 {
1120 self.options.dialect
1121 }
1122
1123 fn in_order(&self) -> bool {
1124 self.options.in_order
1125 }
1126
1127 fn scorer(&self) -> Option<String> {
1128 self.options.scorer.clone()
1129 }
1130}
1131
1132impl PageableQuery for TextQuery {
1133 fn paged(&self, offset: usize, num: usize) -> Self {
1134 self.clone().paging(offset, num)
1135 }
1136}
1137
1138#[derive(Debug, Clone)]
1140pub struct FilterQuery {
1141 filter_expression: FilterExpression,
1142 options: QueryOptions,
1143}
1144
1145impl FilterQuery {
1146 pub fn new(filter_expression: FilterExpression) -> Self {
1148 Self {
1149 filter_expression,
1150 options: QueryOptions::with_num_results(10),
1151 }
1152 }
1153
1154 pub fn set_filter(&mut self, filter_expression: FilterExpression) {
1156 self.filter_expression = filter_expression;
1157 }
1158
1159 pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
1161 where
1162 I: IntoIterator<Item = S>,
1163 S: Into<String>,
1164 {
1165 self.options.return_fields = return_fields.into_iter().map(Into::into).collect();
1166 self
1167 }
1168
1169 pub fn paging(mut self, offset: usize, num: usize) -> Self {
1171 self.options.limit = QueryLimit { offset, num };
1172 self
1173 }
1174
1175 pub fn sort_by(mut self, field: impl Into<String>, direction: SortDirection) -> Self {
1177 self.options.sort_by = Some(SortBy {
1178 field: field.into(),
1179 direction,
1180 });
1181 self
1182 }
1183
1184 pub fn in_order(mut self, in_order: bool) -> Self {
1186 self.options.in_order = in_order;
1187 self
1188 }
1189
1190 pub fn with_dialect(mut self, dialect: u32) -> Self {
1192 self.options.dialect = dialect;
1193 self
1194 }
1195}
1196
1197impl QueryString for FilterQuery {
1198 fn to_redis_query(&self) -> String {
1199 self.filter_expression.to_redis_syntax()
1200 }
1201
1202 fn return_fields(&self) -> Vec<String> {
1203 self.options.return_fields.clone()
1204 }
1205
1206 fn sort_by(&self) -> Option<SortBy> {
1207 self.options.sort_by.clone()
1208 }
1209
1210 fn limit(&self) -> Option<QueryLimit> {
1211 Some(self.options.limit)
1212 }
1213
1214 fn dialect(&self) -> u32 {
1215 self.options.dialect
1216 }
1217
1218 fn in_order(&self) -> bool {
1219 self.options.in_order
1220 }
1221
1222 fn should_unpack_json(&self) -> bool {
1223 true
1224 }
1225}
1226
1227impl PageableQuery for FilterQuery {
1228 fn paged(&self, offset: usize, num: usize) -> Self {
1229 self.clone().paging(offset, num)
1230 }
1231}
1232
1233#[derive(Debug, Clone)]
1235pub struct CountQuery {
1236 filter_expression: Option<FilterExpression>,
1237 dialect: u32,
1238}
1239
1240impl CountQuery {
1241 pub fn new() -> Self {
1243 Self {
1244 filter_expression: None,
1245 dialect: 2,
1246 }
1247 }
1248
1249 pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
1251 self.filter_expression = Some(filter_expression);
1252 self
1253 }
1254
1255 pub fn with_dialect(mut self, dialect: u32) -> Self {
1257 self.dialect = dialect;
1258 self
1259 }
1260}
1261
1262impl Default for CountQuery {
1263 fn default() -> Self {
1264 Self::new()
1265 }
1266}
1267
1268impl QueryString for CountQuery {
1269 fn to_redis_query(&self) -> String {
1270 self.filter_expression
1271 .as_ref()
1272 .map_or_else(|| "*".to_owned(), FilterExpression::to_redis_syntax)
1273 }
1274
1275 fn limit(&self) -> Option<QueryLimit> {
1276 Some(QueryLimit { offset: 0, num: 0 })
1277 }
1278
1279 fn dialect(&self) -> u32 {
1280 self.dialect
1281 }
1282
1283 fn no_content(&self) -> bool {
1284 true
1285 }
1286
1287 fn kind(&self) -> QueryKind {
1288 QueryKind::Count
1289 }
1290}
1291
1292#[derive(Debug, Clone, Copy)]
1294pub enum HybridCombinationMethod {
1295 Linear,
1297 Rrf,
1299}
1300
1301impl HybridCombinationMethod {
1302 pub fn redis_name(self) -> &'static str {
1304 match self {
1305 Self::Linear => "LINEAR",
1306 Self::Rrf => "RRF",
1307 }
1308 }
1309}
1310
1311#[derive(Debug, Clone, Copy)]
1313pub enum VectorSearchMethod {
1314 Knn,
1316 Range,
1318}
1319
1320#[derive(Debug, Clone)]
1343pub struct HybridQuery<'a> {
1344 text: String,
1345 text_field_name: String,
1346 vector: Vector<'a>,
1347 vector_field_name: String,
1348 vector_param_name: String,
1349 num_results: usize,
1350 text_scorer: Option<String>,
1351 yield_text_score_as: Option<String>,
1352 vector_search_method: Option<VectorSearchMethod>,
1353 knn_ef_runtime: Option<usize>,
1354 range_radius: Option<f32>,
1355 range_epsilon: Option<f32>,
1356 yield_vsim_score_as: Option<String>,
1357 filter_expression: Option<FilterExpression>,
1358 combination_method: Option<HybridCombinationMethod>,
1359 rrf_window: Option<usize>,
1360 rrf_constant: Option<usize>,
1361 linear_alpha: Option<f32>,
1362 yield_combined_score_as: Option<String>,
1363 return_fields: Vec<String>,
1364 stopwords: Option<std::collections::HashSet<String>>,
1365 text_weights: Option<std::collections::HashMap<String, f32>>,
1366}
1367
1368impl<'a> HybridQuery<'a> {
1369 pub fn new(
1371 text: impl Into<String>,
1372 text_field_name: impl Into<String>,
1373 vector: Vector<'a>,
1374 vector_field_name: impl Into<String>,
1375 ) -> Self {
1376 Self {
1377 text: text.into(),
1378 text_field_name: text_field_name.into(),
1379 vector,
1380 vector_field_name: vector_field_name.into(),
1381 vector_param_name: "vector".to_owned(),
1382 num_results: 10,
1383 text_scorer: None,
1384 yield_text_score_as: None,
1385 vector_search_method: None,
1386 knn_ef_runtime: None,
1387 range_radius: None,
1388 range_epsilon: None,
1389 yield_vsim_score_as: None,
1390 filter_expression: None,
1391 combination_method: None,
1392 rrf_window: None,
1393 rrf_constant: None,
1394 linear_alpha: None,
1395 yield_combined_score_as: None,
1396 return_fields: Vec::new(),
1397 stopwords: None,
1398 text_weights: None,
1399 }
1400 }
1401
1402 pub fn with_num_results(mut self, num_results: usize) -> Self {
1404 self.num_results = num_results;
1405 self
1406 }
1407
1408 pub fn with_text_scorer(mut self, scorer: impl Into<String>) -> Self {
1410 self.text_scorer = Some(scorer.into());
1411 self
1412 }
1413
1414 pub fn with_yield_text_score_as(mut self, alias: impl Into<String>) -> Self {
1416 self.yield_text_score_as = Some(alias.into());
1417 self
1418 }
1419
1420 pub fn with_knn(mut self, ef_runtime: Option<usize>) -> Self {
1422 self.vector_search_method = Some(VectorSearchMethod::Knn);
1423 self.knn_ef_runtime = ef_runtime;
1424 self
1425 }
1426
1427 pub fn with_range(mut self, radius: f32, epsilon: Option<f32>) -> Self {
1429 self.vector_search_method = Some(VectorSearchMethod::Range);
1430 self.range_radius = Some(radius);
1431 self.range_epsilon = epsilon;
1432 self
1433 }
1434
1435 pub fn with_yield_vsim_score_as(mut self, alias: impl Into<String>) -> Self {
1437 self.yield_vsim_score_as = Some(alias.into());
1438 self
1439 }
1440
1441 pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
1443 self.filter_expression = Some(filter_expression);
1444 self
1445 }
1446
1447 pub fn with_combination_method(mut self, method: HybridCombinationMethod) -> Self {
1449 self.combination_method = Some(method);
1450 self
1451 }
1452
1453 pub fn with_rrf(mut self, window: Option<usize>, constant: Option<usize>) -> Self {
1455 self.combination_method = Some(HybridCombinationMethod::Rrf);
1456 self.rrf_window = window;
1457 self.rrf_constant = constant;
1458 self
1459 }
1460
1461 pub fn with_linear(mut self, alpha: f32) -> Self {
1463 self.combination_method = Some(HybridCombinationMethod::Linear);
1464 self.linear_alpha = Some(alpha);
1465 self
1466 }
1467
1468 pub fn with_yield_combined_score_as(mut self, alias: impl Into<String>) -> Self {
1470 self.yield_combined_score_as = Some(alias.into());
1471 self
1472 }
1473
1474 pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
1476 where
1477 I: IntoIterator<Item = S>,
1478 S: Into<String>,
1479 {
1480 self.return_fields = return_fields.into_iter().map(Into::into).collect();
1481 self
1482 }
1483
1484 pub fn with_stopwords(mut self, stopwords: std::collections::HashSet<String>) -> Self {
1486 self.stopwords = Some(stopwords);
1487 self
1488 }
1489
1490 pub fn with_text_weights(mut self, weights: std::collections::HashMap<String, f32>) -> Self {
1492 self.text_weights = Some(weights);
1493 self
1494 }
1495
1496 pub fn with_vector_param_name(mut self, name: impl Into<String>) -> Self {
1498 self.vector_param_name = name.into();
1499 self
1500 }
1501
1502 pub fn vector(&self) -> &Vector<'a> {
1504 &self.vector
1505 }
1506
1507 fn build_query_string(&self) -> String {
1509 let mut text = self.text.clone();
1510
1511 if let Some(stopwords) = &self.stopwords {
1513 if !stopwords.is_empty() {
1514 let words: Vec<&str> = text.split_whitespace().collect();
1515 let filtered: Vec<&str> = words
1516 .into_iter()
1517 .filter(|w| !stopwords.contains(&w.to_lowercase()))
1518 .collect();
1519 text = filtered.join(" ");
1520 }
1521 }
1522
1523 if let Some(weights) = &self.text_weights {
1525 if !weights.is_empty() {
1526 let words: Vec<String> = text
1527 .split_whitespace()
1528 .map(|w| {
1529 if let Some(weight) = weights.get(w) {
1530 format!("{}=>{{{}}}", w, weight)
1531 } else {
1532 w.to_owned()
1533 }
1534 })
1535 .collect();
1536 text = words.join(" ");
1537 }
1538 }
1539
1540 format!("@{}:({})", self.text_field_name, text)
1541 }
1542
1543 pub fn build_cmd(&self, index_name: &str) -> redis::Cmd {
1561 let mut cmd = redis::cmd("FT.HYBRID");
1562 cmd.arg(index_name);
1563
1564 let query_string = self.build_query_string();
1566 cmd.arg("SEARCH").arg(&query_string);
1567
1568 if let Some(scorer) = &self.text_scorer {
1569 cmd.arg("SCORER").arg(scorer);
1570 }
1571 if let Some(alias) = &self.yield_text_score_as {
1572 cmd.arg("YIELD_SCORE_AS").arg(alias);
1573 }
1574
1575 cmd.arg("VSIM")
1577 .arg(format!("@{}", self.vector_field_name))
1578 .arg(format!("${}", self.vector_param_name));
1579
1580 if let Some(method) = self.vector_search_method {
1584 match method {
1585 VectorSearchMethod::Knn => {
1586 let mut kv_count = 1_usize; if self.knn_ef_runtime.is_some() {
1589 kv_count += 1;
1590 }
1591 cmd.arg("KNN").arg(kv_count * 2);
1592 cmd.arg("K").arg(self.num_results);
1593 if let Some(ef) = self.knn_ef_runtime {
1594 cmd.arg("EF_RUNTIME").arg(ef);
1595 }
1596 }
1597 VectorSearchMethod::Range => {
1598 let mut kv_count = 0_usize;
1599 if self.range_radius.is_some() {
1600 kv_count += 1;
1601 }
1602 if self.range_epsilon.is_some() {
1603 kv_count += 1;
1604 }
1605 if kv_count > 0 {
1606 cmd.arg("RANGE").arg(kv_count * 2);
1607 } else {
1608 cmd.arg("RANGE");
1609 }
1610 if let Some(radius) = self.range_radius {
1611 cmd.arg("RADIUS").arg(radius);
1612 }
1613 if let Some(epsilon) = self.range_epsilon {
1614 cmd.arg("EPSILON").arg(epsilon);
1615 }
1616 }
1617 }
1618 }
1619
1620 if let Some(filter) = &self.filter_expression {
1621 let filter_str = filter.to_redis_syntax();
1622 if filter_str != "*" {
1623 cmd.arg("FILTER").arg(&filter_str);
1624 }
1625 }
1626
1627 if let Some(alias) = &self.yield_vsim_score_as {
1628 cmd.arg("YIELD_SCORE_AS").arg(alias);
1629 }
1630
1631 if let Some(method) = &self.combination_method {
1638 cmd.arg("COMBINE").arg(method.redis_name());
1639
1640 let mut kv_pairs: Vec<(String, String)> = Vec::new();
1644 match method {
1645 HybridCombinationMethod::Rrf => {
1646 if let Some(window) = self.rrf_window {
1647 kv_pairs.push(("WINDOW".to_owned(), window.to_string()));
1648 }
1649 if let Some(constant) = self.rrf_constant {
1650 kv_pairs.push(("CONSTANT".to_owned(), constant.to_string()));
1651 }
1652 }
1653 HybridCombinationMethod::Linear => {
1654 if let Some(alpha) = self.linear_alpha {
1655 kv_pairs.push(("ALPHA".to_owned(), alpha.to_string()));
1656 kv_pairs.push(("BETA".to_owned(), (1.0 - alpha).to_string()));
1657 }
1658 }
1659 }
1660 if let Some(alias) = &self.yield_combined_score_as {
1661 kv_pairs.push(("YIELD_SCORE_AS".to_owned(), alias.clone()));
1662 }
1663
1664 if !kv_pairs.is_empty() {
1668 cmd.arg(kv_pairs.len() * 2);
1669 for (k, v) in &kv_pairs {
1670 cmd.arg(k).arg(v);
1671 }
1672 }
1673 }
1674
1675 if !self.return_fields.is_empty() {
1677 cmd.arg("LOAD");
1678 cmd.arg(self.return_fields.len());
1679 for field in &self.return_fields {
1680 cmd.arg(format!("@{}", field));
1681 }
1682 }
1683
1684 cmd.arg("LIMIT").arg(0).arg(self.num_results);
1686
1687 cmd.arg("PARAMS")
1689 .arg(2)
1690 .arg(&self.vector_param_name)
1691 .arg(self.vector.to_bytes().as_ref());
1692
1693 cmd
1694 }
1695}
1696
1697#[derive(Debug, Clone)]
1709pub struct AggregateHybridQuery<'a> {
1710 text: String,
1711 text_field_name: String,
1712 vector: Vector<'a>,
1713 vector_field_name: String,
1714 alpha: f32,
1715 num_results: usize,
1716 text_scorer: String,
1717 filter_expression: Option<FilterExpression>,
1718 return_fields: Vec<String>,
1719 stopwords: Option<std::collections::HashSet<String>>,
1720 text_weights: Option<std::collections::HashMap<String, f32>>,
1721 dialect: u32,
1722}
1723
1724impl<'a> AggregateHybridQuery<'a> {
1725 pub fn new(
1732 text: impl Into<String>,
1733 text_field_name: impl Into<String>,
1734 vector: Vector<'a>,
1735 vector_field_name: impl Into<String>,
1736 ) -> std::result::Result<Self, String> {
1737 let text = text.into();
1738 if text.trim().is_empty() {
1739 return Err("text string cannot be empty".to_owned());
1740 }
1741 Ok(Self {
1742 text,
1743 text_field_name: text_field_name.into(),
1744 vector,
1745 vector_field_name: vector_field_name.into(),
1746 alpha: 0.7,
1747 num_results: 10,
1748 text_scorer: "BM25STD".to_owned(),
1749 filter_expression: None,
1750 return_fields: Vec::new(),
1751 stopwords: None,
1752 text_weights: None,
1753 dialect: 2,
1754 })
1755 }
1756
1757 pub fn with_alpha(mut self, alpha: f32) -> Self {
1759 self.alpha = alpha;
1760 self
1761 }
1762
1763 pub fn with_num_results(mut self, num_results: usize) -> Self {
1765 self.num_results = num_results;
1766 self
1767 }
1768
1769 pub fn with_text_scorer(mut self, scorer: impl Into<String>) -> Self {
1771 self.text_scorer = scorer.into();
1772 self
1773 }
1774
1775 pub fn with_filter(mut self, filter_expression: FilterExpression) -> Self {
1777 self.filter_expression = Some(filter_expression);
1778 self
1779 }
1780
1781 pub fn with_return_fields<I, S>(mut self, return_fields: I) -> Self
1783 where
1784 I: IntoIterator<Item = S>,
1785 S: Into<String>,
1786 {
1787 self.return_fields = return_fields.into_iter().map(Into::into).collect();
1788 self
1789 }
1790
1791 pub fn with_stopwords(mut self, stopwords: std::collections::HashSet<String>) -> Self {
1793 self.stopwords = Some(stopwords);
1794 self
1795 }
1796
1797 pub fn with_text_weights(mut self, weights: std::collections::HashMap<String, f32>) -> Self {
1799 self.text_weights = Some(weights);
1800 self
1801 }
1802
1803 pub fn set_text_weights(&mut self, weights: std::collections::HashMap<String, f32>) {
1805 self.text_weights = Some(weights);
1806 }
1807
1808 pub fn with_dialect(mut self, dialect: u32) -> Self {
1810 self.dialect = dialect;
1811 self
1812 }
1813
1814 pub fn vector(&self) -> &Vector<'a> {
1816 &self.vector
1817 }
1818
1819 pub fn alpha(&self) -> f32 {
1821 self.alpha
1822 }
1823
1824 pub fn text(&self) -> &str {
1826 &self.text
1827 }
1828
1829 pub(crate) fn build_query_string(&self) -> String {
1831 let tokens: Vec<String> = self
1834 .text
1835 .split_whitespace()
1836 .map(|w| w.to_lowercase())
1837 .filter(|w| {
1838 if let Some(stopwords) = &self.stopwords {
1839 !stopwords.contains(w.as_str())
1840 } else {
1841 true
1842 }
1843 })
1844 .collect();
1845
1846 let tokens: Vec<String> = tokens
1848 .into_iter()
1849 .map(|w| {
1850 if let Some(weights) = &self.text_weights {
1851 if let Some(weight) = weights.get(&w) {
1852 return format!("{}=>{{{}}}", w, weight);
1853 }
1854 }
1855 w
1856 })
1857 .collect();
1858
1859 let text = tokens.join(" | ");
1860
1861 let base = if let Some(filter) = &self.filter_expression {
1866 let filter_str = filter.to_redis_syntax();
1867 if filter_str == "*" {
1868 format!("(~@{}:({}))", self.text_field_name, text)
1869 } else {
1870 format!("(~@{}:({}) AND {})", self.text_field_name, text, filter_str)
1871 }
1872 } else {
1873 format!("(~@{}:({}))", self.text_field_name, text)
1874 };
1875
1876 format!(
1878 "{}=>[KNN {} @{} $vector AS vector_distance]",
1879 base, self.num_results, self.vector_field_name,
1880 )
1881 }
1882
1883 pub fn build_aggregate_cmd(&self, index_name: &str) -> redis::Cmd {
1888 let query_string = self.build_query_string();
1889 let mut cmd = redis::cmd("FT.AGGREGATE");
1890 cmd.arg(index_name);
1891 cmd.arg(&query_string);
1892
1893 cmd.arg("SCORER").arg(&self.text_scorer);
1895
1896 cmd.arg("ADDSCORES");
1898
1899 if !self.return_fields.is_empty() {
1902 cmd.arg("LOAD");
1903 cmd.arg(self.return_fields.len());
1904 for field in &self.return_fields {
1905 cmd.arg(field);
1906 }
1907 }
1908
1909 cmd.arg("DIALECT").arg(self.dialect);
1911
1912 cmd.arg("APPLY")
1915 .arg("(2 - @vector_distance)/2")
1916 .arg("AS")
1917 .arg("vector_similarity");
1918 cmd.arg("APPLY").arg("@__score").arg("AS").arg("text_score");
1919
1920 let hybrid_expr = format!(
1921 "{}*@text_score + {}*@vector_similarity",
1922 1.0 - self.alpha,
1923 self.alpha
1924 );
1925 cmd.arg("APPLY")
1926 .arg(&hybrid_expr)
1927 .arg("AS")
1928 .arg("hybrid_score");
1929
1930 cmd.arg("SORTBY")
1931 .arg(2)
1932 .arg("@hybrid_score")
1933 .arg("DESC")
1934 .arg("MAX")
1935 .arg(self.num_results);
1936
1937 cmd.arg("PARAMS")
1939 .arg(2)
1940 .arg("vector")
1941 .arg(self.vector.to_bytes().as_ref());
1942
1943 cmd
1944 }
1945}
1946
1947#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1951pub enum VectorDtype {
1952 BFloat16,
1954 Float16,
1956 Float32,
1958 Float64,
1960 Int8,
1962 Uint8,
1964}
1965
1966impl Default for VectorDtype {
1967 fn default() -> Self {
1968 Self::Float32
1969 }
1970}
1971
1972impl VectorDtype {
1973 pub fn bytes_per_element(self) -> usize {
1975 match self {
1976 Self::BFloat16 | Self::Float16 => 2,
1977 Self::Float32 => 4,
1978 Self::Float64 => 8,
1979 Self::Int8 | Self::Uint8 => 1,
1980 }
1981 }
1982}
1983
1984#[derive(Debug, Clone)]
1990pub struct VectorInput<'a> {
1991 pub vector: Cow<'a, [u8]>,
1993 pub field_name: String,
1995 pub weight: f32,
1997 pub dtype: VectorDtype,
1999 pub max_distance: f32,
2001}
2002
2003impl<'a> VectorInput<'a> {
2004 pub fn from_floats(elements: &[f32], field_name: impl Into<String>) -> Self {
2007 let mut buf = Vec::with_capacity(elements.len() * std::mem::size_of::<f32>());
2008 for &v in elements {
2009 buf.extend_from_slice(&v.to_le_bytes());
2010 }
2011 Self {
2012 vector: Cow::Owned(buf),
2013 field_name: field_name.into(),
2014 weight: 1.0,
2015 dtype: VectorDtype::Float32,
2016 max_distance: 2.0,
2017 }
2018 }
2019
2020 pub fn from_bytes(
2022 bytes: impl Into<Cow<'a, [u8]>>,
2023 field_name: impl Into<String>,
2024 dtype: VectorDtype,
2025 ) -> Self {
2026 Self {
2027 vector: bytes.into(),
2028 field_name: field_name.into(),
2029 weight: 1.0,
2030 dtype,
2031 max_distance: 2.0,
2032 }
2033 }
2034
2035 pub fn with_weight(mut self, weight: f32) -> Self {
2037 self.weight = weight;
2038 self
2039 }
2040
2041 pub fn with_dtype(mut self, dtype: VectorDtype) -> Self {
2043 self.dtype = dtype;
2044 self
2045 }
2046
2047 pub fn with_max_distance(mut self, max_distance: f32) -> Self {
2053 assert!(
2054 (0.0..=2.0).contains(&max_distance),
2055 "max_distance must be in [0.0, 2.0], got {}",
2056 max_distance
2057 );
2058 self.max_distance = max_distance;
2059 self
2060 }
2061}
2062
2063#[derive(Debug, Clone)]
2070pub struct MultiVectorQuery<'a> {
2071 vectors: Vec<VectorInput<'a>>,
2072 filter_expression: Option<FilterExpression>,
2073 num_results: usize,
2074 return_fields: Vec<String>,
2075 dialect: u32,
2076}
2077
2078impl<'a> MultiVectorQuery<'a> {
2079 pub fn new(vectors: Vec<VectorInput<'a>>) -> Self {
2081 Self {
2082 vectors,
2083 filter_expression: None,
2084 num_results: 10,
2085 return_fields: Vec::new(),
2086 dialect: 2,
2087 }
2088 }
2089
2090 pub fn with_num_results(mut self, num_results: usize) -> Self {
2092 self.num_results = num_results;
2093 self
2094 }
2095
2096 pub fn with_filter(mut self, filter: FilterExpression) -> Self {
2098 self.filter_expression = Some(filter);
2099 self
2100 }
2101
2102 pub fn with_return_fields<I, S>(mut self, fields: I) -> Self
2104 where
2105 I: IntoIterator<Item = S>,
2106 S: Into<String>,
2107 {
2108 self.return_fields = fields.into_iter().map(Into::into).collect();
2109 self
2110 }
2111
2112 pub fn with_dialect(mut self, dialect: u32) -> Self {
2114 self.dialect = dialect;
2115 self
2116 }
2117
2118 pub fn vectors(&self) -> &[VectorInput<'a>] {
2120 &self.vectors
2121 }
2122
2123 pub fn build_query_string(&self) -> String {
2131 let mut parts = Vec::with_capacity(self.vectors.len());
2132 for (i, vi) in self.vectors.iter().enumerate() {
2133 parts.push(format!(
2134 "@{}:[VECTOR_RANGE {} $vector_{}]=>{{$YIELD_DISTANCE_AS: distance_{}}}",
2135 vi.field_name, vi.max_distance, i, i
2136 ));
2137 }
2138
2139 let base = parts.join(" AND ");
2140
2141 if let Some(filter) = &self.filter_expression {
2142 let filter_str = filter.to_redis_syntax();
2143 if filter_str != "*" {
2144 format!("({}) {}", filter_str, base)
2145 } else {
2146 base
2147 }
2148 } else {
2149 base
2150 }
2151 }
2152
2153 pub fn build_aggregate_cmd(&self, index_name: &str) -> redis::Cmd {
2155 let query_string = self.build_query_string();
2156 let mut cmd = redis::cmd("FT.AGGREGATE");
2157 cmd.arg(index_name);
2158 cmd.arg(&query_string);
2159
2160 cmd.arg("SCORER").arg("TFIDF");
2162
2163 cmd.arg("DIALECT").arg(self.dialect);
2165
2166 for i in 0..self.vectors.len() {
2168 cmd.arg("APPLY")
2169 .arg(format!("(2 - @distance_{})/2", i))
2170 .arg("AS")
2171 .arg(format!("score_{}", i));
2172 }
2173
2174 let combined_expr: Vec<String> = self
2176 .vectors
2177 .iter()
2178 .enumerate()
2179 .map(|(i, vi)| format!("@score_{} * {}", i, vi.weight))
2180 .collect();
2181 cmd.arg("APPLY")
2182 .arg(combined_expr.join(" + "))
2183 .arg("AS")
2184 .arg("combined_score");
2185
2186 cmd.arg("SORTBY")
2188 .arg(2)
2189 .arg("@combined_score")
2190 .arg("DESC")
2191 .arg("MAX")
2192 .arg(self.num_results);
2193
2194 if !self.return_fields.is_empty() {
2196 cmd.arg("LOAD");
2197 cmd.arg(self.return_fields.len());
2198 for field in &self.return_fields {
2199 cmd.arg(format!("@{}", field));
2200 }
2201 }
2202
2203 let param_count = self.vectors.len() * 2;
2205 cmd.arg("PARAMS").arg(param_count);
2206 for (i, vi) in self.vectors.iter().enumerate() {
2207 cmd.arg(format!("vector_{}", i));
2208 cmd.arg(vi.vector.as_ref());
2209 }
2210
2211 cmd
2212 }
2213}
2214
2215impl QueryString for str {
2216 fn to_redis_query(&self) -> String {
2217 self.to_owned()
2218 }
2219}
2220
2221impl QueryString for &str {
2222 fn to_redis_query(&self) -> String {
2223 (*self).to_owned()
2224 }
2225}
2226
2227impl QueryString for String {
2228 fn to_redis_query(&self) -> String {
2229 self.clone()
2230 }
2231}
2232
2233#[cfg(feature = "sql")]
2234mod sql;
2235#[cfg(feature = "sql")]
2236pub use sql::{SQLQuery, SqlParam};
2237
2238#[cfg(test)]
2239mod tests {
2240 use super::{
2241 AggregateHybridQuery, CountQuery, FilterQuery, HybridCombinationMethod, HybridPolicy,
2242 HybridQuery, MultiVectorQuery, PageableQuery, QueryString, SearchHistoryMode,
2243 SortDirection, TextQuery, Vector, VectorDtype, VectorInput, VectorQuery, VectorRangeQuery,
2244 };
2245 use crate::filter::{Num, Tag, Text};
2246
2247 #[test]
2248 fn vector_query_should_render_knn() {
2249 let query = VectorQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 5)
2250 .with_return_fields(["field1", "field2"])
2251 .with_dialect(3);
2252
2253 assert!(query.to_redis_query().contains("KNN 5"));
2254 assert_eq!(query.vector().to_bytes().len(), 12);
2255 assert_eq!(
2256 query.render().return_fields,
2257 vec!["field1", "field2", "vector_distance"]
2258 );
2259 assert_eq!(query.render().dialect, 3);
2260 }
2261
2262 #[test]
2263 fn hybrid_query_should_build_ft_hybrid_cmd_like_python_hybrid_query() {
2264 let query = HybridQuery::new(
2265 "a medical professional",
2266 "description",
2267 Vector::new(vec![0.1, 0.1, 0.5]),
2268 "user_embedding",
2269 )
2270 .with_num_results(10)
2271 .with_combination_method(HybridCombinationMethod::Rrf)
2272 .with_yield_combined_score_as("hybrid_score")
2273 .with_return_fields(["user", "age", "job"]);
2274
2275 let cmd = query.build_cmd("my_index");
2276 let packed = cmd.get_packed_command();
2277 let cmd_str = String::from_utf8_lossy(&packed);
2278
2279 assert!(cmd_str.contains("FT.HYBRID"));
2280 assert!(cmd_str.contains("my_index"));
2281 assert!(cmd_str.contains("@description:(a medical professional)"));
2282 assert!(cmd_str.contains("COMBINE"));
2283 assert!(cmd_str.contains("RRF"));
2284 assert!(cmd_str.contains("YIELD_SCORE_AS"));
2285 assert!(cmd_str.contains("hybrid_score"));
2286 }
2287
2288 #[test]
2289 fn hybrid_query_with_rrf_params_like_python_hybrid_query_rrf() {
2290 let query = HybridQuery::new(
2291 "search text",
2292 "content",
2293 Vector::new(vec![0.5, 0.5]),
2294 "vec_field",
2295 )
2296 .with_rrf(Some(100), Some(10));
2297
2298 let cmd = query.build_cmd("idx");
2299 let packed = cmd.get_packed_command();
2300 let cmd_str = String::from_utf8_lossy(&packed);
2301
2302 assert!(cmd_str.contains("COMBINE"));
2303 assert!(cmd_str.contains("RRF"));
2304 assert!(cmd_str.contains("WINDOW"));
2305 assert!(cmd_str.contains("CONSTANT"));
2306 }
2307
2308 #[test]
2309 fn hybrid_query_with_linear_alpha_like_python_hybrid_query_linear() {
2310 let query =
2311 HybridQuery::new("query text", "body", Vector::new(vec![1.0]), "vec").with_linear(0.3);
2312
2313 let cmd = query.build_cmd("idx");
2314 let packed = cmd.get_packed_command();
2315 let cmd_str = String::from_utf8_lossy(&packed);
2316
2317 assert!(cmd_str.contains("COMBINE"));
2318 assert!(cmd_str.contains("LINEAR"));
2319 assert!(cmd_str.contains("ALPHA"));
2320 }
2321
2322 #[test]
2323 fn hybrid_query_with_filter_like_python_hybrid_query_filter() {
2324 let filter = Tag::new("status").eq("active");
2325 let query = HybridQuery::new("doctors", "description", Vector::new(vec![1.0, 2.0]), "vec")
2326 .with_filter(filter);
2327
2328 let cmd = query.build_cmd("idx");
2329 let packed = cmd.get_packed_command();
2330 let cmd_str = String::from_utf8_lossy(&packed);
2331
2332 assert!(cmd_str.contains("FILTER"));
2333 assert!(cmd_str.contains("@status:{active}"));
2334 }
2335
2336 #[test]
2337 fn hybrid_query_with_stopwords_and_weights_like_python_hybrid_query() {
2338 use std::collections::{HashMap, HashSet};
2339 let mut stopwords = HashSet::new();
2340 stopwords.insert("the".to_owned());
2341 stopwords.insert("a".to_owned());
2342
2343 let mut weights = HashMap::new();
2344 weights.insert("doctor".to_owned(), 2.0_f32);
2345
2346 let query = HybridQuery::new(
2347 "a doctor in the house",
2348 "description",
2349 Vector::new(vec![1.0]),
2350 "vec",
2351 )
2352 .with_stopwords(stopwords)
2353 .with_text_weights(weights);
2354
2355 let query_string = query.build_query_string();
2356 assert!(!query_string.contains(" a "));
2358 assert!(!query_string.contains(" the "));
2359 assert!(query_string.contains("doctor"));
2360 assert!(query_string.contains("doctor=>{2}"));
2362 }
2363
2364 #[test]
2365 fn hybrid_query_with_text_scorer_like_python_hybrid_query() {
2366 let query = HybridQuery::new("test", "body", Vector::new(vec![1.0]), "vec")
2367 .with_text_scorer("BM25STD")
2368 .with_yield_text_score_as("text_score");
2369
2370 let cmd = query.build_cmd("idx");
2371 let packed = cmd.get_packed_command();
2372 let cmd_str = String::from_utf8_lossy(&packed);
2373
2374 assert!(cmd_str.contains("SCORER"));
2375 assert!(cmd_str.contains("BM25STD"));
2376 assert!(cmd_str.contains("YIELD_SCORE_AS"));
2377 assert!(cmd_str.contains("text_score"));
2378 }
2379
2380 #[test]
2381 fn filter_query_should_track_paging_and_sort_like_python_test_query_types() {
2382 let query = FilterQuery::new(Tag::new("brand").eq("Nike"))
2383 .with_return_fields(["brand", "price"])
2384 .paging(5, 7)
2385 .sort_by("price", SortDirection::Asc)
2386 .in_order(true)
2387 .with_dialect(2);
2388
2389 let render = query.render();
2390 assert_eq!(render.return_fields, vec!["brand", "price"]);
2391 assert_eq!(render.limit.expect("limit").offset, 5);
2392 assert_eq!(render.limit.expect("limit").num, 7);
2393 assert!(render.sort_by.is_some());
2394 assert!(render.in_order);
2395 assert_eq!(render.dialect, 2);
2396 }
2397
2398 #[test]
2399 fn count_query_should_use_nocontent_and_zero_limit_like_python_test_query_types() {
2400 let render = CountQuery::new()
2401 .with_filter(Tag::new("brand").eq("Nike"))
2402 .render();
2403
2404 assert!(render.no_content);
2405 assert_eq!(render.limit.expect("limit").num, 0);
2406 assert_eq!(render.dialect, 2);
2407 }
2408
2409 #[test]
2410 fn text_query_should_track_return_fields_and_limit_like_python_test_query_types() {
2411 let render = TextQuery::new("basketball")
2412 .for_field("description")
2413 .with_return_fields(["title", "genre", "rating"])
2414 .paging(5, 7)
2415 .render();
2416
2417 assert_eq!(render.return_fields, vec!["title", "genre", "rating"]);
2418 assert_eq!(render.limit.expect("limit").offset, 5);
2419 assert!(render.query_string.contains("@description:(basketball)"));
2420 }
2421
2422 #[test]
2423 fn range_query_should_include_vector_params_like_python_test_query_types() {
2424 let render = VectorRangeQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 0.2)
2425 .with_return_fields(["field1"])
2426 .render();
2427
2428 assert_eq!(render.params.len(), 2);
2430 assert_eq!(render.params[0].name, "vector");
2431 assert_eq!(render.params[1].name, "distance_threshold");
2432 assert_eq!(render.return_fields, vec!["field1", "vector_distance"]);
2433 }
2434
2435 #[test]
2436 fn vector_range_query_should_update_distance_threshold_like_python_integration_test_query() {
2437 let mut query = VectorRangeQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 0.2);
2438 assert_eq!(query.distance_threshold(), 0.2);
2439
2440 query.set_distance_threshold(0.1);
2441
2442 assert_eq!(query.distance_threshold(), 0.1);
2443 assert!(
2444 query
2445 .to_redis_query()
2446 .contains("VECTOR_RANGE $distance_threshold")
2447 );
2448 }
2449
2450 #[test]
2451 fn vector_query_should_replace_filter_in_place_like_python_integration_test_query() {
2452 let mut query = VectorQuery::new(Vector::new(vec![1.0, 2.0, 3.0]), "embedding", 5);
2453 query.set_filter(Tag::new("brand").eq("Nike"));
2454 assert!(query.to_redis_query().starts_with("@brand:{Nike}"));
2455
2456 query.set_filter(Num::new("price").gte(10.0));
2457 assert!(query.to_redis_query().starts_with("@price:[10 +inf]"));
2458 }
2459
2460 #[test]
2461 fn pageable_queries_should_clone_updated_limits_for_pagination() {
2462 let query = FilterQuery::new(Tag::new("brand").eq("Nike")).paging(0, 5);
2463
2464 let paged = query.paged(10, 3);
2465
2466 assert_eq!(paged.render().limit.expect("limit").offset, 10);
2467 assert_eq!(paged.render().limit.expect("limit").num, 3);
2468 assert_eq!(query.render().limit.expect("limit").offset, 0);
2469 }
2470
2471 #[test]
2472 fn raw_string_queries_should_render_directly_for_python_style_batch_search() {
2473 let render = "@test:{foo}".render();
2474
2475 assert_eq!(render.query_string, "@test:{foo}");
2476 assert!(render.params.is_empty());
2477 }
2478
2479 #[test]
2480 fn aggregate_hybrid_query_should_reject_empty_text() {
2481 let result = AggregateHybridQuery::new("", "desc", Vector::new(vec![1.0]), "vec");
2482 assert!(result.is_err());
2483 }
2484
2485 #[test]
2486 fn aggregate_hybrid_query_should_build_query_string_like_python_aggregate_hybrid() {
2487 let query = AggregateHybridQuery::new(
2488 "a medical professional with expertise in lung cancer",
2489 "description",
2490 Vector::new(vec![0.1, 0.1, 0.5]),
2491 "user_embedding",
2492 )
2493 .unwrap()
2494 .with_num_results(10);
2495
2496 let qs = query.build_query_string();
2497 assert!(
2500 qs.contains("~@description:("),
2501 "should use ~ (optional) prefix: {qs}"
2502 );
2503 assert!(qs.contains(" | "), "tokens should be OR-joined: {qs}");
2504 assert!(qs.contains("=>[KNN 10 @user_embedding $vector AS vector_distance]"));
2505 }
2506
2507 #[test]
2508 fn aggregate_hybrid_query_should_build_ft_aggregate_cmd_like_python() {
2509 let query = AggregateHybridQuery::new(
2510 "medical professional",
2511 "description",
2512 Vector::new(vec![0.1, 0.1, 0.5]),
2513 "user_embedding",
2514 )
2515 .unwrap()
2516 .with_alpha(0.5)
2517 .with_num_results(3)
2518 .with_text_scorer("BM25STD")
2519 .with_return_fields(["user", "age", "job"]);
2520
2521 let cmd = query.build_aggregate_cmd("my_index");
2522 let packed = cmd.get_packed_command();
2523 let cmd_str = String::from_utf8_lossy(&packed);
2524
2525 assert!(cmd_str.contains("FT.AGGREGATE"));
2526 assert!(cmd_str.contains("my_index"));
2527 assert!(cmd_str.contains("SCORER"));
2528 assert!(cmd_str.contains("BM25STD"));
2529 assert!(cmd_str.contains("ADDSCORES"));
2530 assert!(cmd_str.contains("vector_similarity"));
2531 assert!(cmd_str.contains("text_score"));
2532 assert!(cmd_str.contains("hybrid_score"));
2533 assert!(cmd_str.contains("SORTBY"));
2534 assert!(cmd_str.contains("LOAD"));
2535 assert!(cmd_str.contains("DIALECT"));
2536 assert!(cmd_str.contains("PARAMS"));
2537 }
2538
2539 #[test]
2540 fn aggregate_hybrid_query_with_filter_like_python_aggregate_filter() {
2541 let filter = Tag::new("credit_score").eq("high") & Num::new("age").gt(30.0);
2542 let query = AggregateHybridQuery::new(
2543 "medical professional",
2544 "description",
2545 Vector::new(vec![0.1, 0.1, 0.5]),
2546 "user_embedding",
2547 )
2548 .unwrap()
2549 .with_filter(filter);
2550
2551 let qs = query.build_query_string();
2552 assert!(qs.contains("@credit_score:{high}"));
2553 assert!(qs.contains("@age:[(30"));
2554 }
2555
2556 #[test]
2557 fn aggregate_hybrid_query_with_stopwords_like_python_aggregate_stopwords() {
2558 use std::collections::HashSet;
2559 let mut stopwords = HashSet::new();
2560 stopwords.insert("medical".to_owned());
2561 stopwords.insert("expertise".to_owned());
2562
2563 let query = AggregateHybridQuery::new(
2564 "a medical professional with expertise in lung cancer",
2565 "description",
2566 Vector::new(vec![0.1, 0.1, 0.5]),
2567 "user_embedding",
2568 )
2569 .unwrap()
2570 .with_stopwords(stopwords);
2571
2572 let qs = query.build_query_string();
2573 assert!(!qs.contains("medical"));
2574 assert!(!qs.contains("expertise"));
2575 }
2576
2577 #[test]
2578 fn aggregate_hybrid_query_with_text_weights_like_python_aggregate_word_weights() {
2579 use std::collections::HashMap;
2580 let mut weights = HashMap::new();
2581 weights.insert("medical".to_owned(), 3.4_f32);
2582 weights.insert("cancers".to_owned(), 5.0_f32);
2583
2584 let query = AggregateHybridQuery::new(
2585 "a medical professional with expertise in lung cancers",
2586 "description",
2587 Vector::new(vec![0.1, 0.1, 0.5]),
2588 "user_embedding",
2589 )
2590 .unwrap()
2591 .with_text_weights(weights);
2592
2593 let qs = query.build_query_string();
2594 assert!(qs.contains("medical=>{3.4}"));
2595 assert!(qs.contains("cancers=>{5}"));
2596 }
2597
2598 #[test]
2599 fn aggregate_hybrid_query_set_text_weights_should_match_constructor_weights() {
2600 use std::collections::HashMap;
2601 let mut weights = HashMap::new();
2602 weights.insert("medical".to_owned(), 5.0_f32);
2603
2604 let query1 = AggregateHybridQuery::new(
2605 "a medical professional",
2606 "description",
2607 Vector::new(vec![0.1, 0.1, 0.5]),
2608 "user_embedding",
2609 )
2610 .unwrap()
2611 .with_text_weights(weights.clone());
2612
2613 let mut query2 = AggregateHybridQuery::new(
2614 "a medical professional",
2615 "description",
2616 Vector::new(vec![0.1, 0.1, 0.5]),
2617 "user_embedding",
2618 )
2619 .unwrap();
2620 query2.set_text_weights(weights);
2621
2622 assert_eq!(query1.build_query_string(), query2.build_query_string());
2623 }
2624
2625 #[test]
2631 fn multi_vector_query_should_build_vector_range_query_like_python() {
2632 let v1 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "text embedding")
2634 .with_weight(0.2)
2635 .with_max_distance(0.7);
2636 let v2 = VectorInput::from_floats(&[0.5, 0.5], "image embedding")
2637 .with_weight(0.7)
2638 .with_max_distance(1.8);
2639
2640 let query = MultiVectorQuery::new(vec![v1, v2]);
2641 let qs = query.build_query_string();
2642
2643 assert!(qs.contains("@text embedding:[VECTOR_RANGE 0.7 $vector_0]"));
2644 assert!(qs.contains("YIELD_DISTANCE_AS: distance_0"));
2645 assert!(qs.contains("@image embedding:[VECTOR_RANGE 1.8 $vector_1]"));
2646 assert!(qs.contains("YIELD_DISTANCE_AS: distance_1"));
2647 assert!(qs.contains("AND"));
2648 }
2649
2650 #[test]
2651 fn multi_vector_query_default_properties_like_python() {
2652 let vi = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "field_1");
2654 assert_eq!(vi.weight, 1.0);
2655 assert_eq!(vi.dtype, VectorDtype::Float32);
2656 assert_eq!(vi.max_distance, 2.0);
2657
2658 let query = MultiVectorQuery::new(vec![vi]);
2659 assert!(query.filter_expression.is_none());
2660 assert_eq!(query.num_results, 10);
2661 assert!(query.return_fields.is_empty());
2662 assert_eq!(query.dialect, 2);
2663 }
2664
2665 #[test]
2666 fn multi_vector_query_should_accept_multiple_vectors_like_python() {
2667 let v1 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "field_1")
2669 .with_weight(0.2)
2670 .with_max_distance(2.0);
2671 let v2 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "field_2")
2672 .with_weight(0.5)
2673 .with_max_distance(1.5);
2674 let v3 = VectorInput::from_floats(&[0.5, 0.5], "field_3")
2675 .with_weight(0.6)
2676 .with_max_distance(0.4);
2677 let v4 = VectorInput::from_floats(&[0.1, 0.1, 0.1], "field_4")
2678 .with_weight(0.1)
2679 .with_max_distance(0.01);
2680
2681 let query = MultiVectorQuery::new(vec![v1, v2, v3, v4]);
2682 assert_eq!(query.vectors().len(), 4);
2683 }
2684
2685 #[test]
2686 fn multi_vector_query_overrides_like_python() {
2687 let vi = VectorInput::from_floats(&[0.1, 0.2], "field_1");
2689 let filter = Tag::new("user group").one_of(["group A", "group C"]);
2690
2691 let query = MultiVectorQuery::new(vec![vi])
2692 .with_filter(filter)
2693 .with_num_results(5)
2694 .with_return_fields(["field_1", "user name", "address"])
2695 .with_dialect(4);
2696
2697 assert!(query.filter_expression.is_some());
2698 assert_eq!(query.num_results, 5);
2699 assert_eq!(query.return_fields, vec!["field_1", "user name", "address"]);
2700 assert_eq!(query.dialect, 4);
2701 }
2702
2703 #[test]
2704 fn multi_vector_query_aggregate_cmd_like_python() {
2705 let v1 = VectorInput::from_floats(&[0.1, 0.2, 0.3, 0.4], "text embedding")
2707 .with_weight(0.2)
2708 .with_max_distance(0.7);
2709 let v2 = VectorInput::from_floats(&[0.5, 0.5], "image embedding")
2710 .with_weight(0.7)
2711 .with_max_distance(1.8);
2712
2713 let query = MultiVectorQuery::new(vec![v1, v2]);
2714 let cmd = query.build_aggregate_cmd("my_index");
2715 let packed = cmd.get_packed_command();
2716 let cmd_str = String::from_utf8_lossy(&packed);
2717
2718 assert!(cmd_str.contains("FT.AGGREGATE"));
2719 assert!(cmd_str.contains("my_index"));
2720 assert!(cmd_str.contains("SCORER"));
2721 assert!(cmd_str.contains("TFIDF"));
2722 assert!(cmd_str.contains("score_0"));
2723 assert!(cmd_str.contains("score_1"));
2724 assert!(cmd_str.contains("combined_score"));
2725 assert!(cmd_str.contains("SORTBY"));
2726 assert!(cmd_str.contains("PARAMS"));
2727 }
2728
2729 #[test]
2730 fn multi_vector_query_with_filter_like_python() {
2731 let v1 = VectorInput::from_floats(&[0.1, 0.1, 0.5], "user_embedding");
2733 let v2 = VectorInput::from_floats(&[0.3, 0.4, 0.7, 0.2, -0.3], "image_embedding");
2734 let filter = Text::new("description").eq("medical");
2735
2736 let query = MultiVectorQuery::new(vec![v1, v2]).with_filter(filter);
2737
2738 let qs = query.build_query_string();
2739 assert!(qs.contains("@description"));
2740 assert!(qs.contains("medical"));
2741 }
2742
2743 #[test]
2744 #[should_panic(expected = "max_distance must be in [0.0, 2.0]")]
2745 fn vector_input_should_reject_invalid_max_distance_like_python() {
2746 VectorInput::from_floats(&[0.1, 0.2], "field").with_max_distance(2.001);
2748 }
2749
2750 #[test]
2751 #[should_panic(expected = "max_distance must be in [0.0, 2.0]")]
2752 fn vector_input_should_reject_negative_max_distance_like_python() {
2753 VectorInput::from_floats(&[0.1, 0.2], "field").with_max_distance(-0.1);
2755 }
2756
2757 #[test]
2758 fn vector_input_from_bytes_like_python() {
2759 let floats = [0.1_f32, 0.2, 0.3, 0.4];
2761 let mut expected_bytes = Vec::new();
2762 for &f in &floats {
2763 expected_bytes.extend_from_slice(&f.to_le_bytes());
2764 }
2765 let vi = VectorInput::from_floats(&floats, "field_1");
2766 assert_eq!(vi.vector.as_ref(), expected_bytes.as_slice());
2767
2768 let vi2 = VectorInput::from_bytes(expected_bytes.clone(), "field_1", VectorDtype::Float32);
2770 assert_eq!(vi2.vector.as_ref(), expected_bytes.as_slice());
2771 }
2772
2773 #[test]
2779 fn aggregate_hybrid_query_reject_stopword_only_text_like_python() {
2780 let result = AggregateHybridQuery::new(
2784 "",
2785 "description",
2786 Vector::new(vec![0.1, 0.1, 0.5]),
2787 "user_embedding",
2788 );
2789 assert!(result.is_err());
2790 }
2791
2792 #[test]
2793 fn aggregate_hybrid_query_with_string_filter_like_python() {
2794 use crate::filter::FilterExpression;
2797 let filter_str = "@category:{tech|science|engineering}";
2798 let filter = FilterExpression::raw(filter_str);
2799
2800 let query = AggregateHybridQuery::new(
2801 "search for document 12345",
2802 "description",
2803 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2804 "embedding",
2805 )
2806 .unwrap()
2807 .with_filter(filter);
2808
2809 let qs = query.build_query_string();
2810 assert!(
2813 qs.contains("~@description:(search | for | document | 12345)"),
2814 "tokens should be OR-joined with ~ prefix: {qs}"
2815 );
2816 assert!(
2817 qs.contains("AND @category:{tech|science|engineering}"),
2818 "filter should be AND-joined: {qs}"
2819 );
2820 }
2821
2822 #[test]
2823 fn aggregate_hybrid_query_wildcard_filter_is_ignored_like_python() {
2824 use crate::filter::FilterExpression;
2826 let filter = FilterExpression::raw("*");
2827
2828 let query = AggregateHybridQuery::new(
2829 "search text",
2830 "description",
2831 Vector::new(vec![0.1]),
2832 "embedding",
2833 )
2834 .unwrap()
2835 .with_filter(filter);
2836
2837 let qs = query.build_query_string();
2838 assert!(!qs.contains("AND"));
2839 }
2840
2841 #[test]
2842 fn aggregate_hybrid_query_text_weights_validation_like_python() {
2843 use std::collections::HashMap;
2846
2847 let q1 = AggregateHybridQuery::new(
2848 "sample text query",
2849 "description",
2850 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2851 "embedding",
2852 )
2853 .unwrap()
2854 .with_text_weights(HashMap::new());
2855 assert!(q1.build_query_string().contains("sample"));
2856
2857 let mut weights = HashMap::new();
2859 weights.insert("alpha".to_owned(), 0.2_f32);
2860 weights.insert("bravo".to_owned(), 0.4_f32);
2861 let q2 = AggregateHybridQuery::new(
2862 "sample text query",
2863 "description",
2864 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2865 "embedding",
2866 )
2867 .unwrap()
2868 .with_text_weights(weights);
2869 let qs = q2.build_query_string();
2871 assert!(qs.contains("sample"));
2872 }
2873
2874 #[test]
2880 fn hybrid_query_without_filter_like_python() {
2881 let query = HybridQuery::new(
2883 "test query",
2884 "description",
2885 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2886 "embedding",
2887 );
2888
2889 let cmd = query.build_cmd("idx");
2890 let packed = cmd.get_packed_command();
2891 let cmd_str = String::from_utf8_lossy(&packed);
2892
2893 assert!(!cmd_str.contains("FILTER"));
2895 assert!(cmd_str.contains("@description:(test query)"));
2896 }
2897
2898 #[test]
2899 fn hybrid_query_vector_search_method_knn_like_python() {
2900 let query = HybridQuery::new(
2902 "test query",
2903 "description",
2904 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2905 "embedding",
2906 )
2907 .with_knn(Some(100))
2908 .with_num_results(10);
2909
2910 let cmd = query.build_cmd("idx");
2911 let packed = cmd.get_packed_command();
2912 let cmd_str = String::from_utf8_lossy(&packed);
2913
2914 assert!(cmd_str.contains("KNN"));
2915 assert!(cmd_str.contains("EF_RUNTIME"));
2916 }
2917
2918 #[test]
2919 fn hybrid_query_vector_search_method_range_like_python() {
2920 let query = HybridQuery::new(
2922 "test query",
2923 "description",
2924 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2925 "embedding",
2926 )
2927 .with_range(10.0, Some(0.1));
2928
2929 let cmd = query.build_cmd("idx");
2930 let packed = cmd.get_packed_command();
2931 let cmd_str = String::from_utf8_lossy(&packed);
2932
2933 assert!(cmd_str.contains("RANGE"));
2934 assert!(cmd_str.contains("RADIUS"));
2935 assert!(cmd_str.contains("EPSILON"));
2936 }
2937
2938 #[test]
2939 fn hybrid_query_without_vector_search_method_like_python() {
2940 let query = HybridQuery::new(
2942 "test query",
2943 "description",
2944 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2945 "embedding",
2946 );
2947
2948 let cmd = query.build_cmd("idx");
2949 let packed = cmd.get_packed_command();
2950 let cmd_str = String::from_utf8_lossy(&packed);
2951
2952 assert!(cmd_str.contains("VSIM"));
2953 assert!(!cmd_str.contains("KNN"));
2955 assert!(!cmd_str.contains("RANGE"));
2956 }
2957
2958 #[test]
2959 fn hybrid_query_rrf_with_both_params_like_python() {
2960 let query = HybridQuery::new(
2962 "test query",
2963 "description",
2964 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2965 "embedding",
2966 )
2967 .with_rrf(Some(20), Some(50))
2968 .with_yield_combined_score_as("rrf_score");
2969
2970 let cmd = query.build_cmd("idx");
2971 let packed = cmd.get_packed_command();
2972 let cmd_str = String::from_utf8_lossy(&packed);
2973
2974 assert!(cmd_str.contains("RRF"));
2975 assert!(cmd_str.contains("WINDOW"));
2976 assert!(cmd_str.contains("CONSTANT"));
2977 assert!(cmd_str.contains("YIELD_SCORE_AS"));
2978 assert!(cmd_str.contains("rrf_score"));
2979 }
2980
2981 #[test]
2982 fn hybrid_query_linear_with_alpha_like_python() {
2983 for alpha in [0.1_f32, 0.5, 0.9] {
2985 let query = HybridQuery::new(
2986 "test query",
2987 "description",
2988 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
2989 "embedding",
2990 )
2991 .with_linear(alpha);
2992
2993 let cmd = query.build_cmd("idx");
2994 let packed = cmd.get_packed_command();
2995 let cmd_str = String::from_utf8_lossy(&packed);
2996
2997 assert!(cmd_str.contains("LINEAR"));
2998 assert!(cmd_str.contains("ALPHA"));
2999 assert!(cmd_str.contains("BETA"));
3000 }
3001 }
3002
3003 #[test]
3004 fn hybrid_query_without_combination_method_like_python() {
3005 let query = HybridQuery::new(
3007 "test query",
3008 "description",
3009 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
3010 "embedding",
3011 );
3012
3013 let cmd = query.build_cmd("idx");
3014 let packed = cmd.get_packed_command();
3015 let cmd_str = String::from_utf8_lossy(&packed);
3016
3017 assert!(!cmd_str.contains("COMBINE"));
3018 }
3019
3020 #[test]
3021 fn hybrid_query_with_combined_filters_like_python() {
3022 let filter = Tag::new("genre").eq("comedy") & Num::new("rating").gt(7.0);
3024
3025 let query = HybridQuery::new(
3026 "test query",
3027 "description",
3028 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
3029 "embedding",
3030 )
3031 .with_filter(filter);
3032
3033 let cmd = query.build_cmd("idx");
3034 let packed = cmd.get_packed_command();
3035 let cmd_str = String::from_utf8_lossy(&packed);
3036
3037 assert!(cmd_str.contains("FILTER"));
3038 assert!(cmd_str.contains("genre"));
3039 assert!(cmd_str.contains("comedy"));
3040 assert!(cmd_str.contains("rating"));
3041 }
3042
3043 #[test]
3044 fn hybrid_query_with_numeric_filter_like_python() {
3045 let filter = Num::new("age").gt(30.0);
3047
3048 let query = HybridQuery::new(
3049 "test query",
3050 "description",
3051 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
3052 "embedding",
3053 )
3054 .with_filter(filter);
3055
3056 let cmd = query.build_cmd("idx");
3057 let packed = cmd.get_packed_command();
3058 let cmd_str = String::from_utf8_lossy(&packed);
3059
3060 assert!(cmd_str.contains("FILTER"));
3061 assert!(cmd_str.contains("@age:[(30"));
3062 }
3063
3064 #[test]
3065 fn hybrid_query_with_text_filter_like_python() {
3066 let filter = Text::new("job").eq("engineer");
3068
3069 let query = HybridQuery::new(
3070 "test query",
3071 "description",
3072 Vector::new(vec![0.1, 0.2, 0.3, 0.4]),
3073 "embedding",
3074 )
3075 .with_filter(filter);
3076
3077 let cmd = query.build_cmd("idx");
3078 let packed = cmd.get_packed_command();
3079 let cmd_str = String::from_utf8_lossy(&packed);
3080
3081 assert!(cmd_str.contains("FILTER"));
3082 assert!(cmd_str.contains("@job"));
3083 assert!(cmd_str.contains("engineer"));
3084 }
3085
3086 #[test]
3089 fn vector_query_hybrid_policy_like_python_test_query_types() {
3090 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3091 .with_hybrid_policy(HybridPolicy::Batches);
3092
3093 assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
3094 assert!(query.to_redis_query().contains("HYBRID_POLICY BATCHES"));
3095 }
3096
3097 #[test]
3098 fn vector_query_hybrid_policy_with_batch_size_like_python() {
3099 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3100 .with_hybrid_policy(HybridPolicy::Batches)
3101 .with_batch_size(50);
3102
3103 let qs = query.to_redis_query();
3104 assert!(qs.contains("HYBRID_POLICY BATCHES BATCH_SIZE 50"));
3105 }
3106
3107 #[test]
3108 fn vector_query_adhoc_bf_policy_like_python() {
3109 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3110 .with_hybrid_policy(HybridPolicy::AdhocBf);
3111
3112 assert!(query.to_redis_query().contains("HYBRID_POLICY ADHOC_BF"));
3113 }
3114
3115 #[test]
3116 fn vector_query_epsilon_like_python_test_query_types() {
3117 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3118 .with_epsilon(0.05);
3119
3120 assert_eq!(query.epsilon(), Some(0.05));
3121 let qs = query.to_redis_query();
3122 assert!(qs.contains("EPSILON $EPSILON"));
3123 let params = query.params();
3124 assert!(params.iter().any(|p| p.name == "EPSILON"));
3125 }
3126
3127 #[test]
3128 fn vector_query_ef_runtime_params_like_python_test_query_types() {
3129 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3130 .with_ef_runtime(100);
3131
3132 assert_eq!(query.ef_runtime(), Some(100));
3133 let qs = query.to_redis_query();
3134 assert!(qs.contains("EF_RUNTIME $EF"));
3135 let params = query.params();
3136 assert!(params.iter().any(|p| p.name == "EF"));
3137 }
3138
3139 #[test]
3140 fn vector_query_search_window_size_like_python() {
3141 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3142 .with_search_window_size(40);
3143
3144 assert_eq!(query.search_window_size(), Some(40));
3145 let qs = query.to_redis_query();
3146 assert!(qs.contains("SEARCH_WINDOW_SIZE $SEARCH_WINDOW_SIZE"));
3147 }
3148
3149 #[test]
3150 fn vector_query_use_search_history_like_python() {
3151 for mode in [
3152 SearchHistoryMode::Off,
3153 SearchHistoryMode::On,
3154 SearchHistoryMode::Auto,
3155 ] {
3156 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3157 .with_use_search_history(mode);
3158
3159 assert_eq!(query.use_search_history(), Some(mode));
3160 let qs = query.to_redis_query();
3161 assert!(qs.contains("USE_SEARCH_HISTORY $USE_SEARCH_HISTORY"));
3162 }
3163 }
3164
3165 #[test]
3166 fn vector_query_search_buffer_capacity_like_python() {
3167 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3168 .with_search_buffer_capacity(50);
3169
3170 assert_eq!(query.search_buffer_capacity(), Some(50));
3171 let qs = query.to_redis_query();
3172 assert!(qs.contains("SEARCH_BUFFER_CAPACITY $SEARCH_BUFFER_CAPACITY"));
3173 }
3174
3175 #[test]
3176 fn vector_query_all_runtime_params_like_python() {
3177 let query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10)
3178 .with_ef_runtime(100)
3179 .with_epsilon(0.05)
3180 .with_search_window_size(40)
3181 .with_use_search_history(SearchHistoryMode::On)
3182 .with_search_buffer_capacity(50);
3183
3184 let qs = query.to_redis_query();
3185 assert!(qs.contains("EF_RUNTIME $EF"));
3186 assert!(qs.contains("EPSILON $EPSILON"));
3187 assert!(qs.contains("SEARCH_WINDOW_SIZE $SEARCH_WINDOW_SIZE"));
3188 assert!(qs.contains("USE_SEARCH_HISTORY $USE_SEARCH_HISTORY"));
3189 assert!(qs.contains("SEARCH_BUFFER_CAPACITY $SEARCH_BUFFER_CAPACITY"));
3190
3191 let params = query.params();
3192 assert!(params.iter().any(|p| p.name == "EF"));
3193 assert!(params.iter().any(|p| p.name == "EPSILON"));
3194 assert!(params.iter().any(|p| p.name == "SEARCH_WINDOW_SIZE"));
3195 assert!(params.iter().any(|p| p.name == "USE_SEARCH_HISTORY"));
3196 assert!(params.iter().any(|p| p.name == "SEARCH_BUFFER_CAPACITY"));
3197 }
3198
3199 #[test]
3200 fn vector_query_set_methods_like_python() {
3201 let mut query = VectorQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 10);
3202
3203 assert!(query.ef_runtime().is_none());
3204 assert!(query.epsilon().is_none());
3205 assert!(query.hybrid_policy().is_none());
3206
3207 query.set_ef_runtime(200);
3208 assert_eq!(query.ef_runtime(), Some(200));
3209
3210 query.set_epsilon(0.1);
3211 assert_eq!(query.epsilon(), Some(0.1));
3212
3213 query.set_hybrid_policy(HybridPolicy::Batches);
3214 assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
3215
3216 query.set_batch_size(100);
3217 assert_eq!(query.batch_size(), Some(100));
3218 }
3219
3220 #[test]
3223 fn range_query_epsilon_like_python_test_query_types() {
3224 let query =
3225 VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
3226 .with_epsilon(0.05);
3227
3228 assert_eq!(query.epsilon(), Some(0.05));
3229 let qs = query.to_redis_query();
3230 assert!(qs.contains("$EPSILON: 0.05"));
3231 }
3232
3233 #[test]
3234 fn range_query_construction_like_python() {
3235 let basic = VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
3237 .with_return_fields(["user", "credit_score"]);
3238
3239 let qs = basic.to_redis_query();
3240 assert!(qs.contains("VECTOR_RANGE $distance_threshold $vector"));
3241 assert!(qs.contains("$YIELD_DISTANCE_AS: vector_distance"));
3242 assert!(!qs.contains("HYBRID_POLICY"));
3243
3244 let epsilon_query =
3246 VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
3247 .with_epsilon(0.05);
3248
3249 let qs = epsilon_query.to_redis_query();
3250 assert!(qs.contains("$EPSILON: 0.05"));
3251 assert_eq!(epsilon_query.epsilon(), Some(0.05));
3252 }
3253
3254 #[test]
3255 fn range_query_hybrid_policy_in_params_not_query_string_like_python() {
3256 let query = VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
3257 .with_hybrid_policy(HybridPolicy::Batches);
3258
3259 let qs = query.to_redis_query();
3260 assert!(!qs.contains("HYBRID_POLICY"));
3262 assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
3263
3264 let params = query.params();
3265 assert!(params.iter().any(|p| p.name == "HYBRID_POLICY"));
3266 }
3267
3268 #[test]
3269 fn range_query_hybrid_policy_with_batch_size_in_params_like_python() {
3270 let query = VectorRangeQuery::new(Vector::new(vec![0.1, 0.1, 0.5]), "user_embedding", 0.2)
3271 .with_hybrid_policy(HybridPolicy::Batches)
3272 .with_batch_size(50);
3273
3274 let qs = query.to_redis_query();
3275 assert!(!qs.contains("HYBRID_POLICY"));
3276 assert!(!qs.contains("BATCH_SIZE"));
3277
3278 let params = query.params();
3279 assert!(params.iter().any(|p| p.name == "HYBRID_POLICY"));
3280 assert!(params.iter().any(|p| p.name == "BATCH_SIZE"));
3281 }
3282
3283 #[test]
3284 fn range_query_setter_methods_like_python() {
3285 let mut query =
3286 VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "user_embedding", 0.2);
3287
3288 assert!(query.epsilon().is_none());
3289 assert!(query.hybrid_policy().is_none());
3290 assert!(query.batch_size().is_none());
3291
3292 query.set_epsilon(0.1);
3293 assert_eq!(query.epsilon(), Some(0.1));
3294 assert!(query.to_redis_query().contains("$EPSILON: 0.1"));
3295
3296 query.set_hybrid_policy(HybridPolicy::Batches);
3297 assert_eq!(query.hybrid_policy(), Some(HybridPolicy::Batches));
3298
3299 query.set_batch_size(25);
3300 assert_eq!(query.batch_size(), Some(25));
3301 }
3302
3303 #[test]
3304 fn range_query_search_window_size_like_python() {
3305 let query =
3306 VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
3307 .with_search_window_size(40);
3308
3309 assert_eq!(query.search_window_size(), Some(40));
3310 assert!(query.to_redis_query().contains("$SEARCH_WINDOW_SIZE: 40"));
3311 }
3312
3313 #[test]
3314 fn range_query_use_search_history_like_python() {
3315 for (mode, expected_str) in [
3316 (SearchHistoryMode::Off, "OFF"),
3317 (SearchHistoryMode::On, "ON"),
3318 (SearchHistoryMode::Auto, "AUTO"),
3319 ] {
3320 let query =
3321 VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
3322 .with_use_search_history(mode);
3323
3324 assert_eq!(query.use_search_history(), Some(mode));
3325 let qs = query.to_redis_query();
3326 assert!(
3327 qs.contains(&format!("$USE_SEARCH_HISTORY: {}", expected_str)),
3328 "query string should contain USE_SEARCH_HISTORY for {:?}",
3329 mode,
3330 );
3331 }
3332 }
3333
3334 #[test]
3335 fn range_query_search_buffer_capacity_like_python() {
3336 let query =
3337 VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
3338 .with_search_buffer_capacity(50);
3339
3340 assert_eq!(query.search_buffer_capacity(), Some(50));
3341 assert!(
3342 query
3343 .to_redis_query()
3344 .contains("$SEARCH_BUFFER_CAPACITY: 50")
3345 );
3346 }
3347
3348 #[test]
3349 fn range_query_all_svs_params_like_python() {
3350 let query =
3351 VectorRangeQuery::new(Vector::new(vec![0.1, 0.2, 0.3, 0.4]), "vector_field", 0.3)
3352 .with_epsilon(0.05)
3353 .with_search_window_size(40)
3354 .with_use_search_history(SearchHistoryMode::On)
3355 .with_search_buffer_capacity(50);
3356
3357 let qs = query.to_redis_query();
3358 assert!(qs.contains("$EPSILON: 0.05"));
3359 assert!(qs.contains("$SEARCH_WINDOW_SIZE: 40"));
3360 assert!(qs.contains("$USE_SEARCH_HISTORY: ON"));
3361 assert!(qs.contains("$SEARCH_BUFFER_CAPACITY: 50"));
3362 }
3363
3364 #[test]
3367 fn text_query_with_filter_expression_like_python() {
3368 let filter = Tag::new("genre").eq("comedy");
3369 let query = TextQuery::new("basketball")
3370 .for_field("description")
3371 .with_filter(filter);
3372
3373 let qs = query.to_redis_query();
3374 assert!(qs.contains("@description:(basketball)"));
3375 assert!(qs.contains("AND @genre:{comedy}"));
3376 }
3377
3378 #[test]
3379 fn text_query_without_filter_like_python() {
3380 let query = TextQuery::new("basketball").for_field("description");
3381
3382 let qs = query.to_redis_query();
3383 assert!(qs.contains("@description:(basketball)"));
3384 assert!(!qs.contains("AND"));
3385 }
3386
3387 #[test]
3388 fn text_query_set_filter_like_python() {
3389 let mut query = TextQuery::new("basketball").for_field("description");
3390 query.set_filter(Tag::new("category").eq("sports"));
3391
3392 let qs = query.to_redis_query();
3393 assert!(qs.contains("AND @category:{sports}"));
3394 }
3395
3396 #[test]
3399 fn text_query_with_stopwords_removes_words() {
3400 use std::collections::HashSet;
3401 let mut stopwords = HashSet::new();
3402 stopwords.insert("the".to_owned());
3403 stopwords.insert("a".to_owned());
3404
3405 let query = TextQuery::new("a doctor in the house")
3406 .for_field("description")
3407 .with_stopwords(stopwords);
3408
3409 let qs = query.to_redis_query();
3410 assert!(!qs.contains(" a "));
3412 assert!(!qs.contains(" the "));
3413 assert!(qs.contains("doctor"));
3414 assert!(qs.contains("house"));
3415 }
3416
3417 #[test]
3418 fn text_query_with_text_weights_applies_weight_syntax() {
3419 use std::collections::HashMap;
3420 let mut weights = HashMap::new();
3421 weights.insert("doctor".to_owned(), 2.0_f32);
3422
3423 let query = TextQuery::new("a doctor in the house")
3424 .for_field("description")
3425 .with_text_weights(weights);
3426
3427 let qs = query.to_redis_query();
3428 assert!(qs.contains("doctor=>{2}"));
3429 assert!(qs.contains("house"));
3431 }
3432
3433 #[test]
3434 fn text_query_with_stopwords_and_weights_combined() {
3435 use std::collections::{HashMap, HashSet};
3436 let mut stopwords = HashSet::new();
3437 stopwords.insert("the".to_owned());
3438 stopwords.insert("a".to_owned());
3439
3440 let mut weights = HashMap::new();
3441 weights.insert("doctor".to_owned(), 2.0_f32);
3442
3443 let query = TextQuery::new("a doctor in the house")
3444 .for_field("description")
3445 .with_stopwords(stopwords)
3446 .with_text_weights(weights);
3447
3448 let qs = query.to_redis_query();
3449 assert!(!qs.contains(" a "));
3451 assert!(!qs.contains(" the "));
3452 assert!(qs.contains("doctor=>{2}"));
3453 }
3454
3455 #[test]
3456 fn text_query_set_text_weights_mirrors_builder() {
3457 use std::collections::HashMap;
3458 let mut weights = HashMap::new();
3459 weights.insert("medical".to_owned(), 5.0_f32);
3460
3461 let query1 = TextQuery::new("a medical professional")
3462 .for_field("description")
3463 .with_text_weights(weights.clone());
3464
3465 let mut query2 = TextQuery::new("a medical professional").for_field("description");
3466 query2.set_text_weights(weights);
3467
3468 assert_eq!(query1.to_redis_query(), query2.to_redis_query());
3469 }
3470
3471 #[test]
3472 fn text_query_text_weights_accessor() {
3473 use std::collections::HashMap;
3474 let mut weights = HashMap::new();
3475 weights.insert("alpha".to_owned(), 0.2_f32);
3476 weights.insert("bravo".to_owned(), 0.4_f32);
3477
3478 let query = TextQuery::new("sample text query")
3479 .for_field("description")
3480 .with_text_weights(weights.clone());
3481
3482 assert_eq!(query.text_weights(), Some(&weights));
3483 }
3484
3485 #[test]
3486 fn text_query_no_weights_returns_none() {
3487 let query = TextQuery::new("sample text query").for_field("description");
3488 assert!(query.text_weights().is_none());
3489 }
3490}