Skip to main content

offline_intelligence/context_engine/
orchestrator.rs

1//! Main orchestrator that coordinates all memory subsystems
2
3use crate::memory::Message;
4use crate::memory_db::MemoryDatabase;
5use crate::context_engine::{
6    retrieval_planner::RetrievalPlan,
7    retrieval_planner::RetrievalPlanner,
8    tier_manager::{TierManager, TierManagerConfig},
9    context_builder::{ContextBuilder, ContextBuilderConfig},
10    smart_retrieval::{SmartRetrieval, SmartRetrievalConfig},
11};
12use crate::worker_threads::LLMWorker;
13
14use std::sync::Arc;
15use tracing::{info, debug, warn};
16use tokio::sync::RwLock;
17
18/// Main orchestrator for the context engine
19pub struct ContextOrchestrator {
20    database: Arc<MemoryDatabase>,
21    retrieval_planner: Arc<RwLock<RetrievalPlanner>>,
22    tier_manager: Arc<RwLock<TierManager>>,
23    context_builder: Arc<RwLock<ContextBuilder>>,
24    config: OrchestratorConfig,
25    /// LLM worker for generating query embeddings during semantic search
26    llm_worker: Option<Arc<LLMWorker>>,
27    /// Smart retrieval for optimized context assembly
28    smart_retrieval: Option<Arc<SmartRetrieval>>,
29}
30
31/// Configuration for the orchestrator
32#[derive(Debug, Clone)]
33pub struct OrchestratorConfig {
34    pub enabled: bool,
35    pub max_context_tokens: usize,
36    pub auto_optimize: bool,
37    pub enable_metrics: bool,
38    pub session_timeout_seconds: u64,
39    /// Enable smart retrieval optimization (default: true)
40    pub enable_smart_retrieval: bool,
41    /// Smart retrieval configuration
42    pub smart_retrieval_config: SmartRetrievalConfig,
43}
44
45impl Default for OrchestratorConfig {
46    fn default() -> Self {
47        Self {
48            enabled: true,
49            max_context_tokens: 4000,
50            auto_optimize: true,
51            enable_metrics: true,
52            session_timeout_seconds: 3600,
53            enable_smart_retrieval: true,  // Enable by default for production
54            smart_retrieval_config: SmartRetrievalConfig::default(),
55        }
56    }
57}
58
59impl ContextOrchestrator {
60    /// Create a new context orchestrator
61    pub async fn new(
62        database: Arc<MemoryDatabase>,
63        config: OrchestratorConfig,
64    ) -> anyhow::Result<Self> {
65        // Create retrieval planner wrapped in Arc<RwLock>
66        let retrieval_planner = Arc::new(RwLock::new(RetrievalPlanner::new(database.clone())));
67        
68        // Create tier manager
69        let tier_manager_config = TierManagerConfig::default();
70        let tier_manager = TierManager::new(
71            database.clone(),
72            tier_manager_config,
73        );
74        let tier_manager = Arc::new(RwLock::new(tier_manager));
75        
76        // Create context builder wrapped in Arc<RwLock>
77        let context_builder_config = ContextBuilderConfig::default();
78        let context_builder = Arc::new(RwLock::new(ContextBuilder::new(context_builder_config)));
79        
80        // Initialize smart retrieval if enabled
81        let smart_retrieval = if config.enable_smart_retrieval {
82            let smart_ret = SmartRetrieval::new(
83                Arc::clone(&tier_manager),
84                config.smart_retrieval_config.clone(),
85            );
86            info!("Smart retrieval initialized (enabled)");
87            Some(Arc::new(smart_ret))
88        } else {
89            info!("Smart retrieval disabled");
90            None
91        };
92
93        let orchestrator = Self {
94            database,
95            retrieval_planner,
96            tier_manager,
97            context_builder,
98            config,
99            llm_worker: None,
100            smart_retrieval,
101        };
102
103        info!("Context orchestrator initialized successfully");
104
105        Ok(orchestrator)
106    }
107
108    /// Set the LLM worker for embedding-based semantic search
109    pub fn set_llm_worker(&mut self, worker: Arc<LLMWorker>) {
110        self.llm_worker = Some(worker);
111        info!("Context orchestrator: LLM worker set for semantic search");
112    }
113    
114    /// Chat persistence: Expose database for conversation API handlers
115    pub fn database(&self) -> &Arc<MemoryDatabase> {
116        &self.database
117    }
118    
119    /// Process conversation and return optimized context
120    pub async fn process_conversation(
121        &self,
122        session_id: &str,
123        messages: &[Message],
124        user_query: Option<&str>,
125    ) -> anyhow::Result<Vec<Message>> {
126        if !self.config.enabled || messages.is_empty() {
127            debug!("Context engine disabled or no messages");
128            return Ok(messages.to_vec());
129        }
130        
131        info!("Processing conversation for session {} ({} messages)", session_id, messages.len());
132        
133        // Update current messages in Tier 1
134        {
135            let tier_manager = self.tier_manager.write().await;
136            tier_manager.store_tier1_content(session_id, messages).await;
137        }
138        
139        // Save ONLY the last user message (new query) to database
140        if let Some(last_message) = messages.last() {
141            if last_message.role == "user" {
142                let tier_manager = self.tier_manager.read().await;
143                if let Err(e) = tier_manager.store_tier3_content(session_id, std::slice::from_ref(last_message)).await {
144                    warn!("Failed to persist user query to database: {}", e);
145                } else {
146                    info!("✅ Persisted user query to database for session {}", session_id);
147                }
148            }
149        }
150        
151        // Create retrieval plan
152        let plan = {
153            let retrieval_planner = self.retrieval_planner.read().await;
154            
155            // --- UPDATED CALL ---
156            // Detect if the user is referring to past conversations
157            let has_past_refs = if let Some(query) = user_query {
158                retrieval_planner.has_past_references_in_text(query)
159            } else {
160                false
161            };
162            
163            // Now create the plan using the detected references and the user query
164            retrieval_planner.create_plan(
165                session_id,
166                messages,
167                self.config.max_context_tokens,
168                user_query,
169                has_past_refs, // Passing the reference check to the planner
170            ).await?
171        };
172        
173        if !plan.needs_retrieval {
174            debug!("No retrieval needed, returning current messages");
175            return Ok(messages.to_vec());
176        }
177        
178        // Execute retrieval plan (includes semantic search when KV cache misses)
179        let retrieved_content = self.execute_retrieval_plan(session_id, &plan, user_query).await?;
180
181        // === SMART RETRIEVAL INTEGRATION ===
182        // If smart retrieval is enabled, use it to optimize the context assembly
183        let optimized_context = if let Some(ref smart_retrieval) = self.smart_retrieval {
184            match smart_retrieval.retrieve(
185                session_id,
186                messages,
187                retrieved_content.tier2.clone(),
188                retrieved_content.tier3.clone(),
189                retrieved_content.cross_session.clone(),
190            ).await {
191                Ok(smart_result) => {
192                    info!(
193                        "🎯 Smart retrieval: Strategy={:?}, Tokens={}, Savings={:.1}%",
194                        smart_result.strategy,
195                        smart_result.retrieved_tokens,
196                        smart_result.compute_savings * 100.0
197                    );
198                    smart_result.messages
199                }
200                Err(e) => {
201                    warn!("Smart retrieval failed, falling back to standard: {}", e);
202                    // Fallback to standard context building
203                    let mut context_builder = self.context_builder.write().await;
204                    context_builder.build_context(
205                        messages,
206                        retrieved_content.tier1,
207                        retrieved_content.tier2,
208                        retrieved_content.tier3,
209                        retrieved_content.cross_session,
210                        user_query,
211                    ).await?
212                }
213            }
214        } else {
215            // Smart retrieval disabled, use standard context building
216            let mut context_builder = self.context_builder.write().await;
217            context_builder.build_context(
218                messages,
219                retrieved_content.tier1,
220                retrieved_content.tier2,
221                retrieved_content.tier3,
222                retrieved_content.cross_session,
223                user_query,
224            ).await?
225        };
226        
227        // If we used retrieval, update statistics
228        if let Some(query) = user_query {
229            if let Some(response) = optimized_context.last() {
230                if response.role == "assistant" {
231                    self.update_engagement(query, &response.content).await;
232                }
233            }
234        }
235        
236        info!(
237            "Context optimization complete: {} -> {} messages",
238            messages.len(),
239            optimized_context.len()
240        );
241        
242        Ok(optimized_context)
243    }
244    
245    /// Save assistant response to database (Tier 3)
246    pub async fn save_assistant_response(
247        &self,
248        session_id: &str,
249        response: &str,
250    ) -> anyhow::Result<()> {
251        let assistant_message = Message {
252            role: "assistant".to_string(),
253            content: response.to_string(),
254        };
255        
256        let tier_manager = self.tier_manager.read().await;
257        tier_manager.store_tier3_content(session_id, &[assistant_message]).await
258    }
259    
260    /// Execute retrieval plan across all tiers.
261    /// When semantic_search is enabled and we have an LLM worker, we embed the query
262    /// and find similar messages via HNSW — this is the core "KV cache miss → DB retrieval" path.
263    async fn execute_retrieval_plan(
264        &self,
265        session_id: &str,
266        plan: &RetrievalPlan,
267        user_query: Option<&str>,
268    ) -> anyhow::Result<RetrievedContent> {
269        let mut retrieved = RetrievedContent::default();
270
271        // Retrieve from Tier 1 (current context — hot KV cache)
272        if plan.use_tier1 {
273            let tier_manager = self.tier_manager.read().await;
274            retrieved.tier1 = tier_manager.get_tier1_content(session_id).await;
275        }
276
277        // Retrieve from Tier 2 (summaries)
278        if plan.use_tier2 {
279            let tier_manager = self.tier_manager.read().await;
280            retrieved.tier2 = tier_manager.get_tier2_content(session_id).await;
281        }
282
283        // ── Semantic Search: KV cache miss path ──
284        // If the retrieval plan calls for semantic search AND we have embeddings available,
285        // embed the user query and find semantically similar past messages from the DB.
286        // This avoids re-computing full context — we retrieve just the relevant history.
287        //
288        // IMPORTANT: Skip entirely when no embeddings exist yet (first conversation / fresh DB).
289        // This avoids a wasted round-trip to llama-server /v1/embeddings when there's nothing to search.
290        let mut semantic_results: Vec<crate::memory_db::StoredMessage> = Vec::new();
291
292        let has_embeddings = self.database.embeddings.get_stats()
293            .map(|s| s.total_embeddings > 0)
294            .unwrap_or(false);
295
296        if plan.semantic_search && has_embeddings {
297            if let (Some(ref llm_worker), Some(query)) = (&self.llm_worker, user_query) {
298                match llm_worker.generate_embeddings(vec![query.to_string()]).await {
299                    Ok(query_embeddings) if !query_embeddings.is_empty() => {
300                        let query_vec = &query_embeddings[0];
301                        // Search HNSW index for similar past messages
302                        match self.database.embeddings.find_similar_embeddings(
303                            query_vec,
304                            "llama-server",
305                            (plan.max_messages * 2) as i32,
306                            0.3, // similarity threshold
307                        ) {
308                            Ok(similar) if !similar.is_empty() => {
309                                info!("Semantic search found {} similar messages for context retrieval", similar.len());
310                                // Fetch actual message content for each match
311                                for (message_id, _similarity) in &similar {
312                                    // Get message from DB by ID
313                                    let conn = self.database.conversations.get_conn_public();
314                                    if let Ok(conn) = conn {
315                                        let mut stmt = conn.prepare(
316                                            "SELECT id, session_id, message_index, role, content, tokens,
317                                                    timestamp, importance_score, embedding_generated
318                                             FROM messages WHERE id = ?1"
319                                        ).ok();
320                                        if let Some(ref mut stmt) = stmt {
321                                            if let Ok(mut rows) = stmt.query([message_id]) {
322                                                if let Ok(Some(row)) = rows.next() {
323                                                    let ts_str: String = row.get(6).unwrap_or_default();
324                                                    let ts = chrono::DateTime::parse_from_rfc3339(&ts_str)
325                                                        .map(|dt| dt.with_timezone(&chrono::Utc))
326                                                        .unwrap_or_else(|_| chrono::Utc::now());
327                                                    semantic_results.push(crate::memory_db::StoredMessage {
328                                                        id: row.get(0).unwrap_or(0),
329                                                        session_id: row.get(1).unwrap_or_default(),
330                                                        message_index: row.get(2).unwrap_or(0),
331                                                        role: row.get(3).unwrap_or_default(),
332                                                        content: row.get(4).unwrap_or_default(),
333                                                        tokens: row.get(5).unwrap_or(0),
334                                                        timestamp: ts,
335                                                        importance_score: row.get(7).unwrap_or(0.5),
336                                                        embedding_generated: row.get(8).unwrap_or(true),
337                                                    });
338                                                }
339                                            }
340                                        }
341                                    }
342                                }
343                            }
344                            Ok(_) => debug!("Semantic search: no results above threshold"),
345                            Err(e) => debug!("Semantic search failed: {}", e),
346                        }
347                    }
348                    Ok(_) => debug!("Empty embedding response for query"),
349                    Err(e) => debug!("Query embedding generation failed (semantic search skipped): {}", e),
350                }
351            }
352        }
353
354        // Retrieve from Tier 3 (full database) — keyword fallback or supplement
355        if plan.use_tier3 {
356            let tier_manager = self.tier_manager.read().await;
357            if plan.keyword_search && !plan.search_topics.is_empty() {
358                for topic in &plan.search_topics {
359                    let limit_per_topic = plan.max_messages / plan.search_topics.len().max(1);
360
361                    if let Ok(results) = tier_manager.search_tier3_content(
362                        session_id,
363                        topic,
364                        limit_per_topic,
365                    ).await {
366                        // Merge with semantic results, deduplicating by message ID
367                        let semantic_ids: std::collections::HashSet<i64> = semantic_results.iter().map(|m| m.id).collect();
368                        let mut merged = semantic_results.clone();
369                        for msg in results {
370                            if !semantic_ids.contains(&msg.id) {
371                                merged.push(msg);
372                            }
373                        }
374                        retrieved.tier3 = Some(merged);
375                        break;
376                    }
377                }
378                // If keyword search found nothing but semantic did, use semantic results
379                if retrieved.tier3.is_none() && !semantic_results.is_empty() {
380                    retrieved.tier3 = Some(semantic_results.clone());
381                }
382            } else {
383                if !semantic_results.is_empty() {
384                    // Use semantic results as tier3 content
385                    retrieved.tier3 = Some(semantic_results.clone());
386                } else {
387                    retrieved.tier3 = tier_manager.get_tier3_content(
388                        session_id,
389                        Some((plan.max_messages as i64).min(i32::MAX as i64) as i32),
390                        Some(0),
391                    ).await.ok();
392                }
393            }
394        } else if !semantic_results.is_empty() {
395            // Even if tier3 wasn't planned, if semantic search found relevant content, use it
396            retrieved.tier3 = Some(semantic_results);
397        }
398
399        // Add cross-session search if needed
400        if plan.cross_session_search && !plan.search_topics.is_empty() {
401            let tier_manager = self.tier_manager.read().await;
402            if let Ok(cross_session_results) = tier_manager.search_cross_session_content(
403                session_id,
404                &plan.search_topics.join(" "),
405                10,
406            ).await {
407                retrieved.cross_session = Some(cross_session_results);
408            }
409        }
410
411        Ok(retrieved)
412    }
413    
414    async fn update_engagement(&self, user_query: &str, assistant_response: &str) {
415        debug!("Engagement updated for query: {} (response length: {})", 
416               user_query, assistant_response.len());
417    }
418    
419    pub async fn get_session_stats(&self, session_id: &str) -> anyhow::Result<SessionStats> {
420        let tier_manager = self.tier_manager.read().await;
421        let tier_stats = tier_manager.get_tier_stats(session_id).await;
422        let db_stats = self.database.get_stats()?;
423        
424        Ok(SessionStats {
425            session_id: session_id.to_string(),
426            tier_stats,
427            database_stats: db_stats,
428        })
429    }
430    
431    pub async fn cleanup(&self, older_than_seconds: u64) -> anyhow::Result<CleanupStats> {
432        info!("Starting cleanup of old data");
433        let db_cleaned = self.database.cleanup_old_data((older_than_seconds / 86400) as i32)?;
434        let tier_manager = self.tier_manager.read().await;
435        let cache_cleaned = tier_manager.cleanup_cache(older_than_seconds).await;
436        
437        Ok(CleanupStats {
438            sessions_cleaned: db_cleaned,
439            cache_entries_cleaned: cache_cleaned,
440        })
441    }
442    
443    /// Search messages across sessions by keywords
444    pub async fn search_messages(
445        &self,
446        session_id: Option<&str>,
447        keywords: &[String],
448        limit: usize,
449    ) -> anyhow::Result<Vec<crate::memory_db::StoredMessage>> {
450        if keywords.is_empty() {
451            return Ok(Vec::new());
452        }
453        
454        if let Some(sid) = session_id {
455            // Search within specific session
456            self.database.search_messages_by_keywords(sid, keywords, limit).await
457        } else {
458            // Search across all sessions (would need cross-session search implementation)
459            // For now, return empty results for global search
460            Ok(Vec::new())
461        }
462    }
463    
464    pub fn set_enabled(&mut self, enabled: bool) {
465        self.config.enabled = enabled;
466        info!("Context engine {}", if enabled { "enabled" } else { "disabled" });
467    }
468    
469    pub fn update_config(&mut self, config: OrchestratorConfig) {
470        self.config = config;
471        info!("Context engine configuration updated");
472    }
473    
474    pub fn get_config(&self) -> &OrchestratorConfig {
475        &self.config
476    }
477
478    // Chat persistence: Expose tier manager to ensure sessions exist before processing
479    pub fn tier_manager(&self) -> &Arc<RwLock<TierManager>> {
480        &self.tier_manager
481    }
482}
483
484impl Clone for ContextOrchestrator {
485    fn clone(&self) -> Self {
486        Self {
487            database: self.database.clone(),
488            retrieval_planner: self.retrieval_planner.clone(),
489            tier_manager: self.tier_manager.clone(),
490            context_builder: self.context_builder.clone(),
491            config: self.config.clone(),
492            llm_worker: self.llm_worker.clone(),
493            smart_retrieval: self.smart_retrieval.clone(),
494        }
495    }
496}
497
498#[derive(Debug, Default)]
499struct RetrievedContent {
500    tier1: Option<Vec<Message>>,
501    tier2: Option<Vec<crate::memory_db::Summary>>,
502    tier3: Option<Vec<crate::memory_db::StoredMessage>>,
503    cross_session: Option<Vec<crate::memory_db::StoredMessage>>,
504}
505
506#[derive(Debug, Clone)]
507pub struct SessionStats {
508    pub session_id: String,
509    pub tier_stats: crate::context_engine::tier_manager::TierStats,
510    pub database_stats: crate::memory_db::schema::DatabaseStats,
511}
512
513#[derive(Debug, Clone)]
514pub struct CleanupStats {
515    pub sessions_cleaned: usize,
516    pub cache_entries_cleaned: usize,
517}