Skip to main content

offline_intelligence/context_engine/
smart_retrieval.rs

1//! Smart retrieval with two-tier context optimization
2//!
3//! This module implements intelligent retrieval that minimizes recomputation cost
4//! by enforcing strict token budgets and importance filtering.
5//!
6//! Key optimizations:
7//! - Tier 1 (hot cache) → O(1) return, 100% compute savings
8//! - Tier 3 (cold storage) → Importance-filtered, token-budgeted SQLite retrieval
9
10use crate::memory::Message;
11use crate::memory_db::StoredMessage;
12use crate::context_engine::tier_manager::TierManager;
13use std::sync::Arc;
14use tokio::sync::RwLock;
15use tracing::{info, debug};
16
17/// Configuration for smart retrieval
18#[derive(Debug, Clone)]
19pub struct SmartRetrievalConfig {
20    /// Maximum tokens for retrieved historical context (excludes current messages)
21    pub max_retrieved_tokens: usize,
22
23    /// Minimum importance score to include a message (0.0-1.0)
24    pub importance_threshold: f32,
25
26    /// Group contiguous messages into chunks for better llama.cpp caching
27    pub chunk_contiguous_messages: bool,
28
29    /// Enable smart retrieval (can be disabled to fall back to original behavior)
30    pub enabled: bool,
31}
32
33impl Default for SmartRetrievalConfig {
34    fn default() -> Self {
35        Self {
36            max_retrieved_tokens: 1000,
37            importance_threshold: 0.5,
38            chunk_contiguous_messages: true,
39            enabled: true,
40        }
41    }
42}
43
44impl SmartRetrievalConfig {
45    /// Derive the historical retrieval budget from the model's context window.
46    /// 25% of CTX_SIZE is allocated to retrieved history (summaries + cold SQLite),
47    /// ensuring the current conversation always gets the lion's share.
48    pub fn from_ctx_size(ctx_size: u32) -> Self {
49        Self {
50            max_retrieved_tokens: (ctx_size as f32 * 0.25) as usize,
51            ..Self::default()
52        }
53    }
54}
55
56/// Result of smart retrieval operation
57#[derive(Debug, Clone)]
58pub struct RetrievalResult {
59    /// Strategy used for retrieval
60    pub strategy: RetrievalStrategy,
61
62    /// Optimized messages to send to LLM
63    pub messages: Vec<Message>,
64
65    /// Estimated computation cost saved (0.0-1.0)
66    pub compute_savings: f32,
67
68    /// Token count of retrieved context
69    pub retrieved_tokens: usize,
70
71    /// Sessions referenced in the retrieval
72    pub sessions_referenced: Vec<String>,
73}
74
75/// Strategy used for retrieval
76#[derive(Debug, Clone, PartialEq)]
77pub enum RetrievalStrategy {
78    /// Current session already in Tier 1 hot cache
79    HotCacheHit,
80
81    /// Retrieved chunks with importance filtering
82    ImportanceFiltered,
83
84    /// Full retrieval (fallback, no optimization)
85    FullRetrieval,
86
87    /// No retrieval needed (fresh context)
88    NoRetrieval,
89}
90
91/// Smart retrieval orchestrator
92pub struct SmartRetrieval {
93    tier_manager: Arc<RwLock<TierManager>>,
94    config: SmartRetrievalConfig,
95}
96
97impl SmartRetrieval {
98    /// Create a new smart retrieval instance
99    pub fn new(tier_manager: Arc<RwLock<TierManager>>, config: SmartRetrievalConfig) -> Self {
100        Self {
101            tier_manager,
102            config,
103        }
104    }
105
106    /// Main retrieval function with smart optimization
107    pub async fn retrieve(
108        &self,
109        session_id: &str,
110        current_messages: &[Message],
111        tier3_messages: Option<Vec<StoredMessage>>,
112        cross_session_messages: Option<Vec<StoredMessage>>,
113    ) -> anyhow::Result<RetrievalResult> {
114        if !self.config.enabled {
115            debug!("Smart retrieval disabled, using fallback");
116            return self.fallback_retrieval(current_messages);
117        }
118
119        // Step 1: Check Tier 1 hot cache
120        let tier_manager = self.tier_manager.read().await;
121        if let Some(hot_messages) = tier_manager.get_tier1_content(session_id).await {
122            let retrieved_tokens = self.count_tokens(&hot_messages);
123            info!("🚀 Smart retrieval: Tier 1 hot cache hit for session {}", session_id);
124            return Ok(RetrievalResult {
125                strategy: RetrievalStrategy::HotCacheHit,
126                messages: hot_messages,
127                compute_savings: 1.0,
128                retrieved_tokens,
129                sessions_referenced: vec![session_id.to_string()],
130            });
131        }
132        drop(tier_manager);
133
134        // Step 2: Check if we have any historical content
135        let has_tier3 = tier3_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
136        let has_cross_session = cross_session_messages.as_ref().map(|m| !m.is_empty()).unwrap_or(false);
137
138        if !has_tier3 && !has_cross_session {
139            debug!("No historical content available, returning current messages");
140            return Ok(RetrievalResult {
141                strategy: RetrievalStrategy::NoRetrieval,
142                messages: current_messages.to_vec(),
143                compute_savings: 0.0,
144                retrieved_tokens: 0,
145                sessions_referenced: vec![],
146            });
147        }
148
149        // Step 3: Build optimized context from Tier 1 (hot) and Tier 3 (cold) only
150        let optimized_context = self.build_context_from_tiers(
151            current_messages,
152            tier3_messages.as_ref(),
153            cross_session_messages.as_ref(),
154        ).await?;
155
156        let strategy = if self.config.importance_threshold > 0.0 {
157            RetrievalStrategy::ImportanceFiltered
158        } else {
159            RetrievalStrategy::FullRetrieval
160        };
161
162        let compute_savings = self.estimate_compute_savings(&strategy, &optimized_context.messages);
163
164        info!(
165            "Smart retrieval complete: Strategy={:?}, Tokens={}, Savings={:.1}%",
166            strategy,
167            optimized_context.retrieved_tokens,
168            compute_savings * 100.0
169        );
170
171        Ok(optimized_context)
172    }
173
174    /// Build context from Tier 1 (hot) and Tier 3 (cold storage) with importance filtering
175    async fn build_context_from_tiers(
176        &self,
177        current_messages: &[Message],
178        tier3_messages: Option<&Vec<StoredMessage>>,
179        cross_session_messages: Option<&Vec<StoredMessage>>,
180    ) -> anyhow::Result<RetrievalResult> {
181        let mut context = Vec::new();
182        let mut retrieved_tokens = 0;
183        let mut sessions_referenced = Vec::new();
184
185        let current_tokens: usize = current_messages.iter()
186            .map(|m| self.estimate_message_tokens(m))
187            .sum();
188
189        let budget_for_history = self.config.max_retrieved_tokens.saturating_sub(current_tokens);
190
191        // Add cross-session context (highest priority, 1/3 of budget)
192        if let Some(cross_msgs) = cross_session_messages {
193            if !cross_msgs.is_empty() {
194                let cross_context = self.add_cross_session_context(cross_msgs, budget_for_history / 3);
195                retrieved_tokens += self.count_tokens(&cross_context);
196
197                for msg in cross_msgs.iter().take(3) {
198                    if !sessions_referenced.contains(&msg.session_id) {
199                        sessions_referenced.push(msg.session_id.clone());
200                    }
201                }
202
203                context.extend(cross_context);
204            }
205        }
206
207        // Add importance-filtered messages from Tier 3 (cold storage)
208        if let Some(tier3_msgs) = tier3_messages {
209            let remaining_budget = budget_for_history.saturating_sub(retrieved_tokens);
210            let detail_context = self.add_important_details(tier3_msgs, remaining_budget);
211            retrieved_tokens += self.count_tokens(&detail_context);
212            context.extend(detail_context);
213        }
214
215        // Always append current messages last
216        context.extend_from_slice(current_messages);
217
218        Ok(RetrievalResult {
219            strategy: RetrievalStrategy::ImportanceFiltered,
220            messages: context,
221            compute_savings: 0.0,
222            retrieved_tokens,
223            sessions_referenced,
224        })
225    }
226
227    /// Add cross-session context with budget enforcement
228    fn add_cross_session_context(
229        &self,
230        cross_messages: &[StoredMessage],
231        token_budget: usize,
232    ) -> Vec<Message> {
233        let mut context = Vec::new();
234        let mut used_tokens = 0;
235
236        // Add bridge message
237        context.push(Message {
238            role: "system".to_string(),
239            content: "[Context from previous conversations]".to_string(),
240        });
241        used_tokens += 8;
242
243        // Add top 3 most important cross-session messages
244        let mut scored: Vec<_> = cross_messages.iter()
245            .map(|m| (m, m.importance_score))
246            .collect();
247        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
248
249        for (msg, _score) in scored.iter().take(3) {
250            let msg_tokens = msg.tokens as usize;
251            if used_tokens + msg_tokens > token_budget {
252                break;
253            }
254
255            context.push(Message {
256                role: msg.role.clone(),
257                content: format!("[From earlier: {}]", msg.content),
258            });
259            used_tokens += msg_tokens;
260        }
261
262        debug!("Added {} cross-session messages ({} tokens)", context.len() - 1, used_tokens);
263        context
264    }
265
266    /// Add important details with importance filtering and budget
267    fn add_important_details(
268        &self,
269        messages: &[StoredMessage],
270        token_budget: usize,
271    ) -> Vec<Message> {
272        let mut context = Vec::new();
273        let mut used_tokens = 0;
274
275        // Filter by importance threshold
276        let important: Vec<_> = messages.iter()
277            .filter(|m| m.importance_score >= self.config.importance_threshold)
278            .collect();
279
280        if important.is_empty() {
281            debug!("No messages meet importance threshold {}", self.config.importance_threshold);
282            return context;
283        }
284
285        // Sort by importance score descending
286        let mut scored = important.clone();
287        scored.sort_by(|a, b| b.importance_score.partial_cmp(&a.importance_score).unwrap_or(std::cmp::Ordering::Equal));
288
289        // Add messages until budget exhausted
290        for msg in scored {
291            let msg_tokens = msg.tokens as usize;
292            if used_tokens + msg_tokens > token_budget {
293                break;
294            }
295
296            context.push(Message {
297                role: msg.role.clone(),
298                content: msg.content.clone(),
299            });
300            used_tokens += msg_tokens;
301        }
302
303        info!("Added {} important messages ({} tokens, threshold={:.2})",
304              context.len(),
305              used_tokens,
306              self.config.importance_threshold
307        );
308
309        context
310    }
311
312    /// Estimate compute savings based on strategy
313    fn estimate_compute_savings(&self, strategy: &RetrievalStrategy, _messages: &[Message]) -> f32 {
314        match strategy {
315            RetrievalStrategy::HotCacheHit => 1.0,
316            RetrievalStrategy::ImportanceFiltered => 0.6,
317            RetrievalStrategy::FullRetrieval => 0.0,
318            RetrievalStrategy::NoRetrieval => 0.0,
319        }
320    }
321
322    /// Count total tokens in messages
323    fn count_tokens(&self, messages: &[Message]) -> usize {
324        messages.iter()
325            .map(|m| self.estimate_message_tokens(m))
326            .sum()
327    }
328
329    /// Estimate tokens for a message (rough approximation: 4 chars per token)
330    fn estimate_message_tokens(&self, message: &Message) -> usize {
331        message.content.len() / 4
332    }
333
334    /// Fallback retrieval (disabled smart retrieval)
335    fn fallback_retrieval(&self, current_messages: &[Message]) -> anyhow::Result<RetrievalResult> {
336        Ok(RetrievalResult {
337            strategy: RetrievalStrategy::FullRetrieval,
338            messages: current_messages.to_vec(),
339            compute_savings: 0.0,
340            retrieved_tokens: 0,
341            sessions_referenced: vec![],
342        })
343    }
344}
345
346impl Clone for SmartRetrieval {
347    fn clone(&self) -> Self {
348        Self {
349            tier_manager: Arc::clone(&self.tier_manager),
350            config: self.config.clone(),
351        }
352    }
353}