Skip to main content

agentic_memory/v3/
retrieval.rs

1//! Smart multi-index context retrieval engine.
2
3use super::block::*;
4use super::immortal_log::*;
5use super::indexes::*;
6use super::tiered::*;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Request for smart context retrieval
11#[derive(Debug, Clone)]
12pub struct RetrievalRequest {
13    /// Natural language query
14    pub query: String,
15    /// Token budget for the context
16    pub token_budget: u32,
17    /// Retrieval strategy
18    pub strategy: RetrievalStrategy,
19    /// Minimum relevance score (0.0 - 1.0)
20    pub min_relevance: f32,
21}
22
23/// Retrieval strategy
24#[derive(Debug, Clone, Copy)]
25pub enum RetrievalStrategy {
26    /// Prioritize recent blocks
27    Recency,
28    /// Prioritize relevant blocks
29    Relevance,
30    /// Prioritize causal chains
31    Causal,
32    /// Balanced mix
33    Balanced,
34    /// Custom weights
35    Custom {
36        recency_weight: f32,
37        relevance_weight: f32,
38        causal_weight: f32,
39    },
40}
41
42/// Result of smart retrieval
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct RetrievalResult {
45    /// Assembled context (ordered blocks)
46    pub blocks: Vec<Block>,
47    /// Tokens used
48    pub tokens_used: u32,
49    /// Coverage metrics
50    pub coverage: RetrievalCoverage,
51    /// Blocks that didn't fit
52    pub omitted: Vec<BlockHash>,
53    /// Retrieval duration in ms
54    pub retrieval_ms: u64,
55}
56
57/// Coverage metrics
58#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct RetrievalCoverage {
60    pub semantic: f32,
61    pub temporal: f32,
62    pub causal: f32,
63}
64
65/// Smart retrieval engine
66pub struct SmartRetrievalEngine {}
67
68impl SmartRetrievalEngine {
69    pub fn new() -> Self {
70        Self {}
71    }
72
73    /// Main retrieval function
74    #[allow(clippy::too_many_arguments)]
75    pub fn retrieve(
76        &self,
77        request: RetrievalRequest,
78        _log: &ImmortalLog,
79        storage: &TieredStorage,
80        temporal: &temporal::TemporalIndex,
81        semantic: &semantic::SemanticIndex,
82        causal: &causal::CausalIndex,
83        entity: &entity::EntityIndex,
84        _procedural: &procedural::ProceduralIndex,
85    ) -> RetrievalResult {
86        let start = std::time::Instant::now();
87
88        // Step 1: Gather candidates from all indexes
89        let mut candidates: HashMap<u64, f32> = HashMap::new();
90
91        // Semantic search
92        let semantic_results = semantic.search_by_text(&request.query, 100);
93        for result in &semantic_results {
94            let score = result.score * self.get_weight(&request.strategy, "semantic");
95            *candidates.entry(result.block_sequence).or_insert(0.0) += score;
96        }
97
98        // Temporal search (recent)
99        let recent_results = temporal.query_recent(3600); // Last hour
100        for (i, result) in recent_results.iter().enumerate() {
101            let recency_score = 1.0 - (i as f32 / recent_results.len().max(1) as f32);
102            let score = recency_score * self.get_weight(&request.strategy, "temporal");
103            *candidates.entry(result.block_sequence).or_insert(0.0) += score;
104        }
105
106        // Entity search (extract entities from query)
107        for word in request.query.split_whitespace() {
108            if word.contains('/') || word.contains('.') {
109                let entity_results = entity.query_entity(word);
110                for result in entity_results {
111                    let score = 0.8 * self.get_weight(&request.strategy, "entity");
112                    *candidates.entry(result.block_sequence).or_insert(0.0) += score;
113                }
114            }
115        }
116
117        // Step 2: Sort by score
118        let mut sorted: Vec<(u64, f32)> = candidates.into_iter().collect();
119        sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
120
121        // Step 3: Filter by minimum relevance
122        sorted.retain(|(_, score)| *score >= request.min_relevance);
123
124        // Step 4: Fit within token budget
125        let mut selected_blocks = Vec::new();
126        let mut tokens_used = 0u32;
127        let mut omitted = Vec::new();
128
129        for (seq, _score) in &sorted {
130            if let Some(block) = storage.get(*seq) {
131                let block_tokens = self.estimate_tokens(&block);
132                if tokens_used + block_tokens <= request.token_budget {
133                    tokens_used += block_tokens;
134                    selected_blocks.push(block);
135                } else {
136                    omitted.push(block.hash);
137                }
138            }
139        }
140
141        // Step 5: Causal expansion (add context for decisions)
142        let decision_blocks: Vec<u64> = selected_blocks
143            .iter()
144            .filter(|b| matches!(b.block_type, BlockType::Decision))
145            .map(|b| b.sequence)
146            .collect();
147
148        for decision_seq in &decision_blocks {
149            let ancestors = causal.get_ancestors(*decision_seq, 3);
150            for result in ancestors {
151                if !selected_blocks
152                    .iter()
153                    .any(|b| b.sequence == result.block_sequence)
154                {
155                    if let Some(block) = storage.get(result.block_sequence) {
156                        let block_tokens = self.estimate_tokens(&block);
157                        if tokens_used + block_tokens <= request.token_budget {
158                            tokens_used += block_tokens;
159                            selected_blocks.push(block);
160                        }
161                    }
162                }
163            }
164        }
165
166        // Step 6: Sort by sequence (chronological order)
167        selected_blocks.sort_by_key(|b| b.sequence);
168
169        // Step 7: Calculate coverage
170        let coverage = RetrievalCoverage {
171            semantic: (selected_blocks.len() as f32 / 100.0).min(1.0),
172            temporal: (selected_blocks
173                .iter()
174                .filter(|b| {
175                    chrono::Utc::now()
176                        .signed_duration_since(b.timestamp)
177                        .num_hours()
178                        < 24
179                })
180                .count() as f32
181                / 50.0)
182                .min(1.0),
183            causal: (decision_blocks.len() as f32 / 10.0).min(1.0),
184        };
185
186        RetrievalResult {
187            blocks: selected_blocks,
188            tokens_used,
189            coverage,
190            omitted,
191            retrieval_ms: start.elapsed().as_millis() as u64,
192        }
193    }
194
195    fn get_weight(&self, strategy: &RetrievalStrategy, index_type: &str) -> f32 {
196        match strategy {
197            RetrievalStrategy::Recency => match index_type {
198                "temporal" => 1.0,
199                "semantic" => 0.3,
200                _ => 0.2,
201            },
202            RetrievalStrategy::Relevance => match index_type {
203                "semantic" => 1.0,
204                "entity" => 0.8,
205                _ => 0.2,
206            },
207            RetrievalStrategy::Causal => match index_type {
208                "causal" => 1.0,
209                "semantic" => 0.5,
210                _ => 0.2,
211            },
212            RetrievalStrategy::Balanced => 0.5,
213            RetrievalStrategy::Custom {
214                recency_weight,
215                relevance_weight,
216                causal_weight,
217            } => match index_type {
218                "temporal" => *recency_weight,
219                "semantic" => *relevance_weight,
220                "causal" => *causal_weight,
221                _ => 0.3,
222            },
223        }
224    }
225
226    fn estimate_tokens(&self, block: &Block) -> u32 {
227        let content_size = match &block.content {
228            BlockContent::Text { text, .. } => text.len(),
229            BlockContent::Tool {
230                tool_name,
231                input,
232                output,
233                ..
234            } => {
235                tool_name.len()
236                    + serde_json::to_string(input).map(|s| s.len()).unwrap_or(0)
237                    + output
238                        .as_ref()
239                        .and_then(|o| serde_json::to_string(o).ok())
240                        .map(|s| s.len())
241                        .unwrap_or(0)
242            }
243            BlockContent::File { path, diff, .. } => {
244                path.len() + diff.as_ref().map(|d| d.len()).unwrap_or(0)
245            }
246            BlockContent::Decision {
247                decision,
248                reasoning,
249                ..
250            } => decision.len() + reasoning.as_ref().map(|r| r.len()).unwrap_or(0),
251            BlockContent::Boundary { summary, .. } => summary.len(),
252            BlockContent::Error {
253                message,
254                resolution,
255                ..
256            } => message.len() + resolution.as_ref().map(|r| r.len()).unwrap_or(0),
257            BlockContent::Checkpoint {
258                working_context, ..
259            } => working_context.len(),
260            BlockContent::Binary { data, .. } => data.len(),
261        };
262
263        ((content_size / 4) + 10) as u32
264    }
265}
266
267impl Default for SmartRetrievalEngine {
268    fn default() -> Self {
269        Self::new()
270    }
271}