1use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::Instant;
14use thiserror::Error;
15
16use common::VectorId;
17
18type SearchResultRow = (VectorId, f32, Option<Vec<f32>>, Option<serde_json::Value>);
20
21#[derive(Debug, Error)]
27pub enum BatchSearchError {
28 #[error("Query batch too large: {0} exceeds maximum {1}")]
29 BatchTooLarge(usize, usize),
30
31 #[error("Invalid cursor: {0}")]
32 InvalidCursor(String),
33
34 #[error("Invalid geo-coordinate: lat={0}, lon={1}")]
35 InvalidGeoCoordinate(f64, f64),
36
37 #[error("Unsupported aggregation type: {0}")]
38 UnsupportedAggregation(String),
39
40 #[error("Invalid scoring function: {0}")]
41 InvalidScoringFunction(String),
42
43 #[error("Query timeout exceeded: {0}ms")]
44 Timeout(u64),
45
46 #[error("Internal error: {0}")]
47 Internal(String),
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct BatchQueryConfig {
57 pub max_batch_size: usize,
59 pub max_concurrency: usize,
61 pub query_timeout_ms: u64,
63 pub deduplicate_queries: bool,
65 pub collect_stats: bool,
67}
68
69impl Default for BatchQueryConfig {
70 fn default() -> Self {
71 Self {
72 max_batch_size: 100,
73 max_concurrency: 16,
74 query_timeout_ms: 5000,
75 deduplicate_queries: true,
76 collect_stats: true,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct BatchQueryItem {
84 pub query_id: String,
86 pub vector: Vec<f32>,
88 pub top_k: usize,
90 pub filter: Option<FilterExpression>,
92 pub cursor: Option<SearchCursor>,
94 pub scoring: Option<ScoringConfig>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct BatchQueryRequest {
101 pub namespace: String,
103 pub queries: Vec<BatchQueryItem>,
105 pub include_vectors: bool,
107 pub include_metadata: bool,
109 pub facets: Option<Vec<FacetRequest>>,
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct BatchQueryItemResponse {
116 pub query_id: String,
118 pub results: Vec<SearchHit>,
120 pub next_cursor: Option<SearchCursor>,
122 pub took_ms: u64,
124 pub total_matches: usize,
126 pub explanation: Option<QueryExplanation>,
128}
129
130#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct BatchQueryResponse {
133 pub responses: Vec<BatchQueryItemResponse>,
135 pub facets: Option<HashMap<String, FacetResult>>,
137 pub stats: BatchQueryStats,
139}
140
141#[derive(Debug, Clone, Default, Serialize, Deserialize)]
143pub struct BatchQueryStats {
144 pub total_queries: usize,
146 pub successful_queries: usize,
148 pub failed_queries: usize,
150 pub total_took_ms: u64,
152 pub avg_query_ms: f64,
154 pub max_query_ms: u64,
156 pub min_query_ms: u64,
158 pub deduplicated_count: usize,
160}
161
162#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SearchHit {
165 pub id: VectorId,
167 pub score: f32,
169 pub vector: Option<Vec<f32>>,
171 pub metadata: Option<serde_json::Value>,
173 pub sort_values: Vec<SortValue>,
175}
176
177#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct SearchCursor {
184 pub cursor_type: CursorType,
186 pub value: String,
188 pub created_at: u64,
190}
191
192impl SearchCursor {
193 pub fn cursor_based(offset: usize, total: usize) -> Self {
195 let value = format!("{}:{}", offset, total);
196 Self {
197 cursor_type: CursorType::CursorBased,
198 value: base64_encode(&value),
199 created_at: current_timestamp_ms(),
200 }
201 }
202
203 pub fn search_after(sort_values: &[SortValue]) -> Self {
205 let json = serde_json::to_string(sort_values).unwrap_or_default();
206 Self {
207 cursor_type: CursorType::SearchAfter,
208 value: base64_encode(&json),
209 created_at: current_timestamp_ms(),
210 }
211 }
212
213 pub fn parse_offset(&self) -> Result<usize, BatchSearchError> {
215 if self.cursor_type != CursorType::CursorBased {
216 return Err(BatchSearchError::InvalidCursor(
217 "Expected cursor-based cursor".into(),
218 ));
219 }
220 let decoded = base64_decode(&self.value)
221 .map_err(|_| BatchSearchError::InvalidCursor("Invalid base64".into()))?;
222 let parts: Vec<&str> = decoded.split(':').collect();
223 parts
224 .first()
225 .and_then(|s| s.parse().ok())
226 .ok_or_else(|| BatchSearchError::InvalidCursor("Invalid offset format".into()))
227 }
228
229 pub fn parse_sort_values(&self) -> Result<Vec<SortValue>, BatchSearchError> {
231 if self.cursor_type != CursorType::SearchAfter {
232 return Err(BatchSearchError::InvalidCursor(
233 "Expected search-after cursor".into(),
234 ));
235 }
236 let decoded = base64_decode(&self.value)
237 .map_err(|_| BatchSearchError::InvalidCursor("Invalid base64".into()))?;
238 serde_json::from_str(&decoded)
239 .map_err(|_| BatchSearchError::InvalidCursor("Invalid sort values".into()))
240 }
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
245pub enum CursorType {
246 CursorBased,
248 SearchAfter,
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
254pub enum SortValue {
255 Score(f32),
256 Integer(i64),
257 Float(f64),
258 String(String),
259 Null,
260}
261
262impl SortValue {
263 pub fn compare(&self, other: &SortValue) -> std::cmp::Ordering {
265 use std::cmp::Ordering;
266 match (self, other) {
267 (SortValue::Score(a), SortValue::Score(b)) => {
268 b.partial_cmp(a).unwrap_or(Ordering::Equal)
269 }
270 (SortValue::Integer(a), SortValue::Integer(b)) => a.cmp(b),
271 (SortValue::Float(a), SortValue::Float(b)) => {
272 a.partial_cmp(b).unwrap_or(Ordering::Equal)
273 }
274 (SortValue::String(a), SortValue::String(b)) => a.cmp(b),
275 (SortValue::Null, SortValue::Null) => Ordering::Equal,
276 (SortValue::Null, _) => Ordering::Greater,
277 (_, SortValue::Null) => Ordering::Less,
278 _ => Ordering::Equal,
279 }
280 }
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct PaginationConfig {
286 pub page_size: usize,
288 pub max_offset: usize,
290 pub cursor_ttl_secs: u64,
292 pub default_sort: Vec<SortField>,
294}
295
296impl Default for PaginationConfig {
297 fn default() -> Self {
298 Self {
299 page_size: 20,
300 max_offset: 10000,
301 cursor_ttl_secs: 3600,
302 default_sort: vec![SortField {
303 field: "_score".into(),
304 order: SortOrder::Descending,
305 }],
306 }
307 }
308}
309
310#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SortField {
313 pub field: String,
315 pub order: SortOrder,
317}
318
319#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
321pub enum SortOrder {
322 Ascending,
323 Descending,
324}
325
326#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct FacetRequest {
333 pub name: String,
335 pub field: String,
337 pub agg_type: AggregationType,
339 pub max_buckets: Option<usize>,
341 pub ranges: Option<Vec<RangeBucket>>,
343}
344
345#[derive(Debug, Clone, Serialize, Deserialize)]
347pub enum AggregationType {
348 Terms,
350 Range,
352 DateHistogram { interval: String },
354 Histogram { interval: f64 },
356 Min,
358 Max,
360 Avg,
362 Sum,
364 Count,
366 Cardinality,
368}
369
370#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct RangeBucket {
373 pub key: String,
375 pub from: Option<f64>,
377 pub to: Option<f64>,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct FacetResult {
384 pub name: String,
386 pub field: String,
388 pub buckets: Option<Vec<FacetBucket>>,
390 pub value: Option<f64>,
392 pub doc_count: usize,
394}
395
396#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct FacetBucket {
399 pub key: String,
401 pub doc_count: usize,
403 pub sub_aggregations: Option<HashMap<String, FacetResult>>,
405}
406
407pub struct FacetExecutor {
409 max_buckets: usize,
410}
411
412impl FacetExecutor {
413 pub fn new(max_buckets: usize) -> Self {
415 Self { max_buckets }
416 }
417
418 pub fn terms_aggregation(
420 &self,
421 values: &[Option<serde_json::Value>],
422 max_buckets: usize,
423 ) -> Vec<FacetBucket> {
424 let mut counts: HashMap<String, usize> = HashMap::new();
425
426 for value in values.iter().flatten() {
427 let key = match value {
428 serde_json::Value::String(s) => s.clone(),
429 serde_json::Value::Number(n) => n.to_string(),
430 serde_json::Value::Bool(b) => b.to_string(),
431 _ => continue,
432 };
433 *counts.entry(key).or_insert(0) += 1;
434 }
435
436 let mut buckets: Vec<_> = counts
437 .into_iter()
438 .map(|(key, count)| FacetBucket {
439 key,
440 doc_count: count,
441 sub_aggregations: None,
442 })
443 .collect();
444
445 buckets.sort_by(|a, b| b.doc_count.cmp(&a.doc_count));
447 buckets.truncate(max_buckets.min(self.max_buckets));
448 buckets
449 }
450
451 pub fn range_aggregation(
453 &self,
454 values: &[Option<f64>],
455 ranges: &[RangeBucket],
456 ) -> Vec<FacetBucket> {
457 ranges
458 .iter()
459 .map(|range| {
460 let count = values
461 .iter()
462 .filter(|v| {
463 if let Some(val) = v {
464 let from_ok = range.from.is_none_or(|f| *val >= f);
465 let to_ok = range.to.is_none_or(|t| *val < t);
466 from_ok && to_ok
467 } else {
468 false
469 }
470 })
471 .count();
472
473 FacetBucket {
474 key: range.key.clone(),
475 doc_count: count,
476 sub_aggregations: None,
477 }
478 })
479 .collect()
480 }
481
482 pub fn numeric_aggregation(&self, values: &[Option<f64>], agg_type: &AggregationType) -> f64 {
484 let valid_values: Vec<f64> = values.iter().filter_map(|v| *v).collect();
485
486 if valid_values.is_empty() {
487 return 0.0;
488 }
489
490 match agg_type {
491 AggregationType::Min => valid_values.iter().copied().fold(f64::INFINITY, f64::min),
492 AggregationType::Max => valid_values
493 .iter()
494 .copied()
495 .fold(f64::NEG_INFINITY, f64::max),
496 AggregationType::Avg => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
497 AggregationType::Sum => valid_values.iter().sum(),
498 AggregationType::Count => valid_values.len() as f64,
499 _ => 0.0,
500 }
501 }
502}
503
504#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
510pub struct GeoPoint {
511 pub lat: f64,
513 pub lon: f64,
515}
516
517impl GeoPoint {
518 pub fn new(lat: f64, lon: f64) -> Result<Self, BatchSearchError> {
520 if !(-90.0..=90.0).contains(&lat) || !(-180.0..=180.0).contains(&lon) {
521 return Err(BatchSearchError::InvalidGeoCoordinate(lat, lon));
522 }
523 Ok(Self { lat, lon })
524 }
525
526 pub fn distance_km(&self, other: &GeoPoint) -> f64 {
528 const EARTH_RADIUS_KM: f64 = 6371.0;
529
530 let lat1 = self.lat.to_radians();
531 let lat2 = other.lat.to_radians();
532 let delta_lat = (other.lat - self.lat).to_radians();
533 let delta_lon = (other.lon - self.lon).to_radians();
534
535 let a = (delta_lat / 2.0).sin().powi(2)
536 + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
537 let c = 2.0 * a.sqrt().asin();
538
539 EARTH_RADIUS_KM * c
540 }
541
542 pub fn distance_m(&self, other: &GeoPoint) -> f64 {
544 self.distance_km(other) * 1000.0
545 }
546
547 pub fn distance_miles(&self, other: &GeoPoint) -> f64 {
549 self.distance_km(other) * 0.621371
550 }
551}
552
553#[derive(Debug, Clone, Serialize, Deserialize)]
555pub enum GeoFilter {
556 Distance {
558 center: GeoPoint,
560 distance: f64,
562 unit: DistanceUnit,
564 },
565 BoundingBox {
567 top_left: GeoPoint,
569 bottom_right: GeoPoint,
571 },
572 Polygon {
574 points: Vec<GeoPoint>,
576 },
577}
578
579#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
581pub enum DistanceUnit {
582 Meters,
583 Kilometers,
584 Miles,
585 Feet,
586}
587
588impl DistanceUnit {
589 pub fn to_meters(&self, distance: f64) -> f64 {
591 match self {
592 DistanceUnit::Meters => distance,
593 DistanceUnit::Kilometers => distance * 1000.0,
594 DistanceUnit::Miles => distance * 1609.34,
595 DistanceUnit::Feet => distance * 0.3048,
596 }
597 }
598}
599
600pub struct GeoFilterExecutor;
602
603impl GeoFilterExecutor {
604 pub fn matches(filter: &GeoFilter, point: &GeoPoint) -> bool {
606 match filter {
607 GeoFilter::Distance {
608 center,
609 distance,
610 unit,
611 } => {
612 let max_distance_m = unit.to_meters(*distance);
613 center.distance_m(point) <= max_distance_m
614 }
615 GeoFilter::BoundingBox {
616 top_left,
617 bottom_right,
618 } => {
619 point.lat <= top_left.lat
620 && point.lat >= bottom_right.lat
621 && point.lon >= top_left.lon
622 && point.lon <= bottom_right.lon
623 }
624 GeoFilter::Polygon { points } => Self::point_in_polygon(point, points),
625 }
626 }
627
628 fn point_in_polygon(point: &GeoPoint, polygon: &[GeoPoint]) -> bool {
630 if polygon.len() < 3 {
631 return false;
632 }
633
634 let mut inside = false;
635 let n = polygon.len();
636
637 let mut j = n - 1;
638 for i in 0..n {
639 let pi = &polygon[i];
640 let pj = &polygon[j];
641
642 if ((pi.lat > point.lat) != (pj.lat > point.lat))
643 && (point.lon
644 < (pj.lon - pi.lon) * (point.lat - pi.lat) / (pj.lat - pi.lat) + pi.lon)
645 {
646 inside = !inside;
647 }
648 j = i;
649 }
650
651 inside
652 }
653}
654
655#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct ScoringConfig {
662 pub score_mode: ScoreMode,
664 pub functions: Vec<ScoreFunction>,
666 pub boost_mode: BoostMode,
668 pub min_score: Option<f32>,
670}
671
672impl Default for ScoringConfig {
673 fn default() -> Self {
674 Self {
675 score_mode: ScoreMode::Multiply,
676 functions: Vec::new(),
677 boost_mode: BoostMode::Multiply,
678 min_score: None,
679 }
680 }
681}
682
683#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
685pub enum ScoreMode {
686 Multiply,
687 Sum,
688 Average,
689 First,
690 Max,
691 Min,
692}
693
694#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
696pub enum BoostMode {
697 Multiply,
698 Replace,
699 Sum,
700 Average,
701 Max,
702 Min,
703}
704
705#[derive(Debug, Clone, Serialize, Deserialize)]
707pub enum ScoreFunction {
708 Weight { weight: f32 },
710 FieldValue {
712 field: String,
713 factor: f32,
714 modifier: FieldValueModifier,
715 missing: f32,
716 },
717 Decay {
719 field: String,
720 origin: f64,
721 scale: f64,
722 offset: f64,
723 decay: f64,
724 decay_type: DecayType,
725 },
726 RandomScore { seed: u64, field: Option<String> },
728 Script {
730 source: String,
731 params: HashMap<String, f64>,
732 },
733 GeoDecay {
735 field: String,
736 origin: GeoPoint,
737 scale: f64,
738 scale_unit: DistanceUnit,
739 offset: f64,
740 decay: f64,
741 },
742}
743
744#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
746pub enum FieldValueModifier {
747 None,
748 Log,
749 Log1p,
750 Log2p,
751 Ln,
752 Ln1p,
753 Ln2p,
754 Square,
755 Sqrt,
756 Reciprocal,
757}
758
759#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
761pub enum DecayType {
762 Gaussian,
763 Linear,
764 Exponential,
765}
766
767pub struct ScoreFunctionExecutor;
769
770impl ScoreFunctionExecutor {
771 pub fn apply(
773 function: &ScoreFunction,
774 original_score: f32,
775 metadata: Option<&serde_json::Value>,
776 ) -> f32 {
777 match function {
778 ScoreFunction::Weight { weight } => *weight,
779
780 ScoreFunction::FieldValue {
781 field,
782 factor,
783 modifier,
784 missing,
785 } => {
786 let value = metadata
787 .and_then(|m| m.get(field))
788 .and_then(|v| v.as_f64())
789 .unwrap_or(*missing as f64);
790
791 let modified = Self::apply_modifier(value, *modifier);
792 (modified * *factor as f64) as f32
793 }
794
795 ScoreFunction::Decay {
796 origin,
797 scale,
798 offset,
799 decay,
800 decay_type,
801 field,
802 } => {
803 let value = metadata
804 .and_then(|m| m.get(field))
805 .and_then(|v| v.as_f64())
806 .unwrap_or(*origin);
807
808 let distance = (value - origin).abs() - offset;
809 if distance <= 0.0 {
810 return 1.0;
811 }
812
813 Self::compute_decay(distance, *scale, *decay, *decay_type) as f32
814 }
815
816 ScoreFunction::RandomScore { seed, .. } => {
817 let hash = (*seed as u32).wrapping_mul(2654435769);
819 hash as f32 / u32::MAX as f32
820 }
821
822 ScoreFunction::Script { source, params } => {
823 Self::evaluate_script(source, params, original_score, metadata)
825 }
826
827 ScoreFunction::GeoDecay {
828 field,
829 origin,
830 scale,
831 scale_unit,
832 offset,
833 decay,
834 } => {
835 let point = metadata.and_then(|m| m.get(field)).and_then(|v| {
836 let lat = v.get("lat")?.as_f64()?;
837 let lon = v.get("lon")?.as_f64()?;
838 Some(GeoPoint { lat, lon })
839 });
840
841 if let Some(point) = point {
842 let distance_m = origin.distance_m(&point);
843 let scale_m = scale_unit.to_meters(*scale);
844 let offset_m = scale_unit.to_meters(*offset);
845
846 let adjusted_distance = (distance_m - offset_m).max(0.0);
847 Self::compute_decay(adjusted_distance, scale_m, *decay, DecayType::Gaussian)
848 as f32
849 } else {
850 0.0
851 }
852 }
853 }
854 }
855
856 fn apply_modifier(value: f64, modifier: FieldValueModifier) -> f64 {
857 match modifier {
858 FieldValueModifier::None => value,
859 FieldValueModifier::Log => value.log10(),
860 FieldValueModifier::Log1p => (1.0 + value).log10(),
861 FieldValueModifier::Log2p => (2.0 + value).log10(),
862 FieldValueModifier::Ln => value.ln(),
863 FieldValueModifier::Ln1p => (1.0 + value).ln(),
864 FieldValueModifier::Ln2p => (2.0 + value).ln(),
865 FieldValueModifier::Square => value * value,
866 FieldValueModifier::Sqrt => value.sqrt(),
867 FieldValueModifier::Reciprocal => 1.0 / value.max(0.001),
868 }
869 }
870
871 fn compute_decay(distance: f64, scale: f64, decay: f64, decay_type: DecayType) -> f64 {
872 let lambda = scale.ln().abs() / decay.ln().abs();
873
874 match decay_type {
875 DecayType::Gaussian => (-0.5 * (distance / lambda).powi(2)).exp(),
876 DecayType::Linear => ((lambda - distance) / lambda).max(0.0),
877 DecayType::Exponential => (-distance / lambda).exp(),
878 }
879 }
880
881 fn evaluate_script(
882 _source: &str,
883 params: &HashMap<String, f64>,
884 original_score: f32,
885 _metadata: Option<&serde_json::Value>,
886 ) -> f32 {
887 let boost = params.get("boost").copied().unwrap_or(1.0) as f32;
889 original_score * boost
890 }
891
892 pub fn combine_scores(scores: &[f32], mode: ScoreMode) -> f32 {
894 if scores.is_empty() {
895 return 1.0;
896 }
897
898 match mode {
899 ScoreMode::Multiply => scores.iter().product(),
900 ScoreMode::Sum => scores.iter().sum(),
901 ScoreMode::Average => scores.iter().sum::<f32>() / scores.len() as f32,
902 ScoreMode::First => scores[0],
903 ScoreMode::Max => scores.iter().copied().fold(f32::NEG_INFINITY, f32::max),
904 ScoreMode::Min => scores.iter().copied().fold(f32::INFINITY, f32::min),
905 }
906 }
907
908 pub fn combine_with_original(original: f32, function_score: f32, mode: BoostMode) -> f32 {
910 match mode {
911 BoostMode::Multiply => original * function_score,
912 BoostMode::Replace => function_score,
913 BoostMode::Sum => original + function_score,
914 BoostMode::Average => (original + function_score) / 2.0,
915 BoostMode::Max => original.max(function_score),
916 BoostMode::Min => original.min(function_score),
917 }
918 }
919}
920
921#[derive(Debug, Clone, Serialize, Deserialize)]
927pub struct QueryExplanation {
928 pub score: f32,
930 pub description: String,
932 pub details: Vec<ScoreDetail>,
934}
935
936#[derive(Debug, Clone, Serialize, Deserialize)]
938pub struct ScoreDetail {
939 pub name: String,
941 pub value: f32,
943 pub description: String,
945 pub details: Option<Vec<ScoreDetail>>,
947}
948
949pub struct QueryExplainer;
951
952impl QueryExplainer {
953 pub fn explain(
955 id: &VectorId,
956 original_score: f32,
957 scoring_config: Option<&ScoringConfig>,
958 metadata: Option<&serde_json::Value>,
959 ) -> QueryExplanation {
960 let mut details = vec![ScoreDetail {
961 name: "vector_similarity".into(),
962 value: original_score,
963 description: "Base vector similarity score".into(),
964 details: None,
965 }];
966
967 let mut final_score = original_score;
968
969 if let Some(config) = scoring_config {
970 let mut function_scores = Vec::new();
971
972 for (i, func) in config.functions.iter().enumerate() {
973 let func_score = ScoreFunctionExecutor::apply(func, original_score, metadata);
974 function_scores.push(func_score);
975
976 details.push(ScoreDetail {
977 name: format!("function_{}", i),
978 value: func_score,
979 description: Self::describe_function(func),
980 details: None,
981 });
982 }
983
984 if !function_scores.is_empty() {
985 let combined =
986 ScoreFunctionExecutor::combine_scores(&function_scores, config.score_mode);
987 details.push(ScoreDetail {
988 name: "combined_functions".into(),
989 value: combined,
990 description: format!("Functions combined using {:?}", config.score_mode),
991 details: None,
992 });
993
994 final_score = ScoreFunctionExecutor::combine_with_original(
995 original_score,
996 combined,
997 config.boost_mode,
998 );
999 }
1000 }
1001
1002 QueryExplanation {
1003 score: final_score,
1004 description: format!("Explanation for document {}", id),
1005 details,
1006 }
1007 }
1008
1009 fn describe_function(func: &ScoreFunction) -> String {
1010 match func {
1011 ScoreFunction::Weight { weight } => format!("Constant weight: {}", weight),
1012 ScoreFunction::FieldValue { field, factor, .. } => {
1013 format!("Field value boost on '{}' with factor {}", field, factor)
1014 }
1015 ScoreFunction::Decay {
1016 field, decay_type, ..
1017 } => format!("{:?} decay on field '{}'", decay_type, field),
1018 ScoreFunction::RandomScore { seed, .. } => format!("Random score with seed {}", seed),
1019 ScoreFunction::Script { source, .. } => format!("Script: {}", source),
1020 ScoreFunction::GeoDecay { field, .. } => format!("Geo decay on field '{}'", field),
1021 }
1022 }
1023}
1024
1025#[derive(Debug, Clone, Serialize, Deserialize)]
1031pub enum FilterExpression {
1032 Term {
1034 field: String,
1035 value: serde_json::Value,
1036 },
1037 Terms {
1039 field: String,
1040 values: Vec<serde_json::Value>,
1041 },
1042 Range {
1044 field: String,
1045 gte: Option<f64>,
1046 gt: Option<f64>,
1047 lte: Option<f64>,
1048 lt: Option<f64>,
1049 },
1050 Exists { field: String },
1052 Prefix { field: String, prefix: String },
1054 Geo { field: String, filter: GeoFilter },
1056 And(Vec<FilterExpression>),
1058 Or(Vec<FilterExpression>),
1060 Not(Box<FilterExpression>),
1062}
1063
1064pub struct FilterExecutor;
1066
1067impl FilterExecutor {
1068 pub fn matches(filter: &FilterExpression, metadata: Option<&serde_json::Value>) -> bool {
1070 let metadata = match metadata {
1071 Some(m) => m,
1072 None => return false,
1073 };
1074
1075 match filter {
1076 FilterExpression::Term { field, value } => metadata.get(field) == Some(value),
1077 FilterExpression::Terms { field, values } => {
1078 metadata.get(field).is_some_and(|v| values.contains(v))
1079 }
1080 FilterExpression::Range {
1081 field,
1082 gte,
1083 gt,
1084 lte,
1085 lt,
1086 } => {
1087 let value = metadata.get(field).and_then(|v| v.as_f64());
1088 if let Some(v) = value {
1089 gte.is_none_or(|x| v >= x)
1090 && gt.is_none_or(|x| v > x)
1091 && lte.is_none_or(|x| v <= x)
1092 && lt.is_none_or(|x| v < x)
1093 } else {
1094 false
1095 }
1096 }
1097 FilterExpression::Exists { field } => metadata.get(field).is_some(),
1098 FilterExpression::Prefix { field, prefix } => metadata
1099 .get(field)
1100 .and_then(|v| v.as_str())
1101 .is_some_and(|s| s.starts_with(prefix)),
1102 FilterExpression::Geo { field, filter } => {
1103 let point = metadata.get(field).and_then(|v| {
1104 let lat = v.get("lat")?.as_f64()?;
1105 let lon = v.get("lon")?.as_f64()?;
1106 Some(GeoPoint { lat, lon })
1107 });
1108 point.is_some_and(|p| GeoFilterExecutor::matches(filter, &p))
1109 }
1110 FilterExpression::And(filters) => {
1111 filters.iter().all(|f| Self::matches(f, Some(metadata)))
1112 }
1113 FilterExpression::Or(filters) => {
1114 filters.iter().any(|f| Self::matches(f, Some(metadata)))
1115 }
1116 FilterExpression::Not(filter) => !Self::matches(filter, Some(metadata)),
1117 }
1118 }
1119}
1120
1121pub struct BatchQueryExecutor {
1127 config: BatchQueryConfig,
1128}
1129
1130impl BatchQueryExecutor {
1131 pub fn new(config: BatchQueryConfig) -> Self {
1133 Self { config }
1134 }
1135
1136 pub fn execute(
1138 &self,
1139 request: &BatchQueryRequest,
1140 search_fn: impl Fn(
1141 &[f32],
1142 usize,
1143 Option<&FilterExpression>,
1144 ) -> Vec<(VectorId, f32, Option<Vec<f32>>, Option<serde_json::Value>)>,
1145 ) -> Result<BatchQueryResponse, BatchSearchError> {
1146 let start = Instant::now();
1147
1148 if request.queries.len() > self.config.max_batch_size {
1150 return Err(BatchSearchError::BatchTooLarge(
1151 request.queries.len(),
1152 self.config.max_batch_size,
1153 ));
1154 }
1155
1156 let mut responses: Vec<BatchQueryItemResponse> = Vec::with_capacity(request.queries.len());
1157 let mut query_times: Vec<u64> = Vec::new();
1158 let mut deduplicated_count = 0;
1159
1160 let mut seen_queries: HashMap<Vec<u32>, usize> = HashMap::new();
1162
1163 for query in &request.queries {
1164 let query_start = Instant::now();
1165
1166 let query_hash: Vec<u32> = query.vector.iter().map(|f| f.to_bits()).collect();
1168
1169 if self.config.deduplicate_queries {
1170 if let Some(&existing_idx) = seen_queries.get(&query_hash) {
1171 let mut response = responses[existing_idx].clone();
1173 response.query_id = query.query_id.clone();
1174 responses.push(response);
1175 deduplicated_count += 1;
1176 continue;
1177 }
1178 seen_queries.insert(query_hash, responses.len());
1179 }
1180
1181 let raw_results = search_fn(&query.vector, query.top_k * 2, query.filter.as_ref());
1183
1184 let (results, next_cursor) = self.apply_pagination(&raw_results, query)?;
1186
1187 let scored_results = self.apply_scoring(&results, query.scoring.as_ref());
1189
1190 let hits: Vec<SearchHit> = scored_results
1192 .into_iter()
1193 .take(query.top_k)
1194 .map(|(id, score, vec, meta)| SearchHit {
1195 id,
1196 score,
1197 vector: if request.include_vectors { vec } else { None },
1198 metadata: if request.include_metadata { meta } else { None },
1199 sort_values: vec![SortValue::Score(score)],
1200 })
1201 .collect();
1202
1203 let query_time = query_start.elapsed().as_millis() as u64;
1204 query_times.push(query_time);
1205
1206 responses.push(BatchQueryItemResponse {
1207 query_id: query.query_id.clone(),
1208 results: hits,
1209 next_cursor,
1210 took_ms: query_time,
1211 total_matches: raw_results.len(),
1212 explanation: None,
1213 });
1214 }
1215
1216 let facets = request.facets.as_ref().map(|_| HashMap::new());
1218
1219 let total_took = start.elapsed().as_millis() as u64;
1221 let stats = BatchQueryStats {
1222 total_queries: request.queries.len(),
1223 successful_queries: responses.len(),
1224 failed_queries: 0,
1225 total_took_ms: total_took,
1226 avg_query_ms: if !query_times.is_empty() {
1227 query_times.iter().sum::<u64>() as f64 / query_times.len() as f64
1228 } else {
1229 0.0
1230 },
1231 max_query_ms: query_times.iter().copied().max().unwrap_or(0),
1232 min_query_ms: query_times.iter().copied().min().unwrap_or(0),
1233 deduplicated_count,
1234 };
1235
1236 Ok(BatchQueryResponse {
1237 responses,
1238 facets,
1239 stats,
1240 })
1241 }
1242
1243 fn apply_pagination(
1244 &self,
1245 results: &[SearchResultRow],
1246 query: &BatchQueryItem,
1247 ) -> Result<(Vec<SearchResultRow>, Option<SearchCursor>), BatchSearchError> {
1248 if let Some(cursor) = &query.cursor {
1249 match cursor.cursor_type {
1250 CursorType::CursorBased => {
1251 let offset = cursor.parse_offset()?;
1252 let paginated: Vec<_> = results.iter().skip(offset).cloned().collect();
1253 let next_cursor = if offset + query.top_k < results.len() {
1254 Some(SearchCursor::cursor_based(
1255 offset + query.top_k,
1256 results.len(),
1257 ))
1258 } else {
1259 None
1260 };
1261 Ok((paginated, next_cursor))
1262 }
1263 CursorType::SearchAfter => {
1264 let sort_values = cursor.parse_sort_values()?;
1265 let start_idx = results
1266 .iter()
1267 .position(|(_, score, _, _)| {
1268 let sv = SortValue::Score(*score);
1269 sv.compare(&sort_values[0]) == std::cmp::Ordering::Greater
1270 })
1271 .unwrap_or(0);
1272 let paginated: Vec<_> = results.iter().skip(start_idx).cloned().collect();
1273 let next_cursor = if start_idx + query.top_k < results.len() {
1274 paginated
1275 .get(query.top_k - 1)
1276 .map(|last| SearchCursor::search_after(&[SortValue::Score(last.1)]))
1277 } else {
1278 None
1279 };
1280 Ok((paginated, next_cursor))
1281 }
1282 }
1283 } else {
1284 let next_cursor = if results.len() > query.top_k {
1285 Some(SearchCursor::cursor_based(query.top_k, results.len()))
1286 } else {
1287 None
1288 };
1289 Ok((results.to_vec(), next_cursor))
1290 }
1291 }
1292
1293 fn apply_scoring(
1294 &self,
1295 results: &[SearchResultRow],
1296 scoring: Option<&ScoringConfig>,
1297 ) -> Vec<SearchResultRow> {
1298 let config = match scoring {
1299 Some(c) if !c.functions.is_empty() => c,
1300 _ => return results.to_vec(),
1301 };
1302 let mut scored: Vec<_> = results
1303 .iter()
1304 .map(|(id, score, vec, meta)| {
1305 let function_scores: Vec<f32> = config
1306 .functions
1307 .iter()
1308 .map(|f| ScoreFunctionExecutor::apply(f, *score, meta.as_ref()))
1309 .collect();
1310
1311 let combined =
1312 ScoreFunctionExecutor::combine_scores(&function_scores, config.score_mode);
1313 let final_score = ScoreFunctionExecutor::combine_with_original(
1314 *score,
1315 combined,
1316 config.boost_mode,
1317 );
1318
1319 (id.clone(), final_score, vec.clone(), meta.clone())
1320 })
1321 .collect();
1322
1323 if let Some(min) = config.min_score {
1325 scored.retain(|(_, s, _, _)| *s >= min);
1326 }
1327
1328 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1330 scored
1331 }
1332}
1333
1334fn base64_encode(input: &str) -> String {
1339 let bytes = input.as_bytes();
1341 let mut result = String::new();
1342 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1343
1344 for chunk in bytes.chunks(3) {
1345 let mut n = (chunk[0] as u32) << 16;
1346 if chunk.len() > 1 {
1347 n |= (chunk[1] as u32) << 8;
1348 }
1349 if chunk.len() > 2 {
1350 n |= chunk[2] as u32;
1351 }
1352
1353 result.push(CHARS[(n >> 18 & 0x3f) as usize] as char);
1354 result.push(CHARS[(n >> 12 & 0x3f) as usize] as char);
1355 if chunk.len() > 1 {
1356 result.push(CHARS[(n >> 6 & 0x3f) as usize] as char);
1357 } else {
1358 result.push('=');
1359 }
1360 if chunk.len() > 2 {
1361 result.push(CHARS[(n & 0x3f) as usize] as char);
1362 } else {
1363 result.push('=');
1364 }
1365 }
1366 result
1367}
1368
1369fn base64_decode(input: &str) -> Result<String, &'static str> {
1370 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1371
1372 let mut result = Vec::new();
1373 let input = input.trim_end_matches('=');
1374 let bytes: Vec<u8> = input.bytes().collect();
1375
1376 for chunk in bytes.chunks(4) {
1377 let mut n = 0u32;
1378 for (i, &b) in chunk.iter().enumerate() {
1379 let pos = CHARS.iter().position(|&c| c == b).ok_or("Invalid base64")?;
1380 n |= (pos as u32) << (18 - i * 6);
1381 }
1382
1383 result.push((n >> 16) as u8);
1384 if chunk.len() > 2 {
1385 result.push((n >> 8) as u8);
1386 }
1387 if chunk.len() > 3 {
1388 result.push(n as u8);
1389 }
1390 }
1391
1392 String::from_utf8(result).map_err(|_| "Invalid UTF-8")
1393}
1394
1395fn current_timestamp_ms() -> u64 {
1396 std::time::SystemTime::now()
1397 .duration_since(std::time::UNIX_EPOCH)
1398 .map(|d| d.as_millis() as u64)
1399 .unwrap_or(0)
1400}
1401
1402#[cfg(test)]
1407mod tests {
1408 use super::*;
1409
1410 #[test]
1411 fn test_batch_query_config_default() {
1412 let config = BatchQueryConfig::default();
1413 assert_eq!(config.max_batch_size, 100);
1414 assert_eq!(config.max_concurrency, 16);
1415 assert_eq!(config.query_timeout_ms, 5000);
1416 }
1417
1418 #[test]
1419 fn test_cursor_based_pagination() {
1420 let cursor = SearchCursor::cursor_based(20, 100);
1421 assert_eq!(cursor.cursor_type, CursorType::CursorBased);
1422
1423 let offset = cursor.parse_offset().unwrap();
1424 assert_eq!(offset, 20);
1425 }
1426
1427 #[test]
1428 fn test_search_after_pagination() {
1429 let sort_values = vec![SortValue::Score(0.95), SortValue::String("doc123".into())];
1430 let cursor = SearchCursor::search_after(&sort_values);
1431 assert_eq!(cursor.cursor_type, CursorType::SearchAfter);
1432
1433 let parsed = cursor.parse_sort_values().unwrap();
1434 assert_eq!(parsed.len(), 2);
1435 }
1436
1437 #[test]
1438 fn test_sort_value_comparison() {
1439 assert_eq!(
1440 SortValue::Score(0.9).compare(&SortValue::Score(0.8)),
1441 std::cmp::Ordering::Less
1442 );
1443 assert_eq!(
1444 SortValue::Integer(10).compare(&SortValue::Integer(5)),
1445 std::cmp::Ordering::Greater
1446 );
1447 }
1448
1449 #[test]
1450 fn test_geo_point_distance() {
1451 let nyc = GeoPoint::new(40.7128, -74.0060).unwrap();
1453 let la = GeoPoint::new(34.0522, -118.2437).unwrap();
1454
1455 let distance = nyc.distance_km(&la);
1456 assert!(distance > 3900.0 && distance < 4000.0);
1457 }
1458
1459 #[test]
1460 fn test_geo_point_validation() {
1461 assert!(GeoPoint::new(45.0, 90.0).is_ok());
1462 assert!(GeoPoint::new(91.0, 0.0).is_err());
1463 assert!(GeoPoint::new(0.0, 181.0).is_err());
1464 }
1465
1466 #[test]
1467 fn test_geo_distance_filter() {
1468 let center = GeoPoint::new(40.7128, -74.0060).unwrap();
1469 let filter = GeoFilter::Distance {
1470 center,
1471 distance: 100.0,
1472 unit: DistanceUnit::Kilometers,
1473 };
1474
1475 let nearby = GeoPoint::new(40.8, -74.1).unwrap();
1477 assert!(GeoFilterExecutor::matches(&filter, &nearby));
1478
1479 let far = GeoPoint::new(34.0522, -118.2437).unwrap();
1481 assert!(!GeoFilterExecutor::matches(&filter, &far));
1482 }
1483
1484 #[test]
1485 fn test_geo_bounding_box() {
1486 let filter = GeoFilter::BoundingBox {
1487 top_left: GeoPoint::new(41.0, -75.0).unwrap(),
1488 bottom_right: GeoPoint::new(40.0, -73.0).unwrap(),
1489 };
1490
1491 let inside = GeoPoint::new(40.5, -74.0).unwrap();
1492 assert!(GeoFilterExecutor::matches(&filter, &inside));
1493
1494 let outside = GeoPoint::new(42.0, -74.0).unwrap();
1495 assert!(!GeoFilterExecutor::matches(&filter, &outside));
1496 }
1497
1498 #[test]
1499 fn test_terms_aggregation() {
1500 let executor = FacetExecutor::new(100);
1501
1502 let values: Vec<Option<serde_json::Value>> = vec![
1503 Some(serde_json::json!("cat")),
1504 Some(serde_json::json!("dog")),
1505 Some(serde_json::json!("cat")),
1506 Some(serde_json::json!("bird")),
1507 Some(serde_json::json!("cat")),
1508 ];
1509
1510 let buckets = executor.terms_aggregation(&values, 10);
1511
1512 assert_eq!(buckets.len(), 3);
1513 assert_eq!(buckets[0].key, "cat");
1514 assert_eq!(buckets[0].doc_count, 3);
1515 }
1516
1517 #[test]
1518 fn test_range_aggregation() {
1519 let executor = FacetExecutor::new(100);
1520
1521 let values: Vec<Option<f64>> =
1522 vec![Some(5.0), Some(15.0), Some(25.0), Some(35.0), Some(45.0)];
1523 let ranges = vec![
1524 RangeBucket {
1525 key: "low".into(),
1526 from: None,
1527 to: Some(20.0),
1528 },
1529 RangeBucket {
1530 key: "medium".into(),
1531 from: Some(20.0),
1532 to: Some(40.0),
1533 },
1534 RangeBucket {
1535 key: "high".into(),
1536 from: Some(40.0),
1537 to: None,
1538 },
1539 ];
1540
1541 let buckets = executor.range_aggregation(&values, &ranges);
1542
1543 assert_eq!(buckets.len(), 3);
1544 assert_eq!(buckets[0].doc_count, 2); assert_eq!(buckets[1].doc_count, 2); assert_eq!(buckets[2].doc_count, 1); }
1548
1549 #[test]
1550 fn test_numeric_aggregations() {
1551 let executor = FacetExecutor::new(100);
1552 let values: Vec<Option<f64>> = vec![Some(10.0), Some(20.0), Some(30.0)];
1553
1554 assert_eq!(
1555 executor.numeric_aggregation(&values, &AggregationType::Min),
1556 10.0
1557 );
1558 assert_eq!(
1559 executor.numeric_aggregation(&values, &AggregationType::Max),
1560 30.0
1561 );
1562 assert_eq!(
1563 executor.numeric_aggregation(&values, &AggregationType::Avg),
1564 20.0
1565 );
1566 assert_eq!(
1567 executor.numeric_aggregation(&values, &AggregationType::Sum),
1568 60.0
1569 );
1570 }
1571
1572 #[test]
1573 fn test_score_function_weight() {
1574 let func = ScoreFunction::Weight { weight: 2.0 };
1575 let score = ScoreFunctionExecutor::apply(&func, 0.5, None);
1576 assert_eq!(score, 2.0);
1577 }
1578
1579 #[test]
1580 fn test_score_function_field_value() {
1581 let func = ScoreFunction::FieldValue {
1582 field: "popularity".into(),
1583 factor: 0.1,
1584 modifier: FieldValueModifier::Log1p,
1585 missing: 1.0,
1586 };
1587
1588 let metadata = serde_json::json!({"popularity": 100.0});
1589 let score = ScoreFunctionExecutor::apply(&func, 0.5, Some(&metadata));
1590
1591 assert!(score > 0.19 && score < 0.21);
1593 }
1594
1595 #[test]
1596 fn test_combine_scores() {
1597 let scores = vec![2.0f32, 3.0, 4.0];
1598
1599 assert_eq!(
1600 ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Multiply),
1601 24.0
1602 );
1603 assert_eq!(
1604 ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Sum),
1605 9.0
1606 );
1607 assert_eq!(
1608 ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Average),
1609 3.0
1610 );
1611 assert_eq!(
1612 ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Max),
1613 4.0
1614 );
1615 assert_eq!(
1616 ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Min),
1617 2.0
1618 );
1619 }
1620
1621 #[test]
1622 fn test_filter_term() {
1623 let filter = FilterExpression::Term {
1624 field: "category".into(),
1625 value: serde_json::json!("tech"),
1626 };
1627
1628 let metadata = serde_json::json!({"category": "tech"});
1629 assert!(FilterExecutor::matches(&filter, Some(&metadata)));
1630
1631 let other = serde_json::json!({"category": "science"});
1632 assert!(!FilterExecutor::matches(&filter, Some(&other)));
1633 }
1634
1635 #[test]
1636 fn test_filter_range() {
1637 let filter = FilterExpression::Range {
1638 field: "price".into(),
1639 gte: Some(10.0),
1640 gt: None,
1641 lte: Some(100.0),
1642 lt: None,
1643 };
1644
1645 let match1 = serde_json::json!({"price": 50});
1646 assert!(FilterExecutor::matches(&filter, Some(&match1)));
1647
1648 let nomatch = serde_json::json!({"price": 5});
1649 assert!(!FilterExecutor::matches(&filter, Some(&nomatch)));
1650 }
1651
1652 #[test]
1653 fn test_filter_boolean_and() {
1654 let filter = FilterExpression::And(vec![
1655 FilterExpression::Term {
1656 field: "category".into(),
1657 value: serde_json::json!("tech"),
1658 },
1659 FilterExpression::Range {
1660 field: "price".into(),
1661 gte: Some(10.0),
1662 gt: None,
1663 lte: None,
1664 lt: None,
1665 },
1666 ]);
1667
1668 let match1 = serde_json::json!({"category": "tech", "price": 50});
1669 assert!(FilterExecutor::matches(&filter, Some(&match1)));
1670
1671 let nomatch = serde_json::json!({"category": "tech", "price": 5});
1672 assert!(!FilterExecutor::matches(&filter, Some(&nomatch)));
1673 }
1674
1675 #[test]
1676 fn test_query_explanation() {
1677 let id = "doc123".to_string();
1678 let scoring = ScoringConfig {
1679 functions: vec![ScoreFunction::Weight { weight: 1.5 }],
1680 ..Default::default()
1681 };
1682
1683 let explanation = QueryExplainer::explain(&id, 0.8, Some(&scoring), None);
1684
1685 assert!(explanation.score > 0.0);
1686 assert!(!explanation.details.is_empty());
1687 assert!(explanation.description.contains("doc123"));
1688 }
1689
1690 #[test]
1691 fn test_batch_query_executor() {
1692 let config = BatchQueryConfig::default();
1693 let executor = BatchQueryExecutor::new(config);
1694
1695 let request = BatchQueryRequest {
1696 namespace: "test".into(),
1697 queries: vec![
1698 BatchQueryItem {
1699 query_id: "q1".into(),
1700 vector: vec![1.0, 0.0, 0.0],
1701 top_k: 5,
1702 filter: None,
1703 cursor: None,
1704 scoring: None,
1705 },
1706 BatchQueryItem {
1707 query_id: "q2".into(),
1708 vector: vec![0.0, 1.0, 0.0],
1709 top_k: 5,
1710 filter: None,
1711 cursor: None,
1712 scoring: None,
1713 },
1714 ],
1715 include_vectors: false,
1716 include_metadata: true,
1717 facets: None,
1718 };
1719
1720 let search_fn = |_vector: &[f32], _top_k: usize, _filter: Option<&FilterExpression>| {
1722 vec![
1723 (
1724 "doc1".into(),
1725 0.9,
1726 None,
1727 Some(serde_json::json!({"cat": "a"})),
1728 ),
1729 (
1730 "doc2".into(),
1731 0.8,
1732 None,
1733 Some(serde_json::json!({"cat": "b"})),
1734 ),
1735 ]
1736 };
1737
1738 let response = executor.execute(&request, search_fn).unwrap();
1739
1740 assert_eq!(response.responses.len(), 2);
1741 assert_eq!(response.stats.total_queries, 2);
1742 assert_eq!(response.stats.successful_queries, 2);
1743 }
1744
1745 #[test]
1746 fn test_batch_too_large_error() {
1747 let config = BatchQueryConfig {
1748 max_batch_size: 2,
1749 ..Default::default()
1750 };
1751 let executor = BatchQueryExecutor::new(config);
1752
1753 let request = BatchQueryRequest {
1754 namespace: "test".into(),
1755 queries: vec![
1756 BatchQueryItem {
1757 query_id: "q1".into(),
1758 vector: vec![1.0],
1759 top_k: 5,
1760 filter: None,
1761 cursor: None,
1762 scoring: None,
1763 },
1764 BatchQueryItem {
1765 query_id: "q2".into(),
1766 vector: vec![1.0],
1767 top_k: 5,
1768 filter: None,
1769 cursor: None,
1770 scoring: None,
1771 },
1772 BatchQueryItem {
1773 query_id: "q3".into(),
1774 vector: vec![1.0],
1775 top_k: 5,
1776 filter: None,
1777 cursor: None,
1778 scoring: None,
1779 },
1780 ],
1781 include_vectors: false,
1782 include_metadata: false,
1783 facets: None,
1784 };
1785
1786 let result = executor.execute(&request, |_, _, _| vec![]);
1787 assert!(matches!(result, Err(BatchSearchError::BatchTooLarge(3, 2))));
1788 }
1789
1790 #[test]
1791 fn test_base64_roundtrip() {
1792 let original = "hello:world:123";
1793 let encoded = base64_encode(original);
1794 let decoded = base64_decode(&encoded).unwrap();
1795 assert_eq!(decoded, original);
1796 }
1797
1798 #[test]
1799 fn test_distance_unit_conversion() {
1800 assert_eq!(DistanceUnit::Meters.to_meters(100.0), 100.0);
1801 assert_eq!(DistanceUnit::Kilometers.to_meters(1.0), 1000.0);
1802 assert!((DistanceUnit::Miles.to_meters(1.0) - 1609.34).abs() < 0.01);
1803 }
1804
1805 #[test]
1806 fn test_query_deduplication() {
1807 let config = BatchQueryConfig {
1808 deduplicate_queries: true,
1809 ..Default::default()
1810 };
1811 let executor = BatchQueryExecutor::new(config);
1812
1813 let request = BatchQueryRequest {
1815 namespace: "test".into(),
1816 queries: vec![
1817 BatchQueryItem {
1818 query_id: "q1".into(),
1819 vector: vec![1.0, 0.0],
1820 top_k: 5,
1821 filter: None,
1822 cursor: None,
1823 scoring: None,
1824 },
1825 BatchQueryItem {
1826 query_id: "q2".into(),
1827 vector: vec![1.0, 0.0],
1828 top_k: 5,
1829 filter: None,
1830 cursor: None,
1831 scoring: None,
1832 },
1833 ],
1834 include_vectors: false,
1835 include_metadata: false,
1836 facets: None,
1837 };
1838
1839 let response = executor
1840 .execute(&request, |_, _, _| vec![("doc1".into(), 0.9, None, None)])
1841 .unwrap();
1842
1843 assert_eq!(response.stats.deduplicated_count, 1);
1844 }
1845}