Skip to main content

reddb_server/storage/query/executors/
hybrid.rs

1//! Hybrid Query Executor
2//!
3//! Executes HYBRID queries that combine structured (SQL/Graph) queries with
4//! vector similarity search, using various fusion strategies to merge results.
5//!
6//! # Fusion Strategies
7//!
8//! - **Rerank**: Re-ranks structured results by vector similarity
9//! - **FilterThenSearch**: Filters first, then searches vectors
10//! - **SearchThenFilter**: Searches vectors first, then applies structured filter
11//! - **RRF (Reciprocal Rank Fusion)**: Combines rankings fairly
12//! - **Intersection**: Only returns results matching both queries
13//! - **Union**: Returns results from either query with combined scores
14
15use std::collections::{HashMap, HashSet};
16use std::sync::Arc;
17use std::thread;
18
19use crate::storage::engine::graph_store::GraphStore;
20use crate::storage::engine::graph_table_index::GraphTableIndex;
21use crate::storage::engine::unified_index::UnifiedIndex;
22use crate::storage::engine::vector_store::VectorStore;
23use crate::storage::query::ast::{FusionStrategy, HybridQuery, VectorQuery};
24use crate::storage::query::unified::{
25    ExecutionError, QueryStats, UnifiedExecutor, UnifiedRecord, UnifiedResult,
26};
27use crate::storage::schema::Value;
28
29use super::vector::VectorExecutor;
30
31/// Hybrid query executor that combines structured and vector results
32pub struct HybridExecutor {
33    /// Structured query executor
34    unified: UnifiedExecutor,
35    /// Vector search executor
36    vector: VectorExecutor,
37    /// Cross-reference index for linking results
38    unified_index: Option<Arc<UnifiedIndex>>,
39}
40
41impl HybridExecutor {
42    /// Create a new hybrid executor
43    pub fn new(
44        graph: Arc<GraphStore>,
45        index: Arc<GraphTableIndex>,
46        vector_store: Arc<VectorStore>,
47    ) -> Self {
48        let unified = UnifiedExecutor::new(Arc::clone(&graph), Arc::clone(&index));
49        let vector = VectorExecutor::new(vector_store);
50
51        Self {
52            unified,
53            vector,
54            unified_index: None,
55        }
56    }
57
58    /// Add cross-reference support
59    pub fn with_unified_index(mut self, index: Arc<UnifiedIndex>) -> Self {
60        self.unified_index = Some(Arc::clone(&index));
61        self.vector = self.vector.with_unified_index(index);
62        self
63    }
64
65    /// Execute a hybrid query
66    pub fn execute(&self, query: &HybridQuery) -> Result<UnifiedResult, ExecutionError> {
67        let start = std::time::Instant::now();
68
69        // Execute based on fusion strategy
70        let mut result = match &query.fusion {
71            FusionStrategy::Rerank { weight } => self.execute_rerank(query, *weight)?,
72            FusionStrategy::FilterThenSearch => self.execute_filter_then_search(query)?,
73            FusionStrategy::SearchThenFilter => self.execute_search_then_filter(query)?,
74            FusionStrategy::RRF { k } => self.execute_rrf(query, *k)?,
75            FusionStrategy::Intersection => self.execute_intersection(query)?,
76            FusionStrategy::Union {
77                structured_weight,
78                vector_weight,
79            } => self.execute_union(query, *structured_weight, *vector_weight)?,
80        };
81
82        // Apply limit if specified
83        if let Some(limit) = query.limit {
84            result.records.truncate(limit);
85        }
86
87        // Update stats
88        result.stats.exec_time_us = start.elapsed().as_micros() as u64;
89
90        Ok(result)
91    }
92
93    // =========================================================================
94    // Fusion Strategies
95    // =========================================================================
96
97    /// Rerank: Execute structured query, then re-rank by vector similarity
98    fn execute_rerank(
99        &self,
100        query: &HybridQuery,
101        weight: f32,
102    ) -> Result<UnifiedResult, ExecutionError> {
103        // 1. Execute structured query
104        let structured_result = self.unified.execute(&query.structured)?;
105
106        if structured_result.is_empty() {
107            return Ok(structured_result);
108        }
109
110        // 2. Execute vector query
111        let vector_result = self.vector.execute(&query.vector)?;
112
113        // 3. Build vector distance lookup
114        let mut vector_distances: HashMap<String, f32> = HashMap::new();
115        for record in &vector_result.records {
116            for vsr in &record.vector_results {
117                // Use vector ID as key
118                let key = format!("{}:{}", vsr.collection, vsr.id);
119                vector_distances.insert(key, vsr.distance);
120            }
121        }
122
123        // 4. Score and rerank structured results
124        let mut scored: Vec<(String, UnifiedRecord, f32)> = structured_result
125            .records
126            .into_iter()
127            .enumerate()
128            .map(|(rank, record)| {
129                // Structured score: inverse rank (higher = better)
130                let struct_score = 1.0 / (rank as f32 + 1.0);
131
132                // Vector score: try to find matching vector via cross-reference
133                let vector_score = self.get_vector_score_for_record(&record, &vector_distances);
134
135                // Combined score
136                let combined = (1.0 - weight) * struct_score + weight * vector_score;
137                (self.record_to_key(&record), record, combined)
138            })
139            .collect();
140
141        // Sort by combined score (descending), then deterministic key
142        scored.sort_by(
143            |a, b| match b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal) {
144                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
145                ordering => ordering,
146            },
147        );
148
149        // Build result
150        let mut result = UnifiedResult::with_columns(structured_result.columns);
151        result.stats = structured_result.stats;
152
153        for (_key, mut record, score) in scored {
154            record.set_arc(Arc::from("_hybrid_score"), Value::Float(score as f64));
155            result.push(record);
156        }
157
158        Ok(result)
159    }
160
161    /// FilterThenSearch: Use structured results to filter vector search space
162    fn execute_filter_then_search(
163        &self,
164        query: &HybridQuery,
165    ) -> Result<UnifiedResult, ExecutionError> {
166        // 1. Execute structured query to get filter candidates
167        let structured_result = self.unified.execute(&query.structured)?;
168
169        if structured_result.is_empty() {
170            return Ok(structured_result);
171        }
172
173        // 2. Extract IDs from structured results for filtering
174        let candidate_ids: HashSet<u64> = structured_result
175            .records
176            .iter()
177            .filter_map(|r| {
178                // Try to get ID from values
179                r.get("id").and_then(|v| match v {
180                    Value::Integer(i) => Some(*i as u64),
181                    _ => None,
182                })
183            })
184            .collect();
185
186        // 3. Execute vector query
187        let vector_result = self.vector.execute(&query.vector)?;
188
189        // 4. Filter vector results to only include structured candidates
190        let mut result = UnifiedResult::with_columns(vector_result.columns.clone());
191
192        for record in vector_result.records {
193            // Check if this vector result matches any structured candidate
194            let matches = record.vector_results.iter().any(|vsr| {
195                candidate_ids.contains(&vsr.id) ||
196                // Also check linked row if available
197                vsr.linked_row.as_ref().map(|(_, row_id)| candidate_ids.contains(row_id)).unwrap_or(false)
198            });
199
200            if matches {
201                result.push(record);
202            }
203        }
204
205        result.stats = QueryStats::merge(&structured_result.stats, &vector_result.stats);
206        Ok(result)
207    }
208
209    /// SearchThenFilter: Vector search first, then apply structured filter
210    fn execute_search_then_filter(
211        &self,
212        query: &HybridQuery,
213    ) -> Result<UnifiedResult, ExecutionError> {
214        // 1. Execute vector query first
215        let vector_result = self.vector.execute(&query.vector)?;
216
217        if vector_result.is_empty() {
218            return Ok(vector_result);
219        }
220
221        // 2. Execute structured query
222        let structured_result = self.unified.execute(&query.structured)?;
223
224        // 3. Extract IDs from structured results
225        let structured_ids: HashSet<u64> = structured_result
226            .records
227            .iter()
228            .filter_map(|r| {
229                r.get("id").and_then(|v| match v {
230                    Value::Integer(i) => Some(*i as u64),
231                    _ => None,
232                })
233            })
234            .collect();
235
236        // 4. Filter vector results to match structured query
237        let mut result = UnifiedResult::with_columns(vector_result.columns.clone());
238
239        for record in vector_result.records {
240            let matches = record.vector_results.iter().any(|vsr| {
241                structured_ids.contains(&vsr.id)
242                    || vsr
243                        .linked_row
244                        .as_ref()
245                        .map(|(_, row_id)| structured_ids.contains(row_id))
246                        .unwrap_or(false)
247            });
248
249            if matches {
250                result.push(record);
251            }
252        }
253
254        result.stats = QueryStats::merge(&vector_result.stats, &structured_result.stats);
255        Ok(result)
256    }
257
258    /// RRF: Reciprocal Rank Fusion
259    /// Combines rankings using: RRF(d) = Σ(1 / (k + rank(d)))
260    /// Execute structured and vector arms concurrently via
261    /// [`std::thread::scope`].
262    ///
263    /// Used by fusion strategies that always run both arms to completion
264    /// (RRF, Intersection, Union). Short-circuiting strategies (Rerank,
265    /// FilterThenSearch, SearchThenFilter) keep serial execution because
266    /// they check for early-exit conditions on the first arm before
267    /// deciding whether to run the second.
268    ///
269    /// Worst-case total latency collapses from `structured + vector` to
270    /// `max(structured, vector)` — the planner's pessimistic estimate for
271    /// hybrid queries is now tight when both arms dominate.
272    fn execute_structured_and_vector_parallel(
273        &self,
274        query: &HybridQuery,
275    ) -> Result<(UnifiedResult, UnifiedResult), ExecutionError> {
276        thread::scope(|s| {
277            let structured_handle = s.spawn(|| self.unified.execute(&query.structured));
278            let vector_handle = s.spawn(|| self.vector.execute(&query.vector));
279
280            // `join` returns `Result<T, Box<dyn Any + Send>>`; a panic in
281            // either arm is surfaced as an `ExecutionError` so callers
282            // don't see a raw thread panic.
283            let structured = structured_handle
284                .join()
285                .map_err(|_| ExecutionError::new("hybrid: structured arm panicked"))??;
286            let vector = vector_handle
287                .join()
288                .map_err(|_| ExecutionError::new("hybrid: vector arm panicked"))??;
289            Ok((structured, vector))
290        })
291    }
292
293    fn execute_rrf(&self, query: &HybridQuery, k: u32) -> Result<UnifiedResult, ExecutionError> {
294        // 1. Execute both queries in parallel — RRF always consumes both.
295        let (structured_result, vector_result) =
296            self.execute_structured_and_vector_parallel(query)?;
297
298        // 2. Build rank maps (lower rank = better, starting from 1)
299        let mut structured_ranks: HashMap<String, u32> = HashMap::new();
300        for (rank, record) in structured_result.records.iter().enumerate() {
301            let key = self.record_to_key(record);
302            structured_ranks.insert(key, (rank + 1) as u32);
303        }
304
305        let mut vector_ranks: HashMap<String, u32> = HashMap::new();
306        for (rank, record) in vector_result.records.iter().enumerate() {
307            let key = self.record_to_key(record);
308            vector_ranks.insert(key, (rank + 1) as u32);
309        }
310
311        // 3. Calculate RRF scores for all unique records
312        let all_keys: HashSet<_> = structured_ranks
313            .keys()
314            .chain(vector_ranks.keys())
315            .cloned()
316            .collect();
317
318        let k_f64 = k as f64;
319        let mut rrf_scores: Vec<(String, f64)> = all_keys
320            .into_iter()
321            .map(|key| {
322                let struct_contrib = structured_ranks
323                    .get(&key)
324                    .map(|r| 1.0 / (k_f64 + *r as f64))
325                    .unwrap_or(0.0);
326                let vector_contrib = vector_ranks
327                    .get(&key)
328                    .map(|r| 1.0 / (k_f64 + *r as f64))
329                    .unwrap_or(0.0);
330                (key, struct_contrib + vector_contrib)
331            })
332            .collect();
333
334        // Sort by RRF score (descending)
335        rrf_scores.sort_by(|a, b| {
336            match b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal) {
337                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
338                ordering => ordering,
339            }
340        });
341
342        // 4. Build result from scored records
343        let mut record_map: HashMap<String, UnifiedRecord> = HashMap::new();
344        for record in structured_result.records {
345            let key = self.record_to_key(&record);
346            record_map.insert(key, record);
347        }
348        for record in vector_result.records {
349            let key = self.record_to_key(&record);
350            if let Some(existing) = record_map.get_mut(&key) {
351                // Merge vector results
352                existing.vector_results.extend(record.vector_results);
353            } else {
354                record_map.insert(key, record);
355            }
356        }
357
358        // Build final result in RRF order
359        let mut columns = structured_result.columns.clone();
360        for col in &vector_result.columns {
361            if !columns.contains(col) {
362                columns.push(col.clone());
363            }
364        }
365
366        let mut result = UnifiedResult::with_columns(columns);
367        result.stats = QueryStats::merge(&structured_result.stats, &vector_result.stats);
368
369        for (key, score) in rrf_scores {
370            if let Some(mut record) = record_map.remove(&key) {
371                record.set_arc(Arc::from("_rrf_score"), Value::Float(score));
372                result.push(record);
373            }
374        }
375
376        Ok(result)
377    }
378
379    /// Intersection: Only return results present in both
380    fn execute_intersection(&self, query: &HybridQuery) -> Result<UnifiedResult, ExecutionError> {
381        // 1. Execute both queries in parallel — intersection needs both
382        //    result sets before it can filter.
383        let (structured_result, vector_result) =
384            self.execute_structured_and_vector_parallel(query)?;
385
386        // 2. Build key sets
387        let structured_keys: HashSet<String> = structured_result
388            .records
389            .iter()
390            .map(|r| self.record_to_key(r))
391            .collect();
392
393        // 3. Filter vector results to only those in structured
394        let mut result = UnifiedResult::with_columns(vector_result.columns.clone());
395
396        for record in vector_result.records {
397            let key = self.record_to_key(&record);
398            if structured_keys.contains(&key) {
399                result.push(record);
400            }
401        }
402
403        result.stats = QueryStats::merge(&structured_result.stats, &vector_result.stats);
404        Ok(result)
405    }
406
407    /// Union: Combine results with weighted scores
408    fn execute_union(
409        &self,
410        query: &HybridQuery,
411        struct_weight: f32,
412        vector_weight: f32,
413    ) -> Result<UnifiedResult, ExecutionError> {
414        // 1. Execute both queries in parallel — union merges both result
415        //    sets with weighted scores, so neither arm can be skipped.
416        let (structured_result, vector_result) =
417            self.execute_structured_and_vector_parallel(query)?;
418
419        // 2. Score and collect all records
420        let mut scored_records: HashMap<String, (UnifiedRecord, f32)> = HashMap::new();
421
422        // Add structured results with score based on rank
423        for (rank, record) in structured_result.records.into_iter().enumerate() {
424            let key = self.record_to_key(&record);
425            let score = struct_weight * (1.0 / (rank as f32 + 1.0));
426            scored_records.insert(key, (record, score));
427        }
428
429        // Add/merge vector results
430        for (rank, record) in vector_result.records.into_iter().enumerate() {
431            let key = self.record_to_key(&record);
432            let vector_score = vector_weight * (1.0 / (rank as f32 + 1.0));
433
434            if let Some((existing, score)) = scored_records.get_mut(&key) {
435                // Merge: add vector score and vector results
436                *score += vector_score;
437                existing.vector_results.extend(record.vector_results);
438            } else {
439                scored_records.insert(key, (record, vector_score));
440            }
441        }
442
443        // 3. Sort by combined score
444        let mut sorted: Vec<(String, UnifiedRecord, f32)> = scored_records
445            .into_iter()
446            .map(|(key, (record, score))| (key, record, score))
447            .collect();
448        sorted.sort_by(
449            |a, b| match b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal) {
450                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
451                ordering => ordering,
452            },
453        );
454
455        // 4. Build result
456        let mut columns = structured_result.columns.clone();
457        for col in &vector_result.columns {
458            if !columns.contains(col) {
459                columns.push(col.clone());
460            }
461        }
462
463        let mut result = UnifiedResult::with_columns(columns);
464        result.stats = QueryStats::merge(&structured_result.stats, &vector_result.stats);
465
466        for (_key, mut record, score) in sorted {
467            record.set_arc(Arc::from("_union_score"), Value::Float(score as f64));
468            result.push(record);
469        }
470
471        Ok(result)
472    }
473
474    // =========================================================================
475    // Helper Methods
476    // =========================================================================
477
478    /// Get a unique key for a record (for deduplication)
479    fn record_to_key(&self, record: &UnifiedRecord) -> String {
480        // Try various ways to identify the record
481        if let Some(Value::Integer(id)) = record.get("id") {
482            return format!("row:{}", id);
483        }
484        if let Some(first_node) = record.nodes.values().next() {
485            return format!("node:{}", first_node.id);
486        }
487        if let Some(first_vsr) = record.vector_results.first() {
488            return format!("vec:{}:{}", first_vsr.collection, first_vsr.id);
489        }
490        // Fallback: hash of all visible fields
491        let fields: Vec<_> = record.iter_fields().collect();
492        format!("hash:{:?}", fields)
493    }
494
495    /// Get vector similarity score for a structured record
496    fn get_vector_score_for_record(
497        &self,
498        record: &UnifiedRecord,
499        vector_distances: &HashMap<String, f32>,
500    ) -> f32 {
501        // Try to find matching vector via ID
502        if let Some(Value::Integer(id)) = record.get("id") {
503            // Check all collections in vector_distances
504            for (key, distance) in vector_distances {
505                if key.ends_with(&format!(":{}", id)) {
506                    // Convert distance to similarity (lower distance = higher similarity)
507                    return 1.0 / (1.0 + distance);
508                }
509            }
510        }
511
512        // Try via cross-reference if available
513        if let Some(ref unified_index) = self.unified_index {
514            if let Some(Value::Integer(id)) = record.get("id") {
515                // Look up if this row has a linked vector
516                // This requires the unified_index to track row->vector mappings
517                // For now, return 0 if no match found
518            }
519        }
520
521        0.0 // No vector match found
522    }
523}
524
525// ============================================================================
526// QueryStats Helper
527// ============================================================================
528
529impl QueryStats {
530    /// Merge two QueryStats
531    fn merge(a: &QueryStats, b: &QueryStats) -> QueryStats {
532        QueryStats {
533            nodes_scanned: a.nodes_scanned + b.nodes_scanned,
534            edges_scanned: a.edges_scanned + b.edges_scanned,
535            rows_scanned: a.rows_scanned + b.rows_scanned,
536            exec_time_us: a.exec_time_us + b.exec_time_us,
537        }
538    }
539}
540
541// ============================================================================
542// In-Memory Hybrid Executor for Testing
543// ============================================================================
544
545use super::vector::InMemoryVectorExecutor;
546
547/// In-memory hybrid executor for testing
548pub struct InMemoryHybridExecutor {
549    /// Records keyed by ID
550    records: HashMap<u64, UnifiedRecord>,
551    /// Vector executor
552    vector: InMemoryVectorExecutor,
553}
554
555impl InMemoryHybridExecutor {
556    /// Create a new in-memory hybrid executor
557    pub fn new() -> Self {
558        Self {
559            records: HashMap::new(),
560            vector: InMemoryVectorExecutor::new(),
561        }
562    }
563
564    /// Add a structured record
565    pub fn add_record(&mut self, id: u64, values: HashMap<String, Value>) {
566        let mut record = UnifiedRecord::new();
567        for (k, v) in values {
568            record.set_owned(k, v);
569        }
570        record.set_arc(Arc::from("id"), Value::Integer(id as i64));
571        self.records.insert(id, record);
572    }
573
574    /// Add a vector with optional link to record
575    pub fn add_vector(
576        &mut self,
577        collection: &str,
578        id: u64,
579        vector: Vec<f32>,
580        linked_record_id: Option<u64>,
581    ) {
582        use crate::storage::engine::vector_metadata::MetadataValue;
583        let mut meta = HashMap::new();
584        if let Some(record_id) = linked_record_id {
585            meta.insert(
586                "_linked_record".to_string(),
587                MetadataValue::Integer(record_id as i64),
588            );
589        }
590        let meta = if meta.is_empty() { None } else { Some(meta) };
591        self.vector.add_vector(collection, id, vector, meta);
592    }
593
594    /// Execute a hybrid query with manual fusion
595    pub fn execute_with_fusion(
596        &self,
597        structured_ids: &[u64],
598        vector_query: &VectorQuery,
599        fusion: &FusionStrategy,
600    ) -> Result<UnifiedResult, ExecutionError> {
601        // Execute vector query
602        let vector_result = self.vector.execute(vector_query)?;
603
604        // Get structured records
605        let structured: Vec<_> = structured_ids
606            .iter()
607            .filter_map(|id| self.records.get(id).cloned())
608            .collect();
609
610        // Apply fusion strategy
611        match fusion {
612            FusionStrategy::Rerank { weight } => {
613                self.fuse_rerank(structured, vector_result, *weight)
614            }
615            FusionStrategy::Intersection => self.fuse_intersection(structured, vector_result),
616            FusionStrategy::RRF { k } => self.fuse_rrf(structured, vector_result, *k),
617            _ => {
618                // Default: just return vector results
619                Ok(vector_result)
620            }
621        }
622    }
623
624    fn fuse_rerank(
625        &self,
626        structured: Vec<UnifiedRecord>,
627        vector_result: UnifiedResult,
628        weight: f32,
629    ) -> Result<UnifiedResult, ExecutionError> {
630        let mut scored: Vec<(String, UnifiedRecord, f32)> = Vec::new();
631
632        for (rank, record) in structured.into_iter().enumerate() {
633            let struct_score = 1.0 / (rank as f32 + 1.0);
634            let vector_score = self.get_vector_score(&record, &vector_result);
635            let combined = (1.0 - weight) * struct_score + weight * vector_score;
636            let key = self.record_to_key_in_memory(&record);
637            scored.push((key, record, combined));
638        }
639
640        scored.sort_by(
641            |a, b| match b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal) {
642                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
643                ordering => ordering,
644            },
645        );
646
647        let mut result = UnifiedResult::with_columns(vec!["id".to_string()]);
648        for (_key, mut record, score) in scored {
649            record.set_arc(Arc::from("_hybrid_score"), Value::Float(score as f64));
650            result.push(record);
651        }
652
653        Ok(result)
654    }
655
656    fn fuse_intersection(
657        &self,
658        structured: Vec<UnifiedRecord>,
659        vector_result: UnifiedResult,
660    ) -> Result<UnifiedResult, ExecutionError> {
661        let struct_ids: HashSet<i64> = structured
662            .iter()
663            .filter_map(|r| match r.get("id") {
664                Some(Value::Integer(i)) => Some(*i),
665                _ => None,
666            })
667            .collect();
668
669        let mut result = UnifiedResult::with_columns(vector_result.columns.clone());
670
671        for record in vector_result.records {
672            if let Some(vsr) = record.vector_results.first() {
673                if struct_ids.contains(&(vsr.id as i64)) {
674                    result.push(record);
675                }
676            }
677        }
678
679        Ok(result)
680    }
681
682    fn fuse_rrf(
683        &self,
684        structured: Vec<UnifiedRecord>,
685        vector_result: UnifiedResult,
686        k: u32,
687    ) -> Result<UnifiedResult, ExecutionError> {
688        let k_f64 = k as f64;
689
690        // Build ID -> structured rank map
691        let struct_ranks: HashMap<i64, u32> = structured
692            .iter()
693            .enumerate()
694            .filter_map(|(rank, r)| match r.get("id") {
695                Some(Value::Integer(i)) => Some((*i, (rank + 1) as u32)),
696                _ => None,
697            })
698            .collect();
699
700        // Calculate RRF scores for vector results
701        let mut scored: Vec<(String, UnifiedRecord, f64)> = Vec::new();
702
703        for (rank, record) in vector_result.records.into_iter().enumerate() {
704            let vector_contrib = 1.0 / (k_f64 + (rank + 1) as f64);
705
706            let struct_contrib = record
707                .vector_results
708                .first()
709                .and_then(|vsr| struct_ranks.get(&(vsr.id as i64)))
710                .map(|r| 1.0 / (k_f64 + *r as f64))
711                .unwrap_or(0.0);
712
713            let rrf_score = struct_contrib + vector_contrib;
714            let key = self.record_to_key_in_memory(&record);
715            scored.push((key, record, rrf_score));
716        }
717
718        scored.sort_by(
719            |a, b| match b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal) {
720                std::cmp::Ordering::Equal => a.0.cmp(&b.0),
721                ordering => ordering,
722            },
723        );
724
725        let mut result =
726            UnifiedResult::with_columns(vec!["id".to_string(), "distance".to_string()]);
727        for (_key, mut record, score) in scored {
728            record.set_arc(Arc::from("_rrf_score"), Value::Float(score));
729            result.push(record);
730        }
731
732        Ok(result)
733    }
734
735    fn get_vector_score(&self, record: &UnifiedRecord, vector_result: &UnifiedResult) -> f32 {
736        if let Some(Value::Integer(id)) = record.get("id") {
737            for vr in &vector_result.records {
738                for vsr in &vr.vector_results {
739                    if vsr.id == *id as u64 {
740                        return 1.0 / (1.0 + vsr.distance);
741                    }
742                }
743            }
744        }
745        0.0
746    }
747
748    fn record_to_key_in_memory(&self, record: &UnifiedRecord) -> String {
749        if let Some(Value::Integer(id)) = record.get("id") {
750            return format!("row:{}", id);
751        }
752        if let Some(first_vsr) = record.vector_results.first() {
753            return format!("vec:{}:{}", first_vsr.collection, first_vsr.id);
754        }
755        let fields: Vec<_> = record.iter_fields().collect();
756        format!("hash:{:?}", fields)
757    }
758}
759
760impl Default for InMemoryHybridExecutor {
761    fn default() -> Self {
762        Self::new()
763    }
764}
765
766// ============================================================================
767// Tests
768// ============================================================================
769
770#[cfg(test)]
771mod tests {
772    use super::*;
773    use crate::storage::engine::distance::DistanceMetric;
774    use crate::storage::query::ast::VectorSource;
775
776    #[test]
777    fn test_in_memory_hybrid_rerank() {
778        let mut executor = InMemoryHybridExecutor::new();
779
780        // Add structured records
781        let mut vals1 = HashMap::new();
782        vals1.insert("name".to_string(), Value::text("host1".to_string()));
783        executor.add_record(1, vals1);
784
785        let mut vals2 = HashMap::new();
786        vals2.insert("name".to_string(), Value::text("host2".to_string()));
787        executor.add_record(2, vals2);
788
789        let mut vals3 = HashMap::new();
790        vals3.insert("name".to_string(), Value::text("host3".to_string()));
791        executor.add_record(3, vals3);
792
793        // Add vectors (host3 is most similar to query)
794        executor.add_vector("hosts", 1, vec![0.1, 0.0], Some(1));
795        executor.add_vector("hosts", 2, vec![0.5, 0.5], Some(2));
796        executor.add_vector("hosts", 3, vec![0.99, 0.0], Some(3)); // Closest to query
797
798        let query = VectorQuery {
799            alias: None,
800            collection: "hosts".to_string(),
801            query_vector: VectorSource::Literal(vec![1.0, 0.0]),
802            k: 3,
803            filter: None,
804            metric: Some(DistanceMetric::L2),
805            include_vectors: false,
806            include_metadata: false,
807            threshold: None,
808        };
809
810        // With pure structural ranking (weight=0), order should be 1, 2, 3
811        let result = executor
812            .execute_with_fusion(&[1, 2, 3], &query, &FusionStrategy::Rerank { weight: 0.0 })
813            .unwrap();
814
815        assert_eq!(result.len(), 3);
816        assert_eq!(result.records[0].get("id"), Some(&Value::Integer(1)));
817
818        // With pure vector ranking (weight=1), order should be 3, 1, 2
819        let result = executor
820            .execute_with_fusion(&[1, 2, 3], &query, &FusionStrategy::Rerank { weight: 1.0 })
821            .unwrap();
822
823        assert_eq!(result.len(), 3);
824        assert_eq!(result.records[0].get("id"), Some(&Value::Integer(3)));
825    }
826
827    #[test]
828    fn test_in_memory_hybrid_intersection() {
829        let mut executor = InMemoryHybridExecutor::new();
830
831        // Add records 1-5
832        for i in 1..=5 {
833            let mut vals = HashMap::new();
834            vals.insert("name".to_string(), Value::text(format!("host{}", i)));
835            executor.add_record(i, vals);
836        }
837
838        // Add vectors for only 2, 3, 4
839        executor.add_vector("hosts", 2, vec![0.1, 0.0], Some(2));
840        executor.add_vector("hosts", 3, vec![0.5, 0.5], Some(3));
841        executor.add_vector("hosts", 4, vec![0.9, 0.0], Some(4));
842
843        let query = VectorQuery {
844            alias: None,
845            collection: "hosts".to_string(),
846            query_vector: VectorSource::Literal(vec![1.0, 0.0]),
847            k: 10,
848            filter: None,
849            metric: Some(DistanceMetric::L2),
850            include_vectors: false,
851            include_metadata: false,
852            threshold: None,
853        };
854
855        // Intersection of structured [1,2,3] and vectors [2,3,4] should be [2,3]
856        let result = executor
857            .execute_with_fusion(&[1, 2, 3], &query, &FusionStrategy::Intersection)
858            .unwrap();
859
860        assert_eq!(result.len(), 2);
861
862        let ids: HashSet<i64> = result
863            .records
864            .iter()
865            .filter_map(|r| match r.get("id") {
866                Some(Value::Integer(i)) => Some(*i),
867                _ => None,
868            })
869            .collect();
870
871        assert!(ids.contains(&2));
872        assert!(ids.contains(&3));
873    }
874
875    #[test]
876    fn test_in_memory_hybrid_rrf() {
877        let mut executor = InMemoryHybridExecutor::new();
878
879        for i in 1..=4 {
880            let mut vals = HashMap::new();
881            vals.insert("name".to_string(), Value::text(format!("host{}", i)));
882            executor.add_record(i, vals);
883            executor.add_vector("hosts", i, vec![i as f32 * 0.25, 0.0], Some(i));
884        }
885
886        let query = VectorQuery {
887            alias: None,
888            collection: "hosts".to_string(),
889            query_vector: VectorSource::Literal(vec![1.0, 0.0]),
890            k: 4,
891            filter: None,
892            metric: Some(DistanceMetric::L2),
893            include_vectors: false,
894            include_metadata: false,
895            threshold: None,
896        };
897
898        // RRF with k=60
899        let result = executor
900            .execute_with_fusion(
901                &[1, 2, 3, 4], // Structured order: 1, 2, 3, 4
902                &query,        // Vector order: 4, 3, 2, 1 (by distance to [1.0, 0.0])
903                &FusionStrategy::RRF { k: 60 },
904            )
905            .unwrap();
906
907        assert_eq!(result.len(), 4);
908
909        // All records should have RRF scores
910        for record in &result.records {
911            assert!(record.contains_column("_rrf_score"));
912        }
913    }
914}