Skip to main content

do_memory_core/retrieval/
cascade.rs

1//! Cascading retrieval pipeline (WG-131).
2//!
3//! Implements a 4-tier retrieval cascade:
4//! 1. BM25 keyword index (CPU-local, no API calls)
5//! 2. HDC hyperdimensional encoding (CPU-local, no API calls)
6//! 3. ConceptGraph ontology expansion (CPU-local, no API calls)
7//! 4. API embedding fallback (external API call)
8//!
9//! The cascade eliminates 50-70% of embedding API calls by satisfying
10//! queries from CPU-local tiers before falling back to the API.
11
12use anyhow::Result;
13
14/// Configuration for the cascading retrieval pipeline.
15#[derive(Debug, Clone)]
16pub struct CascadeConfig {
17    /// Number of results to return from each tier.
18    pub top_k: usize,
19    /// Minimum score threshold for BM25 results (0.0-1.0).
20    pub bm25_threshold: f32,
21    /// Minimum similarity threshold for HDC results (0.0-1.0).
22    pub hdc_threshold: f32,
23    /// Minimum confidence threshold for ConceptGraph results (0.0-1.0).
24    pub concept_graph_threshold: f32,
25    /// Whether to merge results across tiers.
26    pub merge_results: bool,
27}
28
29impl Default for CascadeConfig {
30    fn default() -> Self {
31        Self {
32            top_k: 10,
33            bm25_threshold: 0.3,
34            hdc_threshold: 0.5,
35            concept_graph_threshold: 0.4,
36            merge_results: true,
37        }
38    }
39}
40
41/// Result from a single tier in the cascade.
42#[derive(Debug, Clone)]
43pub struct TierResult {
44    /// Tier identifier (bm25, hdc, concept_graph, api).
45    pub tier: String,
46    /// Retrieved episode IDs as strings.
47    pub ids: Vec<String>,
48    /// Normalized scores (0.0-1.0).
49    pub scores: Vec<f32>,
50    /// Whether this tier produced sufficient results.
51    pub sufficient: bool,
52}
53
54/// Final result from the cascading retrieval pipeline.
55#[derive(Debug, Clone)]
56pub struct CascadeResult {
57    /// Final merged/re-ranked episode IDs.
58    pub episode_ids: Vec<String>,
59    /// Final merged/re-ranked scores.
60    pub scores: Vec<f32>,
61    /// Which tier(s) contributed to the final result.
62    pub contributing_tiers: Vec<String>,
63    /// Number of API calls made (should be 0 or 1).
64    pub api_calls: u32,
65}
66
67/// Cascading retrieval orchestrator.
68///
69/// Coordinates the 4-tier retrieval pipeline, falling back to API
70/// only when CPU-local tiers cannot satisfy the query.
71pub struct CascadeRetriever {
72    config: CascadeConfig,
73}
74
75impl CascadeRetriever {
76    /// Create a new cascade retriever with given configuration.
77    pub fn new(config: CascadeConfig) -> Self {
78        Self { config }
79    }
80
81    /// Execute the cascading retrieval pipeline.
82    ///
83    /// This is a placeholder that returns empty results when the `csm`
84    /// feature is not enabled. With `csm` enabled, it uses BM25, HDC,
85    /// and ConceptGraph from the `chaotic_semantic_memory` crate.
86    ///
87    /// Note: This placeholder is non-async. The full CSM implementation
88    /// will be async to allow for concurrent tier queries.
89    pub fn retrieve(&self, _query: &str) -> Result<CascadeResult> {
90        // Placeholder implementation - returns empty results
91        // Full implementation requires csm feature to be enabled
92        Ok(CascadeResult {
93            episode_ids: Vec::new(),
94            scores: Vec::new(),
95            contributing_tiers: Vec::new(),
96            api_calls: 0,
97        })
98    }
99
100    /// Get the configuration for this retriever.
101    pub fn config(&self) -> &CascadeConfig {
102        &self.config
103    }
104
105    /// Estimate the number of API calls that would be saved for a query.
106    ///
107    /// Returns 1.0 if the query would likely require an API call,
108    /// or 0.0 if CPU-local tiers would likely suffice.
109    pub fn estimate_api_call_probability(&self, _query: &str) -> f32 {
110        // Placeholder - in full implementation, this would analyze
111        // query characteristics (length, keywords, complexity)
112        // to estimate probability of needing API fallback
113        0.5
114    }
115}
116
117/// Weight computation for query-length-dependent tier weighting.
118///
119/// Short queries favor BM25 (keyword matching), long queries favor
120/// HDC/semantic matching.
121#[cfg(feature = "csm")]
122pub fn compute_tier_weights(query: &str) -> (f32, f32, f32) {
123    let len = query.len();
124    if len < 20 {
125        // Short query: favor keyword matching
126        (0.7, 0.2, 0.1)
127    } else if len < 100 {
128        // Medium query: balanced weighting
129        (0.4, 0.4, 0.2)
130    } else {
131        // Long query: favor semantic matching
132        (0.2, 0.5, 0.3)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139
140    #[test]
141    fn test_cascade_config_default() {
142        let config = CascadeConfig::default();
143        assert_eq!(config.top_k, 10);
144        assert!(config.bm25_threshold > 0.0);
145        assert!(config.hdc_threshold > 0.0);
146        assert!(config.concept_graph_threshold > 0.0);
147        assert!(config.merge_results);
148    }
149
150    #[test]
151    fn test_cascade_retriever_creation() {
152        let config = CascadeConfig::default();
153        let retriever = CascadeRetriever::new(config);
154        assert_eq!(retriever.config().top_k, 10);
155    }
156
157    #[test]
158    fn test_placeholder_retrieve() {
159        let retriever = CascadeRetriever::new(CascadeConfig::default());
160        let result = retriever.retrieve("test query").unwrap();
161        assert!(result.episode_ids.is_empty());
162        assert!(result.scores.is_empty());
163        assert_eq!(result.api_calls, 0);
164    }
165
166    #[test]
167    fn test_estimate_api_call_probability() {
168        let retriever = CascadeRetriever::new(CascadeConfig::default());
169        let prob = retriever.estimate_api_call_probability("test");
170        assert!((0.0..=1.0).contains(&prob));
171    }
172}