Skip to main content

offline_intelligence/context_engine/
orchestrator.rs

1//! Main orchestrator that coordinates all memory subsystems
2
3use crate::memory::Message;
4use crate::memory_db::MemoryDatabase;
5use crate::memory_db::schema::Embedding;
6use crate::context_engine::{
7    retrieval_planner::RetrievalPlan,
8    retrieval_planner::RetrievalPlanner,
9    tier_manager::{TierManager, TierManagerConfig},
10    context_builder::{ContextBuilder, ContextBuilderConfig},
11};
12use crate::worker_threads::LLMWorker;
13
14use std::sync::Arc;
15use tracing::{info, debug, warn};
16use tokio::sync::RwLock;
17
18/// Main orchestrator for the context engine
19pub struct ContextOrchestrator {
20    database: Arc<MemoryDatabase>,
21    retrieval_planner: Arc<RwLock<RetrievalPlanner>>,
22    tier_manager: Arc<RwLock<TierManager>>,
23    context_builder: Arc<RwLock<ContextBuilder>>,
24    config: OrchestratorConfig,
25    /// LLM worker for generating query embeddings during semantic search
26    llm_worker: Option<Arc<LLMWorker>>,
27}
28
29/// Configuration for the orchestrator
30#[derive(Debug, Clone)]
31pub struct OrchestratorConfig {
32    pub enabled: bool,
33    pub max_context_tokens: usize,
34    pub auto_optimize: bool,
35    pub enable_metrics: bool,
36    pub session_timeout_seconds: u64,
37}
38
39impl Default for OrchestratorConfig {
40    fn default() -> Self {
41        Self {
42            enabled: true,
43            max_context_tokens: 4000,
44            auto_optimize: true,
45            enable_metrics: true,
46            session_timeout_seconds: 3600,
47        }
48    }
49}
50
51impl ContextOrchestrator {
52    /// Create a new context orchestrator
53    pub async fn new(
54        database: Arc<MemoryDatabase>,
55        config: OrchestratorConfig,
56    ) -> anyhow::Result<Self> {
57        // Create retrieval planner wrapped in Arc<RwLock>
58        let retrieval_planner = Arc::new(RwLock::new(RetrievalPlanner::new(database.clone())));
59        
60        // Create tier manager
61        let tier_manager_config = TierManagerConfig::default();
62        let tier_manager = TierManager::new(
63            database.clone(),
64            tier_manager_config,
65        );
66        let tier_manager = Arc::new(RwLock::new(tier_manager));
67        
68        // Create context builder wrapped in Arc<RwLock>
69        let context_builder_config = ContextBuilderConfig::default();
70        let context_builder = Arc::new(RwLock::new(ContextBuilder::new(context_builder_config)));
71        
72        let orchestrator = Self {
73            database,
74            retrieval_planner,
75            tier_manager,
76            context_builder,
77            config,
78            llm_worker: None,
79        };
80
81        info!("Context orchestrator initialized successfully");
82
83        Ok(orchestrator)
84    }
85
86    /// Set the LLM worker for embedding-based semantic search
87    pub fn set_llm_worker(&mut self, worker: Arc<LLMWorker>) {
88        self.llm_worker = Some(worker);
89        info!("Context orchestrator: LLM worker set for semantic search");
90    }
91    
92    /// Chat persistence: Expose database for conversation API handlers
93    pub fn database(&self) -> &Arc<MemoryDatabase> {
94        &self.database
95    }
96    
97    /// Process conversation and return optimized context
98    pub async fn process_conversation(
99        &self,
100        session_id: &str,
101        messages: &[Message],
102        user_query: Option<&str>,
103    ) -> anyhow::Result<Vec<Message>> {
104        if !self.config.enabled || messages.is_empty() {
105            debug!("Context engine disabled or no messages");
106            return Ok(messages.to_vec());
107        }
108        
109        info!("Processing conversation for session {} ({} messages)", session_id, messages.len());
110        
111        // Update current messages in Tier 1
112        {
113            let tier_manager = self.tier_manager.write().await;
114            tier_manager.store_tier1_content(session_id, messages).await;
115        }
116        
117        // Save ONLY the last user message (new query) to database
118        if let Some(last_message) = messages.last() {
119            if last_message.role == "user" {
120                let tier_manager = self.tier_manager.read().await;
121                if let Err(e) = tier_manager.store_tier3_content(session_id, std::slice::from_ref(last_message)).await {
122                    warn!("Failed to persist user query to database: {}", e);
123                } else {
124                    info!("✅ Persisted user query to database for session {}", session_id);
125                }
126            }
127        }
128        
129        // Create retrieval plan
130        let plan = {
131            let retrieval_planner = self.retrieval_planner.read().await;
132            
133            // --- UPDATED CALL ---
134            // Detect if the user is referring to past conversations
135            let has_past_refs = if let Some(query) = user_query {
136                retrieval_planner.has_past_references_in_text(query)
137            } else {
138                false
139            };
140            
141            // Now create the plan using the detected references and the user query
142            retrieval_planner.create_plan(
143                session_id,
144                messages,
145                self.config.max_context_tokens,
146                user_query,
147                has_past_refs, // Passing the reference check to the planner
148            ).await?
149        };
150        
151        if !plan.needs_retrieval {
152            debug!("No retrieval needed, returning current messages");
153            return Ok(messages.to_vec());
154        }
155        
156        // Execute retrieval plan (includes semantic search when KV cache misses)
157        let retrieved_content = self.execute_retrieval_plan(session_id, &plan, user_query).await?;
158        
159        // Build optimized context
160        let optimized_context = {
161            let mut context_builder = self.context_builder.write().await;
162            context_builder.build_context(
163                messages,
164                retrieved_content.tier1,
165                retrieved_content.tier2,
166                retrieved_content.tier3,
167                retrieved_content.cross_session,
168                user_query,
169            ).await?
170        };
171        
172        // If we used retrieval, update statistics
173        if let Some(query) = user_query {
174            if let Some(response) = optimized_context.last() {
175                if response.role == "assistant" {
176                    self.update_engagement(query, &response.content).await;
177                }
178            }
179        }
180        
181        info!(
182            "Context optimization complete: {} -> {} messages",
183            messages.len(),
184            optimized_context.len()
185        );
186        
187        Ok(optimized_context)
188    }
189    
190    /// Save assistant response to database (Tier 3)
191    pub async fn save_assistant_response(
192        &self,
193        session_id: &str,
194        response: &str,
195    ) -> anyhow::Result<()> {
196        let assistant_message = Message {
197            role: "assistant".to_string(),
198            content: response.to_string(),
199        };
200        
201        let tier_manager = self.tier_manager.read().await;
202        tier_manager.store_tier3_content(session_id, &[assistant_message]).await
203    }
204    
205    /// Execute retrieval plan across all tiers.
206    /// When semantic_search is enabled and we have an LLM worker, we embed the query
207    /// and find similar messages via HNSW — this is the core "KV cache miss → DB retrieval" path.
208    async fn execute_retrieval_plan(
209        &self,
210        session_id: &str,
211        plan: &RetrievalPlan,
212        user_query: Option<&str>,
213    ) -> anyhow::Result<RetrievedContent> {
214        let mut retrieved = RetrievedContent::default();
215
216        // Retrieve from Tier 1 (current context — hot KV cache)
217        if plan.use_tier1 {
218            let tier_manager = self.tier_manager.read().await;
219            retrieved.tier1 = tier_manager.get_tier1_content(session_id).await;
220        }
221
222        // Retrieve from Tier 2 (summaries)
223        if plan.use_tier2 {
224            let tier_manager = self.tier_manager.read().await;
225            retrieved.tier2 = tier_manager.get_tier2_content(session_id).await;
226        }
227
228        // ── Semantic Search: KV cache miss path ──
229        // If the retrieval plan calls for semantic search AND we have embeddings available,
230        // embed the user query and find semantically similar past messages from the DB.
231        // This avoids re-computing full context — we retrieve just the relevant history.
232        //
233        // IMPORTANT: Skip entirely when no embeddings exist yet (first conversation / fresh DB).
234        // This avoids a wasted round-trip to llama-server /v1/embeddings when there's nothing to search.
235        let mut semantic_results: Vec<crate::memory_db::StoredMessage> = Vec::new();
236
237        let has_embeddings = self.database.embeddings.get_stats()
238            .map(|s| s.total_embeddings > 0)
239            .unwrap_or(false);
240
241        if plan.semantic_search && has_embeddings {
242            if let (Some(ref llm_worker), Some(query)) = (&self.llm_worker, user_query) {
243                match llm_worker.generate_embeddings(vec![query.to_string()]).await {
244                    Ok(query_embeddings) if !query_embeddings.is_empty() => {
245                        let query_vec = &query_embeddings[0];
246                        // Search HNSW index for similar past messages
247                        match self.database.embeddings.find_similar_embeddings(
248                            query_vec,
249                            "llama-server",
250                            (plan.max_messages * 2) as i32,
251                            0.3, // similarity threshold
252                        ) {
253                            Ok(similar) if !similar.is_empty() => {
254                                info!("Semantic search found {} similar messages for context retrieval", similar.len());
255                                // Fetch actual message content for each match
256                                for (message_id, _similarity) in &similar {
257                                    // Get message from DB by ID
258                                    let conn = self.database.conversations.get_conn_public();
259                                    if let Ok(conn) = conn {
260                                        let mut stmt = conn.prepare(
261                                            "SELECT id, session_id, message_index, role, content, tokens,
262                                                    timestamp, importance_score, embedding_generated
263                                             FROM messages WHERE id = ?1"
264                                        ).ok();
265                                        if let Some(ref mut stmt) = stmt {
266                                            if let Ok(mut rows) = stmt.query([message_id]) {
267                                                if let Ok(Some(row)) = rows.next() {
268                                                    let ts_str: String = row.get(6).unwrap_or_default();
269                                                    let ts = chrono::DateTime::parse_from_rfc3339(&ts_str)
270                                                        .map(|dt| dt.with_timezone(&chrono::Utc))
271                                                        .unwrap_or_else(|_| chrono::Utc::now());
272                                                    semantic_results.push(crate::memory_db::StoredMessage {
273                                                        id: row.get(0).unwrap_or(0),
274                                                        session_id: row.get(1).unwrap_or_default(),
275                                                        message_index: row.get(2).unwrap_or(0),
276                                                        role: row.get(3).unwrap_or_default(),
277                                                        content: row.get(4).unwrap_or_default(),
278                                                        tokens: row.get(5).unwrap_or(0),
279                                                        timestamp: ts,
280                                                        importance_score: row.get(7).unwrap_or(0.5),
281                                                        embedding_generated: row.get(8).unwrap_or(true),
282                                                    });
283                                                }
284                                            }
285                                        }
286                                    }
287                                }
288                            }
289                            Ok(_) => debug!("Semantic search: no results above threshold"),
290                            Err(e) => debug!("Semantic search failed: {}", e),
291                        }
292                    }
293                    Ok(_) => debug!("Empty embedding response for query"),
294                    Err(e) => debug!("Query embedding generation failed (semantic search skipped): {}", e),
295                }
296            }
297        }
298
299        // Retrieve from Tier 3 (full database) — keyword fallback or supplement
300        if plan.use_tier3 {
301            let tier_manager = self.tier_manager.read().await;
302            if plan.keyword_search && !plan.search_topics.is_empty() {
303                for topic in &plan.search_topics {
304                    let limit_per_topic = plan.max_messages / plan.search_topics.len().max(1);
305
306                    if let Ok(results) = tier_manager.search_tier3_content(
307                        session_id,
308                        topic,
309                        limit_per_topic,
310                    ).await {
311                        // Merge with semantic results, deduplicating by message ID
312                        let semantic_ids: std::collections::HashSet<i64> = semantic_results.iter().map(|m| m.id).collect();
313                        let mut merged = semantic_results.clone();
314                        for msg in results {
315                            if !semantic_ids.contains(&msg.id) {
316                                merged.push(msg);
317                            }
318                        }
319                        retrieved.tier3 = Some(merged);
320                        break;
321                    }
322                }
323                // If keyword search found nothing but semantic did, use semantic results
324                if retrieved.tier3.is_none() && !semantic_results.is_empty() {
325                    retrieved.tier3 = Some(semantic_results.clone());
326                }
327            } else {
328                if !semantic_results.is_empty() {
329                    // Use semantic results as tier3 content
330                    retrieved.tier3 = Some(semantic_results.clone());
331                } else {
332                    retrieved.tier3 = tier_manager.get_tier3_content(
333                        session_id,
334                        Some((plan.max_messages as i64).min(i32::MAX as i64) as i32),
335                        Some(0),
336                    ).await.ok();
337                }
338            }
339        } else if !semantic_results.is_empty() {
340            // Even if tier3 wasn't planned, if semantic search found relevant content, use it
341            retrieved.tier3 = Some(semantic_results);
342        }
343
344        // Add cross-session search if needed
345        if plan.cross_session_search && !plan.search_topics.is_empty() {
346            let tier_manager = self.tier_manager.read().await;
347            if let Ok(cross_session_results) = tier_manager.search_cross_session_content(
348                session_id,
349                &plan.search_topics.join(" "),
350                10,
351            ).await {
352                retrieved.cross_session = Some(cross_session_results);
353            }
354        }
355
356        Ok(retrieved)
357    }
358    
359    async fn update_engagement(&self, user_query: &str, assistant_response: &str) {
360        debug!("Engagement updated for query: {} (response length: {})", 
361               user_query, assistant_response.len());
362    }
363    
364    pub async fn get_session_stats(&self, session_id: &str) -> anyhow::Result<SessionStats> {
365        let tier_manager = self.tier_manager.read().await;
366        let tier_stats = tier_manager.get_tier_stats(session_id).await;
367        let db_stats = self.database.get_stats()?;
368        
369        Ok(SessionStats {
370            session_id: session_id.to_string(),
371            tier_stats,
372            database_stats: db_stats,
373        })
374    }
375    
376    pub async fn cleanup(&self, older_than_seconds: u64) -> anyhow::Result<CleanupStats> {
377        info!("Starting cleanup of old data");
378        let db_cleaned = self.database.cleanup_old_data((older_than_seconds / 86400) as i32)?;
379        let tier_manager = self.tier_manager.read().await;
380        let cache_cleaned = tier_manager.cleanup_cache(older_than_seconds).await;
381        
382        Ok(CleanupStats {
383            sessions_cleaned: db_cleaned,
384            cache_entries_cleaned: cache_cleaned,
385        })
386    }
387    
388    /// Search messages across sessions by keywords
389    pub async fn search_messages(
390        &self,
391        session_id: Option<&str>,
392        keywords: &[String],
393        limit: usize,
394    ) -> anyhow::Result<Vec<crate::memory_db::StoredMessage>> {
395        if keywords.is_empty() {
396            return Ok(Vec::new());
397        }
398        
399        if let Some(sid) = session_id {
400            // Search within specific session
401            self.database.search_messages_by_keywords(sid, keywords, limit).await
402        } else {
403            // Search across all sessions (would need cross-session search implementation)
404            // For now, return empty results for global search
405            Ok(Vec::new())
406        }
407    }
408    
409    pub fn set_enabled(&mut self, enabled: bool) {
410        self.config.enabled = enabled;
411        info!("Context engine {}", if enabled { "enabled" } else { "disabled" });
412    }
413    
414    pub fn update_config(&mut self, config: OrchestratorConfig) {
415        self.config = config;
416        info!("Context engine configuration updated");
417    }
418    
419    pub fn get_config(&self) -> &OrchestratorConfig {
420        &self.config
421    }
422
423    // Chat persistence: Expose tier manager to ensure sessions exist before processing
424    pub fn tier_manager(&self) -> &Arc<RwLock<TierManager>> {
425        &self.tier_manager
426    }
427}
428
429impl Clone for ContextOrchestrator {
430    fn clone(&self) -> Self {
431        Self {
432            database: self.database.clone(),
433            retrieval_planner: self.retrieval_planner.clone(),
434            tier_manager: self.tier_manager.clone(),
435            context_builder: self.context_builder.clone(),
436            config: self.config.clone(),
437            llm_worker: self.llm_worker.clone(),
438        }
439    }
440}
441
442#[derive(Debug, Default)]
443struct RetrievedContent {
444    tier1: Option<Vec<Message>>,
445    tier2: Option<Vec<crate::memory_db::Summary>>,
446    tier3: Option<Vec<crate::memory_db::StoredMessage>>,
447    cross_session: Option<Vec<crate::memory_db::StoredMessage>>,
448}
449
450#[derive(Debug, Clone)]
451pub struct SessionStats {
452    pub session_id: String,
453    pub tier_stats: crate::context_engine::tier_manager::TierStats,
454    pub database_stats: crate::memory_db::schema::DatabaseStats,
455}
456
457#[derive(Debug, Clone)]
458pub struct CleanupStats {
459    pub sessions_cleaned: usize,
460    pub cache_entries_cleaned: usize,
461}