Skip to main content

engine/
batch_search.rs

1//! Batch Search and Advanced Query Features for Dakera
2//!
3//! Provides high-throughput search capabilities:
4//! - **Batch Queries**: Process multiple queries in parallel for efficiency
5//! - **Pagination**: Cursor-based and search-after pagination for large result sets
6//! - **Faceted Search**: Aggregations and facet counting on metadata fields
7//! - **Geo-Spatial**: Distance-based filtering with geolocation support
8//! - **Custom Scoring**: Boosting, function scores, and script scoring
9//! - **Query Explain**: Debug and understand query scoring
10
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use std::time::Instant;
14use thiserror::Error;
15
16use common::VectorId;
17
18/// A search result row: (vector ID, score, optional vector, optional metadata).
19type SearchResultRow = (VectorId, f32, Option<Vec<f32>>, Option<serde_json::Value>);
20
21// ============================================================================
22// Errors
23// ============================================================================
24
25/// Errors that can occur during batch search operations
26#[derive(Debug, Error)]
27pub enum BatchSearchError {
28    #[error("Query batch too large: {0} exceeds maximum {1}")]
29    BatchTooLarge(usize, usize),
30
31    #[error("Invalid cursor: {0}")]
32    InvalidCursor(String),
33
34    #[error("Invalid geo-coordinate: lat={0}, lon={1}")]
35    InvalidGeoCoordinate(f64, f64),
36
37    #[error("Unsupported aggregation type: {0}")]
38    UnsupportedAggregation(String),
39
40    #[error("Invalid scoring function: {0}")]
41    InvalidScoringFunction(String),
42
43    #[error("Query timeout exceeded: {0}ms")]
44    Timeout(u64),
45
46    #[error("Internal error: {0}")]
47    Internal(String),
48}
49
50// ============================================================================
51// Batch Query API
52// ============================================================================
53
54/// Configuration for batch query execution
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct BatchQueryConfig {
57    /// Maximum queries per batch
58    pub max_batch_size: usize,
59    /// Maximum concurrent queries
60    pub max_concurrency: usize,
61    /// Timeout per query in milliseconds
62    pub query_timeout_ms: u64,
63    /// Enable query deduplication
64    pub deduplicate_queries: bool,
65    /// Collect timing statistics
66    pub collect_stats: bool,
67}
68
69impl Default for BatchQueryConfig {
70    fn default() -> Self {
71        Self {
72            max_batch_size: 100,
73            max_concurrency: 16,
74            query_timeout_ms: 5000,
75            deduplicate_queries: true,
76            collect_stats: true,
77        }
78    }
79}
80
81/// A single query in a batch
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct BatchQueryItem {
84    /// Unique identifier for this query (for correlation)
85    pub query_id: String,
86    /// Query vector
87    pub vector: Vec<f32>,
88    /// Number of results to return
89    pub top_k: usize,
90    /// Optional filter expression
91    pub filter: Option<FilterExpression>,
92    /// Optional pagination cursor
93    pub cursor: Option<SearchCursor>,
94    /// Custom scoring configuration
95    pub scoring: Option<ScoringConfig>,
96}
97
98/// Batch query request
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct BatchQueryRequest {
101    /// Namespace/collection to search
102    pub namespace: String,
103    /// List of queries to execute
104    pub queries: Vec<BatchQueryItem>,
105    /// Include vectors in response
106    pub include_vectors: bool,
107    /// Include metadata in response
108    pub include_metadata: bool,
109    /// Global facet aggregations
110    pub facets: Option<Vec<FacetRequest>>,
111}
112
113/// Response for a single query in the batch
114#[derive(Debug, Clone, Serialize, Deserialize)]
115pub struct BatchQueryItemResponse {
116    /// Query ID (matches request)
117    pub query_id: String,
118    /// Search results
119    pub results: Vec<SearchHit>,
120    /// Pagination cursor for next page
121    pub next_cursor: Option<SearchCursor>,
122    /// Query execution time in milliseconds
123    pub took_ms: u64,
124    /// Total matches (may be estimate for large result sets)
125    pub total_matches: usize,
126    /// Query explanation (if requested)
127    pub explanation: Option<QueryExplanation>,
128}
129
130/// Full batch query response
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct BatchQueryResponse {
133    /// Responses for each query
134    pub responses: Vec<BatchQueryItemResponse>,
135    /// Global facet results
136    pub facets: Option<HashMap<String, FacetResult>>,
137    /// Batch execution statistics
138    pub stats: BatchQueryStats,
139}
140
141/// Statistics for batch query execution
142#[derive(Debug, Clone, Default, Serialize, Deserialize)]
143pub struct BatchQueryStats {
144    /// Total queries in batch
145    pub total_queries: usize,
146    /// Successful queries
147    pub successful_queries: usize,
148    /// Failed queries
149    pub failed_queries: usize,
150    /// Total batch execution time
151    pub total_took_ms: u64,
152    /// Average query time
153    pub avg_query_ms: f64,
154    /// Maximum query time
155    pub max_query_ms: u64,
156    /// Minimum query time
157    pub min_query_ms: u64,
158    /// Queries that were deduplicated
159    pub deduplicated_count: usize,
160}
161
162/// A search hit result
163#[derive(Debug, Clone, Serialize, Deserialize)]
164pub struct SearchHit {
165    /// Vector ID
166    pub id: VectorId,
167    /// Similarity/relevance score
168    pub score: f32,
169    /// Vector data (if requested)
170    pub vector: Option<Vec<f32>>,
171    /// Metadata (if requested)
172    pub metadata: Option<serde_json::Value>,
173    /// Sort values for search-after pagination
174    pub sort_values: Vec<SortValue>,
175}
176
177// ============================================================================
178// Pagination
179// ============================================================================
180
181/// Search cursor for pagination
182#[derive(Debug, Clone, Serialize, Deserialize)]
183pub struct SearchCursor {
184    /// Cursor type
185    pub cursor_type: CursorType,
186    /// Encoded cursor value
187    pub value: String,
188    /// Timestamp when cursor was created
189    pub created_at: u64,
190}
191
192impl SearchCursor {
193    /// Create a cursor-based pagination cursor
194    pub fn cursor_based(offset: usize, total: usize) -> Self {
195        let value = format!("{}:{}", offset, total);
196        Self {
197            cursor_type: CursorType::CursorBased,
198            value: base64_encode(&value),
199            created_at: current_timestamp_ms(),
200        }
201    }
202
203    /// Create a search-after pagination cursor
204    pub fn search_after(sort_values: &[SortValue]) -> Self {
205        let json = serde_json::to_string(sort_values).unwrap_or_default();
206        Self {
207            cursor_type: CursorType::SearchAfter,
208            value: base64_encode(&json),
209            created_at: current_timestamp_ms(),
210        }
211    }
212
213    /// Parse cursor offset for cursor-based pagination
214    pub fn parse_offset(&self) -> Result<usize, BatchSearchError> {
215        if self.cursor_type != CursorType::CursorBased {
216            return Err(BatchSearchError::InvalidCursor(
217                "Expected cursor-based cursor".into(),
218            ));
219        }
220        let decoded = base64_decode(&self.value)
221            .map_err(|_| BatchSearchError::InvalidCursor("Invalid base64".into()))?;
222        let parts: Vec<&str> = decoded.split(':').collect();
223        parts
224            .first()
225            .and_then(|s| s.parse().ok())
226            .ok_or_else(|| BatchSearchError::InvalidCursor("Invalid offset format".into()))
227    }
228
229    /// Parse sort values for search-after pagination
230    pub fn parse_sort_values(&self) -> Result<Vec<SortValue>, BatchSearchError> {
231        if self.cursor_type != CursorType::SearchAfter {
232            return Err(BatchSearchError::InvalidCursor(
233                "Expected search-after cursor".into(),
234            ));
235        }
236        let decoded = base64_decode(&self.value)
237            .map_err(|_| BatchSearchError::InvalidCursor("Invalid base64".into()))?;
238        serde_json::from_str(&decoded)
239            .map_err(|_| BatchSearchError::InvalidCursor("Invalid sort values".into()))
240    }
241}
242
243/// Type of pagination cursor
244#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
245pub enum CursorType {
246    /// Offset-based cursor (fast for small offsets)
247    CursorBased,
248    /// Search-after cursor (efficient for deep pagination)
249    SearchAfter,
250}
251
252/// Sort value for search-after pagination
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub enum SortValue {
255    Score(f32),
256    Integer(i64),
257    Float(f64),
258    String(String),
259    Null,
260}
261
262impl SortValue {
263    /// Compare two sort values
264    pub fn compare(&self, other: &SortValue) -> std::cmp::Ordering {
265        use std::cmp::Ordering;
266        match (self, other) {
267            (SortValue::Score(a), SortValue::Score(b)) => {
268                b.partial_cmp(a).unwrap_or(Ordering::Equal)
269            }
270            (SortValue::Integer(a), SortValue::Integer(b)) => a.cmp(b),
271            (SortValue::Float(a), SortValue::Float(b)) => {
272                a.partial_cmp(b).unwrap_or(Ordering::Equal)
273            }
274            (SortValue::String(a), SortValue::String(b)) => a.cmp(b),
275            (SortValue::Null, SortValue::Null) => Ordering::Equal,
276            (SortValue::Null, _) => Ordering::Greater,
277            (_, SortValue::Null) => Ordering::Less,
278            _ => Ordering::Equal,
279        }
280    }
281}
282
283/// Pagination configuration
284#[derive(Debug, Clone, Serialize, Deserialize)]
285pub struct PaginationConfig {
286    /// Page size (number of results per page)
287    pub page_size: usize,
288    /// Maximum allowed offset for cursor-based pagination
289    pub max_offset: usize,
290    /// Cursor expiration time in seconds
291    pub cursor_ttl_secs: u64,
292    /// Default sort order
293    pub default_sort: Vec<SortField>,
294}
295
296impl Default for PaginationConfig {
297    fn default() -> Self {
298        Self {
299            page_size: 20,
300            max_offset: 10000,
301            cursor_ttl_secs: 3600,
302            default_sort: vec![SortField {
303                field: "_score".into(),
304                order: SortOrder::Descending,
305            }],
306        }
307    }
308}
309
310/// Sort field specification
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SortField {
313    /// Field name to sort by
314    pub field: String,
315    /// Sort order
316    pub order: SortOrder,
317}
318
319/// Sort order
320#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
321pub enum SortOrder {
322    Ascending,
323    Descending,
324}
325
326// ============================================================================
327// Faceted Search / Aggregations
328// ============================================================================
329
330/// Request for a facet/aggregation
331#[derive(Debug, Clone, Serialize, Deserialize)]
332pub struct FacetRequest {
333    /// Name of this facet (for response)
334    pub name: String,
335    /// Field to aggregate on
336    pub field: String,
337    /// Type of aggregation
338    pub agg_type: AggregationType,
339    /// Maximum number of buckets (for terms aggregation)
340    pub max_buckets: Option<usize>,
341    /// Ranges for range aggregation
342    pub ranges: Option<Vec<RangeBucket>>,
343}
344
345/// Type of aggregation
346#[derive(Debug, Clone, Serialize, Deserialize)]
347pub enum AggregationType {
348    /// Count unique terms
349    Terms,
350    /// Numeric range buckets
351    Range,
352    /// Date histogram
353    DateHistogram { interval: String },
354    /// Numeric histogram
355    Histogram { interval: f64 },
356    /// Minimum value
357    Min,
358    /// Maximum value
359    Max,
360    /// Average value
361    Avg,
362    /// Sum of values
363    Sum,
364    /// Value count
365    Count,
366    /// Cardinality (unique count estimate)
367    Cardinality,
368}
369
370/// Range bucket definition
371#[derive(Debug, Clone, Serialize, Deserialize)]
372pub struct RangeBucket {
373    /// Bucket key/name
374    pub key: String,
375    /// From value (inclusive)
376    pub from: Option<f64>,
377    /// To value (exclusive)
378    pub to: Option<f64>,
379}
380
381/// Result of a facet aggregation
382#[derive(Debug, Clone, Serialize, Deserialize)]
383pub struct FacetResult {
384    /// Facet name
385    pub name: String,
386    /// Field that was aggregated
387    pub field: String,
388    /// Buckets (for terms, range, histogram)
389    pub buckets: Option<Vec<FacetBucket>>,
390    /// Metric value (for min, max, avg, sum, count)
391    pub value: Option<f64>,
392    /// Total document count
393    pub doc_count: usize,
394}
395
396/// A bucket in a facet result
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct FacetBucket {
399    /// Bucket key
400    pub key: String,
401    /// Document count in bucket
402    pub doc_count: usize,
403    /// Nested aggregations
404    pub sub_aggregations: Option<HashMap<String, FacetResult>>,
405}
406
407/// Faceted search executor
408pub struct FacetExecutor {
409    max_buckets: usize,
410}
411
412impl FacetExecutor {
413    /// Create a new facet executor
414    pub fn new(max_buckets: usize) -> Self {
415        Self { max_buckets }
416    }
417
418    /// Execute a terms aggregation
419    pub fn terms_aggregation(
420        &self,
421        values: &[Option<serde_json::Value>],
422        max_buckets: usize,
423    ) -> Vec<FacetBucket> {
424        let mut counts: HashMap<String, usize> = HashMap::new();
425
426        for value in values.iter().flatten() {
427            let key = match value {
428                serde_json::Value::String(s) => s.clone(),
429                serde_json::Value::Number(n) => n.to_string(),
430                serde_json::Value::Bool(b) => b.to_string(),
431                _ => continue,
432            };
433            *counts.entry(key).or_insert(0) += 1;
434        }
435
436        let mut buckets: Vec<_> = counts
437            .into_iter()
438            .map(|(key, count)| FacetBucket {
439                key,
440                doc_count: count,
441                sub_aggregations: None,
442            })
443            .collect();
444
445        // Sort by count descending
446        buckets.sort_by(|a, b| b.doc_count.cmp(&a.doc_count));
447        buckets.truncate(max_buckets.min(self.max_buckets));
448        buckets
449    }
450
451    /// Execute a range aggregation
452    pub fn range_aggregation(
453        &self,
454        values: &[Option<f64>],
455        ranges: &[RangeBucket],
456    ) -> Vec<FacetBucket> {
457        ranges
458            .iter()
459            .map(|range| {
460                let count = values
461                    .iter()
462                    .filter(|v| {
463                        if let Some(val) = v {
464                            let from_ok = range.from.is_none_or(|f| *val >= f);
465                            let to_ok = range.to.is_none_or(|t| *val < t);
466                            from_ok && to_ok
467                        } else {
468                            false
469                        }
470                    })
471                    .count();
472
473                FacetBucket {
474                    key: range.key.clone(),
475                    doc_count: count,
476                    sub_aggregations: None,
477                }
478            })
479            .collect()
480    }
481
482    /// Execute a numeric aggregation (min, max, avg, sum)
483    pub fn numeric_aggregation(&self, values: &[Option<f64>], agg_type: &AggregationType) -> f64 {
484        let valid_values: Vec<f64> = values.iter().filter_map(|v| *v).collect();
485
486        if valid_values.is_empty() {
487            return 0.0;
488        }
489
490        match agg_type {
491            AggregationType::Min => valid_values.iter().copied().fold(f64::INFINITY, f64::min),
492            AggregationType::Max => valid_values
493                .iter()
494                .copied()
495                .fold(f64::NEG_INFINITY, f64::max),
496            AggregationType::Avg => valid_values.iter().sum::<f64>() / valid_values.len() as f64,
497            AggregationType::Sum => valid_values.iter().sum(),
498            AggregationType::Count => valid_values.len() as f64,
499            _ => 0.0,
500        }
501    }
502}
503
504// ============================================================================
505// Geo-Spatial Filtering
506// ============================================================================
507
508/// Geographic point
509#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
510pub struct GeoPoint {
511    /// Latitude (-90 to 90)
512    pub lat: f64,
513    /// Longitude (-180 to 180)
514    pub lon: f64,
515}
516
517impl GeoPoint {
518    /// Create a new geo point with validation
519    pub fn new(lat: f64, lon: f64) -> Result<Self, BatchSearchError> {
520        if !(-90.0..=90.0).contains(&lat) || !(-180.0..=180.0).contains(&lon) {
521            return Err(BatchSearchError::InvalidGeoCoordinate(lat, lon));
522        }
523        Ok(Self { lat, lon })
524    }
525
526    /// Calculate haversine distance to another point in kilometers
527    pub fn distance_km(&self, other: &GeoPoint) -> f64 {
528        const EARTH_RADIUS_KM: f64 = 6371.0;
529
530        let lat1 = self.lat.to_radians();
531        let lat2 = other.lat.to_radians();
532        let delta_lat = (other.lat - self.lat).to_radians();
533        let delta_lon = (other.lon - self.lon).to_radians();
534
535        let a = (delta_lat / 2.0).sin().powi(2)
536            + lat1.cos() * lat2.cos() * (delta_lon / 2.0).sin().powi(2);
537        let c = 2.0 * a.sqrt().asin();
538
539        EARTH_RADIUS_KM * c
540    }
541
542    /// Calculate distance in meters
543    pub fn distance_m(&self, other: &GeoPoint) -> f64 {
544        self.distance_km(other) * 1000.0
545    }
546
547    /// Calculate distance in miles
548    pub fn distance_miles(&self, other: &GeoPoint) -> f64 {
549        self.distance_km(other) * 0.621371
550    }
551}
552
553/// Geo-spatial filter types
554#[derive(Debug, Clone, Serialize, Deserialize)]
555pub enum GeoFilter {
556    /// Filter by distance from a point
557    Distance {
558        /// Center point
559        center: GeoPoint,
560        /// Maximum distance
561        distance: f64,
562        /// Distance unit
563        unit: DistanceUnit,
564    },
565    /// Filter by bounding box
566    BoundingBox {
567        /// Top-left corner
568        top_left: GeoPoint,
569        /// Bottom-right corner
570        bottom_right: GeoPoint,
571    },
572    /// Filter by polygon
573    Polygon {
574        /// Polygon vertices (must be closed)
575        points: Vec<GeoPoint>,
576    },
577}
578
579/// Distance unit for geo queries
580#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
581pub enum DistanceUnit {
582    Meters,
583    Kilometers,
584    Miles,
585    Feet,
586}
587
588impl DistanceUnit {
589    /// Convert distance to meters
590    pub fn to_meters(&self, distance: f64) -> f64 {
591        match self {
592            DistanceUnit::Meters => distance,
593            DistanceUnit::Kilometers => distance * 1000.0,
594            DistanceUnit::Miles => distance * 1609.34,
595            DistanceUnit::Feet => distance * 0.3048,
596        }
597    }
598}
599
600/// Geo-spatial filter executor
601pub struct GeoFilterExecutor;
602
603impl GeoFilterExecutor {
604    /// Check if a point passes the geo filter
605    pub fn matches(filter: &GeoFilter, point: &GeoPoint) -> bool {
606        match filter {
607            GeoFilter::Distance {
608                center,
609                distance,
610                unit,
611            } => {
612                let max_distance_m = unit.to_meters(*distance);
613                center.distance_m(point) <= max_distance_m
614            }
615            GeoFilter::BoundingBox {
616                top_left,
617                bottom_right,
618            } => {
619                point.lat <= top_left.lat
620                    && point.lat >= bottom_right.lat
621                    && point.lon >= top_left.lon
622                    && point.lon <= bottom_right.lon
623            }
624            GeoFilter::Polygon { points } => Self::point_in_polygon(point, points),
625        }
626    }
627
628    /// Ray casting algorithm for point-in-polygon test
629    fn point_in_polygon(point: &GeoPoint, polygon: &[GeoPoint]) -> bool {
630        if polygon.len() < 3 {
631            return false;
632        }
633
634        let mut inside = false;
635        let n = polygon.len();
636
637        let mut j = n - 1;
638        for i in 0..n {
639            let pi = &polygon[i];
640            let pj = &polygon[j];
641
642            if ((pi.lat > point.lat) != (pj.lat > point.lat))
643                && (point.lon
644                    < (pj.lon - pi.lon) * (point.lat - pi.lat) / (pj.lat - pi.lat) + pi.lon)
645            {
646                inside = !inside;
647            }
648            j = i;
649        }
650
651        inside
652    }
653}
654
655// ============================================================================
656// Custom Scoring / Boosting
657// ============================================================================
658
659/// Custom scoring configuration
660#[derive(Debug, Clone, Serialize, Deserialize)]
661pub struct ScoringConfig {
662    /// Base score mode
663    pub score_mode: ScoreMode,
664    /// Boost functions to apply
665    pub functions: Vec<ScoreFunction>,
666    /// How to combine function scores
667    pub boost_mode: BoostMode,
668    /// Minimum score threshold
669    pub min_score: Option<f32>,
670}
671
672impl Default for ScoringConfig {
673    fn default() -> Self {
674        Self {
675            score_mode: ScoreMode::Multiply,
676            functions: Vec::new(),
677            boost_mode: BoostMode::Multiply,
678            min_score: None,
679        }
680    }
681}
682
683/// How to combine multiple function scores
684#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
685pub enum ScoreMode {
686    Multiply,
687    Sum,
688    Average,
689    First,
690    Max,
691    Min,
692}
693
694/// How to combine function score with original score
695#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
696pub enum BoostMode {
697    Multiply,
698    Replace,
699    Sum,
700    Average,
701    Max,
702    Min,
703}
704
705/// Score function definition
706#[derive(Debug, Clone, Serialize, Deserialize)]
707pub enum ScoreFunction {
708    /// Constant boost
709    Weight { weight: f32 },
710    /// Field value boost
711    FieldValue {
712        field: String,
713        factor: f32,
714        modifier: FieldValueModifier,
715        missing: f32,
716    },
717    /// Decay function (gaussian, linear, exponential)
718    Decay {
719        field: String,
720        origin: f64,
721        scale: f64,
722        offset: f64,
723        decay: f64,
724        decay_type: DecayType,
725    },
726    /// Random score for variety
727    RandomScore { seed: u64, field: Option<String> },
728    /// Script-based scoring
729    Script {
730        source: String,
731        params: HashMap<String, f64>,
732    },
733    /// Geo distance decay
734    GeoDecay {
735        field: String,
736        origin: GeoPoint,
737        scale: f64,
738        scale_unit: DistanceUnit,
739        offset: f64,
740        decay: f64,
741    },
742}
743
744/// Modifier for field value scoring
745#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
746pub enum FieldValueModifier {
747    None,
748    Log,
749    Log1p,
750    Log2p,
751    Ln,
752    Ln1p,
753    Ln2p,
754    Square,
755    Sqrt,
756    Reciprocal,
757}
758
759/// Type of decay function
760#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
761pub enum DecayType {
762    Gaussian,
763    Linear,
764    Exponential,
765}
766
767/// Score function executor
768pub struct ScoreFunctionExecutor;
769
770impl ScoreFunctionExecutor {
771    /// Apply a score function
772    pub fn apply(
773        function: &ScoreFunction,
774        original_score: f32,
775        metadata: Option<&serde_json::Value>,
776    ) -> f32 {
777        match function {
778            ScoreFunction::Weight { weight } => *weight,
779
780            ScoreFunction::FieldValue {
781                field,
782                factor,
783                modifier,
784                missing,
785            } => {
786                let value = metadata
787                    .and_then(|m| m.get(field))
788                    .and_then(|v| v.as_f64())
789                    .unwrap_or(*missing as f64);
790
791                let modified = Self::apply_modifier(value, *modifier);
792                (modified * *factor as f64) as f32
793            }
794
795            ScoreFunction::Decay {
796                origin,
797                scale,
798                offset,
799                decay,
800                decay_type,
801                field,
802            } => {
803                let value = metadata
804                    .and_then(|m| m.get(field))
805                    .and_then(|v| v.as_f64())
806                    .unwrap_or(*origin);
807
808                let distance = (value - origin).abs() - offset;
809                if distance <= 0.0 {
810                    return 1.0;
811                }
812
813                Self::compute_decay(distance, *scale, *decay, *decay_type) as f32
814            }
815
816            ScoreFunction::RandomScore { seed, .. } => {
817                // Simple pseudo-random based on seed
818                let hash = (*seed as u32).wrapping_mul(2654435769);
819                hash as f32 / u32::MAX as f32
820            }
821
822            ScoreFunction::Script { source, params } => {
823                // Simple expression evaluation for demo
824                Self::evaluate_script(source, params, original_score, metadata)
825            }
826
827            ScoreFunction::GeoDecay {
828                field,
829                origin,
830                scale,
831                scale_unit,
832                offset,
833                decay,
834            } => {
835                let point = metadata.and_then(|m| m.get(field)).and_then(|v| {
836                    let lat = v.get("lat")?.as_f64()?;
837                    let lon = v.get("lon")?.as_f64()?;
838                    Some(GeoPoint { lat, lon })
839                });
840
841                if let Some(point) = point {
842                    let distance_m = origin.distance_m(&point);
843                    let scale_m = scale_unit.to_meters(*scale);
844                    let offset_m = scale_unit.to_meters(*offset);
845
846                    let adjusted_distance = (distance_m - offset_m).max(0.0);
847                    Self::compute_decay(adjusted_distance, scale_m, *decay, DecayType::Gaussian)
848                        as f32
849                } else {
850                    0.0
851                }
852            }
853        }
854    }
855
856    fn apply_modifier(value: f64, modifier: FieldValueModifier) -> f64 {
857        match modifier {
858            FieldValueModifier::None => value,
859            FieldValueModifier::Log => value.log10(),
860            FieldValueModifier::Log1p => (1.0 + value).log10(),
861            FieldValueModifier::Log2p => (2.0 + value).log10(),
862            FieldValueModifier::Ln => value.ln(),
863            FieldValueModifier::Ln1p => (1.0 + value).ln(),
864            FieldValueModifier::Ln2p => (2.0 + value).ln(),
865            FieldValueModifier::Square => value * value,
866            FieldValueModifier::Sqrt => value.sqrt(),
867            FieldValueModifier::Reciprocal => 1.0 / value.max(0.001),
868        }
869    }
870
871    fn compute_decay(distance: f64, scale: f64, decay: f64, decay_type: DecayType) -> f64 {
872        let lambda = scale.ln().abs() / decay.ln().abs();
873
874        match decay_type {
875            DecayType::Gaussian => (-0.5 * (distance / lambda).powi(2)).exp(),
876            DecayType::Linear => ((lambda - distance) / lambda).max(0.0),
877            DecayType::Exponential => (-distance / lambda).exp(),
878        }
879    }
880
881    fn evaluate_script(
882        _source: &str,
883        params: &HashMap<String, f64>,
884        original_score: f32,
885        _metadata: Option<&serde_json::Value>,
886    ) -> f32 {
887        // Simplified: just use params to boost score
888        let boost = params.get("boost").copied().unwrap_or(1.0) as f32;
889        original_score * boost
890    }
891
892    /// Combine multiple function scores
893    pub fn combine_scores(scores: &[f32], mode: ScoreMode) -> f32 {
894        if scores.is_empty() {
895            return 1.0;
896        }
897
898        match mode {
899            ScoreMode::Multiply => scores.iter().product(),
900            ScoreMode::Sum => scores.iter().sum(),
901            ScoreMode::Average => scores.iter().sum::<f32>() / scores.len() as f32,
902            ScoreMode::First => scores[0],
903            ScoreMode::Max => scores.iter().copied().fold(f32::NEG_INFINITY, f32::max),
904            ScoreMode::Min => scores.iter().copied().fold(f32::INFINITY, f32::min),
905        }
906    }
907
908    /// Combine function score with original score
909    pub fn combine_with_original(original: f32, function_score: f32, mode: BoostMode) -> f32 {
910        match mode {
911            BoostMode::Multiply => original * function_score,
912            BoostMode::Replace => function_score,
913            BoostMode::Sum => original + function_score,
914            BoostMode::Average => (original + function_score) / 2.0,
915            BoostMode::Max => original.max(function_score),
916            BoostMode::Min => original.min(function_score),
917        }
918    }
919}
920
921// ============================================================================
922// Query Explain API
923// ============================================================================
924
925/// Query explanation for debugging scoring
926#[derive(Debug, Clone, Serialize, Deserialize)]
927pub struct QueryExplanation {
928    /// Final computed score
929    pub score: f32,
930    /// Description of scoring
931    pub description: String,
932    /// Component explanations
933    pub details: Vec<ScoreDetail>,
934}
935
936/// Detail of a score component
937#[derive(Debug, Clone, Serialize, Deserialize)]
938pub struct ScoreDetail {
939    /// Component name
940    pub name: String,
941    /// Component score value
942    pub value: f32,
943    /// How this component was computed
944    pub description: String,
945    /// Nested details
946    pub details: Option<Vec<ScoreDetail>>,
947}
948
949/// Query explainer
950pub struct QueryExplainer;
951
952impl QueryExplainer {
953    /// Generate explanation for a query result
954    pub fn explain(
955        id: &VectorId,
956        original_score: f32,
957        scoring_config: Option<&ScoringConfig>,
958        metadata: Option<&serde_json::Value>,
959    ) -> QueryExplanation {
960        let mut details = vec![ScoreDetail {
961            name: "vector_similarity".into(),
962            value: original_score,
963            description: "Base vector similarity score".into(),
964            details: None,
965        }];
966
967        let mut final_score = original_score;
968
969        if let Some(config) = scoring_config {
970            let mut function_scores = Vec::new();
971
972            for (i, func) in config.functions.iter().enumerate() {
973                let func_score = ScoreFunctionExecutor::apply(func, original_score, metadata);
974                function_scores.push(func_score);
975
976                details.push(ScoreDetail {
977                    name: format!("function_{}", i),
978                    value: func_score,
979                    description: Self::describe_function(func),
980                    details: None,
981                });
982            }
983
984            if !function_scores.is_empty() {
985                let combined =
986                    ScoreFunctionExecutor::combine_scores(&function_scores, config.score_mode);
987                details.push(ScoreDetail {
988                    name: "combined_functions".into(),
989                    value: combined,
990                    description: format!("Functions combined using {:?}", config.score_mode),
991                    details: None,
992                });
993
994                final_score = ScoreFunctionExecutor::combine_with_original(
995                    original_score,
996                    combined,
997                    config.boost_mode,
998                );
999            }
1000        }
1001
1002        QueryExplanation {
1003            score: final_score,
1004            description: format!("Explanation for document {}", id),
1005            details,
1006        }
1007    }
1008
1009    fn describe_function(func: &ScoreFunction) -> String {
1010        match func {
1011            ScoreFunction::Weight { weight } => format!("Constant weight: {}", weight),
1012            ScoreFunction::FieldValue { field, factor, .. } => {
1013                format!("Field value boost on '{}' with factor {}", field, factor)
1014            }
1015            ScoreFunction::Decay {
1016                field, decay_type, ..
1017            } => format!("{:?} decay on field '{}'", decay_type, field),
1018            ScoreFunction::RandomScore { seed, .. } => format!("Random score with seed {}", seed),
1019            ScoreFunction::Script { source, .. } => format!("Script: {}", source),
1020            ScoreFunction::GeoDecay { field, .. } => format!("Geo decay on field '{}'", field),
1021        }
1022    }
1023}
1024
1025// ============================================================================
1026// Filter Expressions
1027// ============================================================================
1028
1029/// Filter expression for queries
1030#[derive(Debug, Clone, Serialize, Deserialize)]
1031pub enum FilterExpression {
1032    /// Match exact value
1033    Term {
1034        field: String,
1035        value: serde_json::Value,
1036    },
1037    /// Match any of the values
1038    Terms {
1039        field: String,
1040        values: Vec<serde_json::Value>,
1041    },
1042    /// Range filter
1043    Range {
1044        field: String,
1045        gte: Option<f64>,
1046        gt: Option<f64>,
1047        lte: Option<f64>,
1048        lt: Option<f64>,
1049    },
1050    /// Exists filter
1051    Exists { field: String },
1052    /// Prefix match
1053    Prefix { field: String, prefix: String },
1054    /// Geo filter
1055    Geo { field: String, filter: GeoFilter },
1056    /// Boolean AND
1057    And(Vec<FilterExpression>),
1058    /// Boolean OR
1059    Or(Vec<FilterExpression>),
1060    /// Boolean NOT
1061    Not(Box<FilterExpression>),
1062}
1063
1064/// Filter executor
1065pub struct FilterExecutor;
1066
1067impl FilterExecutor {
1068    /// Check if metadata matches filter
1069    pub fn matches(filter: &FilterExpression, metadata: Option<&serde_json::Value>) -> bool {
1070        let metadata = match metadata {
1071            Some(m) => m,
1072            None => return false,
1073        };
1074
1075        match filter {
1076            FilterExpression::Term { field, value } => metadata.get(field) == Some(value),
1077            FilterExpression::Terms { field, values } => {
1078                metadata.get(field).is_some_and(|v| values.contains(v))
1079            }
1080            FilterExpression::Range {
1081                field,
1082                gte,
1083                gt,
1084                lte,
1085                lt,
1086            } => {
1087                let value = metadata.get(field).and_then(|v| v.as_f64());
1088                if let Some(v) = value {
1089                    gte.is_none_or(|x| v >= x)
1090                        && gt.is_none_or(|x| v > x)
1091                        && lte.is_none_or(|x| v <= x)
1092                        && lt.is_none_or(|x| v < x)
1093                } else {
1094                    false
1095                }
1096            }
1097            FilterExpression::Exists { field } => metadata.get(field).is_some(),
1098            FilterExpression::Prefix { field, prefix } => metadata
1099                .get(field)
1100                .and_then(|v| v.as_str())
1101                .is_some_and(|s| s.starts_with(prefix)),
1102            FilterExpression::Geo { field, filter } => {
1103                let point = metadata.get(field).and_then(|v| {
1104                    let lat = v.get("lat")?.as_f64()?;
1105                    let lon = v.get("lon")?.as_f64()?;
1106                    Some(GeoPoint { lat, lon })
1107                });
1108                point.is_some_and(|p| GeoFilterExecutor::matches(filter, &p))
1109            }
1110            FilterExpression::And(filters) => {
1111                filters.iter().all(|f| Self::matches(f, Some(metadata)))
1112            }
1113            FilterExpression::Or(filters) => {
1114                filters.iter().any(|f| Self::matches(f, Some(metadata)))
1115            }
1116            FilterExpression::Not(filter) => !Self::matches(filter, Some(metadata)),
1117        }
1118    }
1119}
1120
1121// ============================================================================
1122// Batch Query Executor
1123// ============================================================================
1124
1125/// Batch query executor for high-throughput search
1126pub struct BatchQueryExecutor {
1127    config: BatchQueryConfig,
1128}
1129
1130impl BatchQueryExecutor {
1131    /// Create a new batch query executor
1132    pub fn new(config: BatchQueryConfig) -> Self {
1133        Self { config }
1134    }
1135
1136    /// Execute a batch of queries
1137    pub fn execute(
1138        &self,
1139        request: &BatchQueryRequest,
1140        search_fn: impl Fn(
1141            &[f32],
1142            usize,
1143            Option<&FilterExpression>,
1144        ) -> Vec<(VectorId, f32, Option<Vec<f32>>, Option<serde_json::Value>)>,
1145    ) -> Result<BatchQueryResponse, BatchSearchError> {
1146        let start = Instant::now();
1147
1148        // Validate batch size
1149        if request.queries.len() > self.config.max_batch_size {
1150            return Err(BatchSearchError::BatchTooLarge(
1151                request.queries.len(),
1152                self.config.max_batch_size,
1153            ));
1154        }
1155
1156        let mut responses: Vec<BatchQueryItemResponse> = Vec::with_capacity(request.queries.len());
1157        let mut query_times: Vec<u64> = Vec::new();
1158        let mut deduplicated_count = 0;
1159
1160        // Process queries (could be parallelized in production)
1161        let mut seen_queries: HashMap<Vec<u32>, usize> = HashMap::new();
1162
1163        for query in &request.queries {
1164            let query_start = Instant::now();
1165
1166            // Check for duplicate queries (deduplication)
1167            let query_hash: Vec<u32> = query.vector.iter().map(|f| f.to_bits()).collect();
1168
1169            if self.config.deduplicate_queries {
1170                if let Some(&existing_idx) = seen_queries.get(&query_hash) {
1171                    // Reuse previous response
1172                    let mut response = responses[existing_idx].clone();
1173                    response.query_id = query.query_id.clone();
1174                    responses.push(response);
1175                    deduplicated_count += 1;
1176                    continue;
1177                }
1178                seen_queries.insert(query_hash, responses.len());
1179            }
1180
1181            // Execute search
1182            let raw_results = search_fn(&query.vector, query.top_k * 2, query.filter.as_ref());
1183
1184            // Apply pagination
1185            let (results, next_cursor) = self.apply_pagination(&raw_results, query)?;
1186
1187            // Apply custom scoring
1188            let scored_results = self.apply_scoring(&results, query.scoring.as_ref());
1189
1190            // Build hits
1191            let hits: Vec<SearchHit> = scored_results
1192                .into_iter()
1193                .take(query.top_k)
1194                .map(|(id, score, vec, meta)| SearchHit {
1195                    id,
1196                    score,
1197                    vector: if request.include_vectors { vec } else { None },
1198                    metadata: if request.include_metadata { meta } else { None },
1199                    sort_values: vec![SortValue::Score(score)],
1200                })
1201                .collect();
1202
1203            let query_time = query_start.elapsed().as_millis() as u64;
1204            query_times.push(query_time);
1205
1206            responses.push(BatchQueryItemResponse {
1207                query_id: query.query_id.clone(),
1208                results: hits,
1209                next_cursor,
1210                took_ms: query_time,
1211                total_matches: raw_results.len(),
1212                explanation: None,
1213            });
1214        }
1215
1216        // Compute facets if requested
1217        let facets = request.facets.as_ref().map(|_| HashMap::new());
1218
1219        // Compute statistics
1220        let total_took = start.elapsed().as_millis() as u64;
1221        let stats = BatchQueryStats {
1222            total_queries: request.queries.len(),
1223            successful_queries: responses.len(),
1224            failed_queries: 0,
1225            total_took_ms: total_took,
1226            avg_query_ms: if !query_times.is_empty() {
1227                query_times.iter().sum::<u64>() as f64 / query_times.len() as f64
1228            } else {
1229                0.0
1230            },
1231            max_query_ms: query_times.iter().copied().max().unwrap_or(0),
1232            min_query_ms: query_times.iter().copied().min().unwrap_or(0),
1233            deduplicated_count,
1234        };
1235
1236        Ok(BatchQueryResponse {
1237            responses,
1238            facets,
1239            stats,
1240        })
1241    }
1242
1243    fn apply_pagination(
1244        &self,
1245        results: &[SearchResultRow],
1246        query: &BatchQueryItem,
1247    ) -> Result<(Vec<SearchResultRow>, Option<SearchCursor>), BatchSearchError> {
1248        if let Some(cursor) = &query.cursor {
1249            match cursor.cursor_type {
1250                CursorType::CursorBased => {
1251                    let offset = cursor.parse_offset()?;
1252                    let paginated: Vec<_> = results.iter().skip(offset).cloned().collect();
1253                    let next_cursor = if offset + query.top_k < results.len() {
1254                        Some(SearchCursor::cursor_based(
1255                            offset + query.top_k,
1256                            results.len(),
1257                        ))
1258                    } else {
1259                        None
1260                    };
1261                    Ok((paginated, next_cursor))
1262                }
1263                CursorType::SearchAfter => {
1264                    let sort_values = cursor.parse_sort_values()?;
1265                    let start_idx = results
1266                        .iter()
1267                        .position(|(_, score, _, _)| {
1268                            let sv = SortValue::Score(*score);
1269                            sv.compare(&sort_values[0]) == std::cmp::Ordering::Greater
1270                        })
1271                        .unwrap_or(0);
1272                    let paginated: Vec<_> = results.iter().skip(start_idx).cloned().collect();
1273                    let next_cursor = if start_idx + query.top_k < results.len() {
1274                        paginated
1275                            .get(query.top_k - 1)
1276                            .map(|last| SearchCursor::search_after(&[SortValue::Score(last.1)]))
1277                    } else {
1278                        None
1279                    };
1280                    Ok((paginated, next_cursor))
1281                }
1282            }
1283        } else {
1284            let next_cursor = if results.len() > query.top_k {
1285                Some(SearchCursor::cursor_based(query.top_k, results.len()))
1286            } else {
1287                None
1288            };
1289            Ok((results.to_vec(), next_cursor))
1290        }
1291    }
1292
1293    fn apply_scoring(
1294        &self,
1295        results: &[SearchResultRow],
1296        scoring: Option<&ScoringConfig>,
1297    ) -> Vec<SearchResultRow> {
1298        let config = match scoring {
1299            Some(c) if !c.functions.is_empty() => c,
1300            _ => return results.to_vec(),
1301        };
1302        let mut scored: Vec<_> = results
1303            .iter()
1304            .map(|(id, score, vec, meta)| {
1305                let function_scores: Vec<f32> = config
1306                    .functions
1307                    .iter()
1308                    .map(|f| ScoreFunctionExecutor::apply(f, *score, meta.as_ref()))
1309                    .collect();
1310
1311                let combined =
1312                    ScoreFunctionExecutor::combine_scores(&function_scores, config.score_mode);
1313                let final_score = ScoreFunctionExecutor::combine_with_original(
1314                    *score,
1315                    combined,
1316                    config.boost_mode,
1317                );
1318
1319                (id.clone(), final_score, vec.clone(), meta.clone())
1320            })
1321            .collect();
1322
1323        // Apply min_score filter
1324        if let Some(min) = config.min_score {
1325            scored.retain(|(_, s, _, _)| *s >= min);
1326        }
1327
1328        // Re-sort by new scores
1329        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
1330        scored
1331    }
1332}
1333
1334// ============================================================================
1335// Utility Functions
1336// ============================================================================
1337
1338fn base64_encode(input: &str) -> String {
1339    // Simple base64 encoding for cursor values
1340    let bytes = input.as_bytes();
1341    let mut result = String::new();
1342    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1343
1344    for chunk in bytes.chunks(3) {
1345        let mut n = (chunk[0] as u32) << 16;
1346        if chunk.len() > 1 {
1347            n |= (chunk[1] as u32) << 8;
1348        }
1349        if chunk.len() > 2 {
1350            n |= chunk[2] as u32;
1351        }
1352
1353        result.push(CHARS[(n >> 18 & 0x3f) as usize] as char);
1354        result.push(CHARS[(n >> 12 & 0x3f) as usize] as char);
1355        if chunk.len() > 1 {
1356            result.push(CHARS[(n >> 6 & 0x3f) as usize] as char);
1357        } else {
1358            result.push('=');
1359        }
1360        if chunk.len() > 2 {
1361            result.push(CHARS[(n & 0x3f) as usize] as char);
1362        } else {
1363            result.push('=');
1364        }
1365    }
1366    result
1367}
1368
1369fn base64_decode(input: &str) -> Result<String, &'static str> {
1370    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
1371
1372    let mut result = Vec::new();
1373    let input = input.trim_end_matches('=');
1374    let bytes: Vec<u8> = input.bytes().collect();
1375
1376    for chunk in bytes.chunks(4) {
1377        let mut n = 0u32;
1378        for (i, &b) in chunk.iter().enumerate() {
1379            let pos = CHARS.iter().position(|&c| c == b).ok_or("Invalid base64")?;
1380            n |= (pos as u32) << (18 - i * 6);
1381        }
1382
1383        result.push((n >> 16) as u8);
1384        if chunk.len() > 2 {
1385            result.push((n >> 8) as u8);
1386        }
1387        if chunk.len() > 3 {
1388            result.push(n as u8);
1389        }
1390    }
1391
1392    String::from_utf8(result).map_err(|_| "Invalid UTF-8")
1393}
1394
1395fn current_timestamp_ms() -> u64 {
1396    std::time::SystemTime::now()
1397        .duration_since(std::time::UNIX_EPOCH)
1398        .map(|d| d.as_millis() as u64)
1399        .unwrap_or(0)
1400}
1401
1402// ============================================================================
1403// Tests
1404// ============================================================================
1405
1406#[cfg(test)]
1407mod tests {
1408    use super::*;
1409
1410    #[test]
1411    fn test_batch_query_config_default() {
1412        let config = BatchQueryConfig::default();
1413        assert_eq!(config.max_batch_size, 100);
1414        assert_eq!(config.max_concurrency, 16);
1415        assert_eq!(config.query_timeout_ms, 5000);
1416    }
1417
1418    #[test]
1419    fn test_cursor_based_pagination() {
1420        let cursor = SearchCursor::cursor_based(20, 100);
1421        assert_eq!(cursor.cursor_type, CursorType::CursorBased);
1422
1423        let offset = cursor.parse_offset().unwrap();
1424        assert_eq!(offset, 20);
1425    }
1426
1427    #[test]
1428    fn test_search_after_pagination() {
1429        let sort_values = vec![SortValue::Score(0.95), SortValue::String("doc123".into())];
1430        let cursor = SearchCursor::search_after(&sort_values);
1431        assert_eq!(cursor.cursor_type, CursorType::SearchAfter);
1432
1433        let parsed = cursor.parse_sort_values().unwrap();
1434        assert_eq!(parsed.len(), 2);
1435    }
1436
1437    #[test]
1438    fn test_sort_value_comparison() {
1439        assert_eq!(
1440            SortValue::Score(0.9).compare(&SortValue::Score(0.8)),
1441            std::cmp::Ordering::Less
1442        );
1443        assert_eq!(
1444            SortValue::Integer(10).compare(&SortValue::Integer(5)),
1445            std::cmp::Ordering::Greater
1446        );
1447    }
1448
1449    #[test]
1450    fn test_geo_point_distance() {
1451        // New York to Los Angeles (approximately 3,940 km)
1452        let nyc = GeoPoint::new(40.7128, -74.0060).unwrap();
1453        let la = GeoPoint::new(34.0522, -118.2437).unwrap();
1454
1455        let distance = nyc.distance_km(&la);
1456        assert!(distance > 3900.0 && distance < 4000.0);
1457    }
1458
1459    #[test]
1460    fn test_geo_point_validation() {
1461        assert!(GeoPoint::new(45.0, 90.0).is_ok());
1462        assert!(GeoPoint::new(91.0, 0.0).is_err());
1463        assert!(GeoPoint::new(0.0, 181.0).is_err());
1464    }
1465
1466    #[test]
1467    fn test_geo_distance_filter() {
1468        let center = GeoPoint::new(40.7128, -74.0060).unwrap();
1469        let filter = GeoFilter::Distance {
1470            center,
1471            distance: 100.0,
1472            unit: DistanceUnit::Kilometers,
1473        };
1474
1475        // Point within 100km
1476        let nearby = GeoPoint::new(40.8, -74.1).unwrap();
1477        assert!(GeoFilterExecutor::matches(&filter, &nearby));
1478
1479        // Point far away
1480        let far = GeoPoint::new(34.0522, -118.2437).unwrap();
1481        assert!(!GeoFilterExecutor::matches(&filter, &far));
1482    }
1483
1484    #[test]
1485    fn test_geo_bounding_box() {
1486        let filter = GeoFilter::BoundingBox {
1487            top_left: GeoPoint::new(41.0, -75.0).unwrap(),
1488            bottom_right: GeoPoint::new(40.0, -73.0).unwrap(),
1489        };
1490
1491        let inside = GeoPoint::new(40.5, -74.0).unwrap();
1492        assert!(GeoFilterExecutor::matches(&filter, &inside));
1493
1494        let outside = GeoPoint::new(42.0, -74.0).unwrap();
1495        assert!(!GeoFilterExecutor::matches(&filter, &outside));
1496    }
1497
1498    #[test]
1499    fn test_terms_aggregation() {
1500        let executor = FacetExecutor::new(100);
1501
1502        let values: Vec<Option<serde_json::Value>> = vec![
1503            Some(serde_json::json!("cat")),
1504            Some(serde_json::json!("dog")),
1505            Some(serde_json::json!("cat")),
1506            Some(serde_json::json!("bird")),
1507            Some(serde_json::json!("cat")),
1508        ];
1509
1510        let buckets = executor.terms_aggregation(&values, 10);
1511
1512        assert_eq!(buckets.len(), 3);
1513        assert_eq!(buckets[0].key, "cat");
1514        assert_eq!(buckets[0].doc_count, 3);
1515    }
1516
1517    #[test]
1518    fn test_range_aggregation() {
1519        let executor = FacetExecutor::new(100);
1520
1521        let values: Vec<Option<f64>> =
1522            vec![Some(5.0), Some(15.0), Some(25.0), Some(35.0), Some(45.0)];
1523        let ranges = vec![
1524            RangeBucket {
1525                key: "low".into(),
1526                from: None,
1527                to: Some(20.0),
1528            },
1529            RangeBucket {
1530                key: "medium".into(),
1531                from: Some(20.0),
1532                to: Some(40.0),
1533            },
1534            RangeBucket {
1535                key: "high".into(),
1536                from: Some(40.0),
1537                to: None,
1538            },
1539        ];
1540
1541        let buckets = executor.range_aggregation(&values, &ranges);
1542
1543        assert_eq!(buckets.len(), 3);
1544        assert_eq!(buckets[0].doc_count, 2); // low: 5, 15
1545        assert_eq!(buckets[1].doc_count, 2); // medium: 25, 35
1546        assert_eq!(buckets[2].doc_count, 1); // high: 45
1547    }
1548
1549    #[test]
1550    fn test_numeric_aggregations() {
1551        let executor = FacetExecutor::new(100);
1552        let values: Vec<Option<f64>> = vec![Some(10.0), Some(20.0), Some(30.0)];
1553
1554        assert_eq!(
1555            executor.numeric_aggregation(&values, &AggregationType::Min),
1556            10.0
1557        );
1558        assert_eq!(
1559            executor.numeric_aggregation(&values, &AggregationType::Max),
1560            30.0
1561        );
1562        assert_eq!(
1563            executor.numeric_aggregation(&values, &AggregationType::Avg),
1564            20.0
1565        );
1566        assert_eq!(
1567            executor.numeric_aggregation(&values, &AggregationType::Sum),
1568            60.0
1569        );
1570    }
1571
1572    #[test]
1573    fn test_score_function_weight() {
1574        let func = ScoreFunction::Weight { weight: 2.0 };
1575        let score = ScoreFunctionExecutor::apply(&func, 0.5, None);
1576        assert_eq!(score, 2.0);
1577    }
1578
1579    #[test]
1580    fn test_score_function_field_value() {
1581        let func = ScoreFunction::FieldValue {
1582            field: "popularity".into(),
1583            factor: 0.1,
1584            modifier: FieldValueModifier::Log1p,
1585            missing: 1.0,
1586        };
1587
1588        let metadata = serde_json::json!({"popularity": 100.0});
1589        let score = ScoreFunctionExecutor::apply(&func, 0.5, Some(&metadata));
1590
1591        // log1p(100) * 0.1 ≈ 2.004 * 0.1 ≈ 0.2004
1592        assert!(score > 0.19 && score < 0.21);
1593    }
1594
1595    #[test]
1596    fn test_combine_scores() {
1597        let scores = vec![2.0f32, 3.0, 4.0];
1598
1599        assert_eq!(
1600            ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Multiply),
1601            24.0
1602        );
1603        assert_eq!(
1604            ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Sum),
1605            9.0
1606        );
1607        assert_eq!(
1608            ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Average),
1609            3.0
1610        );
1611        assert_eq!(
1612            ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Max),
1613            4.0
1614        );
1615        assert_eq!(
1616            ScoreFunctionExecutor::combine_scores(&scores, ScoreMode::Min),
1617            2.0
1618        );
1619    }
1620
1621    #[test]
1622    fn test_filter_term() {
1623        let filter = FilterExpression::Term {
1624            field: "category".into(),
1625            value: serde_json::json!("tech"),
1626        };
1627
1628        let metadata = serde_json::json!({"category": "tech"});
1629        assert!(FilterExecutor::matches(&filter, Some(&metadata)));
1630
1631        let other = serde_json::json!({"category": "science"});
1632        assert!(!FilterExecutor::matches(&filter, Some(&other)));
1633    }
1634
1635    #[test]
1636    fn test_filter_range() {
1637        let filter = FilterExpression::Range {
1638            field: "price".into(),
1639            gte: Some(10.0),
1640            gt: None,
1641            lte: Some(100.0),
1642            lt: None,
1643        };
1644
1645        let match1 = serde_json::json!({"price": 50});
1646        assert!(FilterExecutor::matches(&filter, Some(&match1)));
1647
1648        let nomatch = serde_json::json!({"price": 5});
1649        assert!(!FilterExecutor::matches(&filter, Some(&nomatch)));
1650    }
1651
1652    #[test]
1653    fn test_filter_boolean_and() {
1654        let filter = FilterExpression::And(vec![
1655            FilterExpression::Term {
1656                field: "category".into(),
1657                value: serde_json::json!("tech"),
1658            },
1659            FilterExpression::Range {
1660                field: "price".into(),
1661                gte: Some(10.0),
1662                gt: None,
1663                lte: None,
1664                lt: None,
1665            },
1666        ]);
1667
1668        let match1 = serde_json::json!({"category": "tech", "price": 50});
1669        assert!(FilterExecutor::matches(&filter, Some(&match1)));
1670
1671        let nomatch = serde_json::json!({"category": "tech", "price": 5});
1672        assert!(!FilterExecutor::matches(&filter, Some(&nomatch)));
1673    }
1674
1675    #[test]
1676    fn test_query_explanation() {
1677        let id = "doc123".to_string();
1678        let scoring = ScoringConfig {
1679            functions: vec![ScoreFunction::Weight { weight: 1.5 }],
1680            ..Default::default()
1681        };
1682
1683        let explanation = QueryExplainer::explain(&id, 0.8, Some(&scoring), None);
1684
1685        assert!(explanation.score > 0.0);
1686        assert!(!explanation.details.is_empty());
1687        assert!(explanation.description.contains("doc123"));
1688    }
1689
1690    #[test]
1691    fn test_batch_query_executor() {
1692        let config = BatchQueryConfig::default();
1693        let executor = BatchQueryExecutor::new(config);
1694
1695        let request = BatchQueryRequest {
1696            namespace: "test".into(),
1697            queries: vec![
1698                BatchQueryItem {
1699                    query_id: "q1".into(),
1700                    vector: vec![1.0, 0.0, 0.0],
1701                    top_k: 5,
1702                    filter: None,
1703                    cursor: None,
1704                    scoring: None,
1705                },
1706                BatchQueryItem {
1707                    query_id: "q2".into(),
1708                    vector: vec![0.0, 1.0, 0.0],
1709                    top_k: 5,
1710                    filter: None,
1711                    cursor: None,
1712                    scoring: None,
1713                },
1714            ],
1715            include_vectors: false,
1716            include_metadata: true,
1717            facets: None,
1718        };
1719
1720        // Mock search function
1721        let search_fn = |_vector: &[f32], _top_k: usize, _filter: Option<&FilterExpression>| {
1722            vec![
1723                (
1724                    "doc1".into(),
1725                    0.9,
1726                    None,
1727                    Some(serde_json::json!({"cat": "a"})),
1728                ),
1729                (
1730                    "doc2".into(),
1731                    0.8,
1732                    None,
1733                    Some(serde_json::json!({"cat": "b"})),
1734                ),
1735            ]
1736        };
1737
1738        let response = executor.execute(&request, search_fn).unwrap();
1739
1740        assert_eq!(response.responses.len(), 2);
1741        assert_eq!(response.stats.total_queries, 2);
1742        assert_eq!(response.stats.successful_queries, 2);
1743    }
1744
1745    #[test]
1746    fn test_batch_too_large_error() {
1747        let config = BatchQueryConfig {
1748            max_batch_size: 2,
1749            ..Default::default()
1750        };
1751        let executor = BatchQueryExecutor::new(config);
1752
1753        let request = BatchQueryRequest {
1754            namespace: "test".into(),
1755            queries: vec![
1756                BatchQueryItem {
1757                    query_id: "q1".into(),
1758                    vector: vec![1.0],
1759                    top_k: 5,
1760                    filter: None,
1761                    cursor: None,
1762                    scoring: None,
1763                },
1764                BatchQueryItem {
1765                    query_id: "q2".into(),
1766                    vector: vec![1.0],
1767                    top_k: 5,
1768                    filter: None,
1769                    cursor: None,
1770                    scoring: None,
1771                },
1772                BatchQueryItem {
1773                    query_id: "q3".into(),
1774                    vector: vec![1.0],
1775                    top_k: 5,
1776                    filter: None,
1777                    cursor: None,
1778                    scoring: None,
1779                },
1780            ],
1781            include_vectors: false,
1782            include_metadata: false,
1783            facets: None,
1784        };
1785
1786        let result = executor.execute(&request, |_, _, _| vec![]);
1787        assert!(matches!(result, Err(BatchSearchError::BatchTooLarge(3, 2))));
1788    }
1789
1790    #[test]
1791    fn test_base64_roundtrip() {
1792        let original = "hello:world:123";
1793        let encoded = base64_encode(original);
1794        let decoded = base64_decode(&encoded).unwrap();
1795        assert_eq!(decoded, original);
1796    }
1797
1798    #[test]
1799    fn test_distance_unit_conversion() {
1800        assert_eq!(DistanceUnit::Meters.to_meters(100.0), 100.0);
1801        assert_eq!(DistanceUnit::Kilometers.to_meters(1.0), 1000.0);
1802        assert!((DistanceUnit::Miles.to_meters(1.0) - 1609.34).abs() < 0.01);
1803    }
1804
1805    #[test]
1806    fn test_query_deduplication() {
1807        let config = BatchQueryConfig {
1808            deduplicate_queries: true,
1809            ..Default::default()
1810        };
1811        let executor = BatchQueryExecutor::new(config);
1812
1813        // Same vector twice
1814        let request = BatchQueryRequest {
1815            namespace: "test".into(),
1816            queries: vec![
1817                BatchQueryItem {
1818                    query_id: "q1".into(),
1819                    vector: vec![1.0, 0.0],
1820                    top_k: 5,
1821                    filter: None,
1822                    cursor: None,
1823                    scoring: None,
1824                },
1825                BatchQueryItem {
1826                    query_id: "q2".into(),
1827                    vector: vec![1.0, 0.0],
1828                    top_k: 5,
1829                    filter: None,
1830                    cursor: None,
1831                    scoring: None,
1832                },
1833            ],
1834            include_vectors: false,
1835            include_metadata: false,
1836            facets: None,
1837        };
1838
1839        let response = executor
1840            .execute(&request, |_, _, _| vec![("doc1".into(), 0.9, None, None)])
1841            .unwrap();
1842
1843        assert_eq!(response.stats.deduplicated_count, 1);
1844    }
1845}