do_memory_core/retrieval/cascade.rs
1//! Cascading retrieval pipeline (WG-131).
2//!
3//! Implements a 4-tier retrieval cascade:
4//! 1. BM25 keyword index (CPU-local, no API calls)
5//! 2. HDC hyperdimensional encoding (CPU-local, no API calls)
6//! 3. ConceptGraph ontology expansion (CPU-local, no API calls)
7//! 4. API embedding fallback (external API call)
8//!
9//! The cascade eliminates 50-70% of embedding API calls by satisfying
10//! queries from CPU-local tiers before falling back to the API.
11
12use anyhow::Result;
13
14/// Configuration for the cascading retrieval pipeline.
15#[derive(Debug, Clone)]
16pub struct CascadeConfig {
17 /// Number of results to return from each tier.
18 pub top_k: usize,
19 /// Minimum score threshold for BM25 results (0.0-1.0).
20 pub bm25_threshold: f32,
21 /// Minimum similarity threshold for HDC results (0.0-1.0).
22 pub hdc_threshold: f32,
23 /// Minimum confidence threshold for ConceptGraph results (0.0-1.0).
24 pub concept_graph_threshold: f32,
25 /// Whether to merge results across tiers.
26 pub merge_results: bool,
27}
28
29impl Default for CascadeConfig {
30 fn default() -> Self {
31 Self {
32 top_k: 10,
33 bm25_threshold: 0.3,
34 hdc_threshold: 0.5,
35 concept_graph_threshold: 0.4,
36 merge_results: true,
37 }
38 }
39}
40
41/// Result from a single tier in the cascade.
42#[derive(Debug, Clone)]
43pub struct TierResult {
44 /// Tier identifier (bm25, hdc, concept_graph, api).
45 pub tier: String,
46 /// Retrieved episode IDs as strings.
47 pub ids: Vec<String>,
48 /// Normalized scores (0.0-1.0).
49 pub scores: Vec<f32>,
50 /// Whether this tier produced sufficient results.
51 pub sufficient: bool,
52}
53
54/// Final result from the cascading retrieval pipeline.
55#[derive(Debug, Clone)]
56pub struct CascadeResult {
57 /// Final merged/re-ranked episode IDs.
58 pub episode_ids: Vec<String>,
59 /// Final merged/re-ranked scores.
60 pub scores: Vec<f32>,
61 /// Which tier(s) contributed to the final result.
62 pub contributing_tiers: Vec<String>,
63 /// Number of API calls made (should be 0 or 1).
64 pub api_calls: u32,
65}
66
67/// Cascading retrieval orchestrator.
68///
69/// Coordinates the 4-tier retrieval pipeline, falling back to API
70/// only when CPU-local tiers cannot satisfy the query.
71pub struct CascadeRetriever {
72 config: CascadeConfig,
73}
74
75impl CascadeRetriever {
76 /// Create a new cascade retriever with given configuration.
77 pub fn new(config: CascadeConfig) -> Self {
78 Self { config }
79 }
80
81 /// Execute the cascading retrieval pipeline.
82 ///
83 /// This is a placeholder that returns empty results when the `csm`
84 /// feature is not enabled. With `csm` enabled, it uses BM25, HDC,
85 /// and ConceptGraph from the `chaotic_semantic_memory` crate.
86 ///
87 /// Note: This placeholder is non-async. The full CSM implementation
88 /// will be async to allow for concurrent tier queries.
89 pub fn retrieve(&self, _query: &str) -> Result<CascadeResult> {
90 // Placeholder implementation - returns empty results
91 // Full implementation requires csm feature to be enabled
92 Ok(CascadeResult {
93 episode_ids: Vec::new(),
94 scores: Vec::new(),
95 contributing_tiers: Vec::new(),
96 api_calls: 0,
97 })
98 }
99
100 /// Get the configuration for this retriever.
101 pub fn config(&self) -> &CascadeConfig {
102 &self.config
103 }
104
105 /// Estimate the number of API calls that would be saved for a query.
106 ///
107 /// Returns 1.0 if the query would likely require an API call,
108 /// or 0.0 if CPU-local tiers would likely suffice.
109 pub fn estimate_api_call_probability(&self, _query: &str) -> f32 {
110 // Placeholder - in full implementation, this would analyze
111 // query characteristics (length, keywords, complexity)
112 // to estimate probability of needing API fallback
113 0.5
114 }
115}
116
117/// Weight computation for query-length-dependent tier weighting.
118///
119/// Short queries favor BM25 (keyword matching), long queries favor
120/// HDC/semantic matching.
121#[cfg(feature = "csm")]
122pub fn compute_tier_weights(query: &str) -> (f32, f32, f32) {
123 let len = query.len();
124 if len < 20 {
125 // Short query: favor keyword matching
126 (0.7, 0.2, 0.1)
127 } else if len < 100 {
128 // Medium query: balanced weighting
129 (0.4, 0.4, 0.2)
130 } else {
131 // Long query: favor semantic matching
132 (0.2, 0.5, 0.3)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_cascade_config_default() {
142 let config = CascadeConfig::default();
143 assert_eq!(config.top_k, 10);
144 assert!(config.bm25_threshold > 0.0);
145 assert!(config.hdc_threshold > 0.0);
146 assert!(config.concept_graph_threshold > 0.0);
147 assert!(config.merge_results);
148 }
149
150 #[test]
151 fn test_cascade_retriever_creation() {
152 let config = CascadeConfig::default();
153 let retriever = CascadeRetriever::new(config);
154 assert_eq!(retriever.config().top_k, 10);
155 }
156
157 #[test]
158 fn test_placeholder_retrieve() {
159 let retriever = CascadeRetriever::new(CascadeConfig::default());
160 let result = retriever.retrieve("test query").unwrap();
161 assert!(result.episode_ids.is_empty());
162 assert!(result.scores.is_empty());
163 assert_eq!(result.api_calls, 0);
164 }
165
166 #[test]
167 fn test_estimate_api_call_probability() {
168 let retriever = CascadeRetriever::new(CascadeConfig::default());
169 let prob = retriever.estimate_api_call_probability("test");
170 assert!((0.0..=1.0).contains(&prob));
171 }
172}