Skip to main content

offline_intelligence/context_engine/
orchestrator.rs

1//! Main orchestrator that coordinates all memory subsystems
2
3use crate::memory::Message;
4use crate::memory_db::MemoryDatabase;
5use crate::context_engine::{
6    retrieval_planner::RetrievalPlan,
7    retrieval_planner::RetrievalPlanner,
8    tier_manager::{TierManager, TierManagerConfig},
9    context_builder::{ContextBuilder, ContextBuilderConfig},
10    smart_retrieval::{SmartRetrieval, SmartRetrievalConfig},
11};
12use crate::worker_threads::LLMWorker;
13
14use std::sync::Arc;
15use tracing::{info, debug, warn};
16use tokio::sync::RwLock;
17
18/// Main orchestrator for the context engine
19pub struct ContextOrchestrator {
20    database: Arc<MemoryDatabase>,
21    retrieval_planner: Arc<RwLock<RetrievalPlanner>>,
22    tier_manager: Arc<RwLock<TierManager>>,
23    context_builder: Arc<RwLock<ContextBuilder>>,
24    config: OrchestratorConfig,
25    /// LLM worker for generating query embeddings during semantic search
26    llm_worker: Option<Arc<LLMWorker>>,
27    /// Smart retrieval for optimized context assembly
28    smart_retrieval: Option<Arc<SmartRetrieval>>,
29}
30
31/// Configuration for the orchestrator
32#[derive(Debug, Clone)]
33pub struct OrchestratorConfig {
34    pub enabled: bool,
35    pub max_context_tokens: usize,
36    pub auto_optimize: bool,
37    pub enable_metrics: bool,
38    pub session_timeout_seconds: u64,
39    /// Enable smart retrieval optimization (default: true)
40    pub enable_smart_retrieval: bool,
41    /// Smart retrieval configuration
42    pub smart_retrieval_config: SmartRetrievalConfig,
43    /// Model context window size in tokens — 0 means not set (use defaults)
44    pub ctx_size: u32,
45}
46
47impl Default for OrchestratorConfig {
48    fn default() -> Self {
49        Self {
50            enabled: true,
51            max_context_tokens: 4000,
52            auto_optimize: true,
53            enable_metrics: true,
54            session_timeout_seconds: 3600,
55            enable_smart_retrieval: true,  // Enable by default for production
56            smart_retrieval_config: SmartRetrievalConfig::default(),
57            ctx_size: 0,
58        }
59    }
60}
61
62impl OrchestratorConfig {
63    /// Derive token limits from the model's context window.
64    /// 75% of CTX_SIZE is the ceiling for the total context sent to the LLM,
65    /// leaving 25% headroom for the model's own generation output.
66    pub fn from_ctx_size(ctx_size: u32) -> Self {
67        let max_context_tokens = (ctx_size as f32 * 0.75) as usize;
68        Self {
69            max_context_tokens,
70            smart_retrieval_config: SmartRetrievalConfig::from_ctx_size(ctx_size),
71            ctx_size,
72            ..Self::default()
73        }
74    }
75}
76
77impl ContextOrchestrator {
78    /// Create a new context orchestrator
79    pub async fn new(
80        database: Arc<MemoryDatabase>,
81        config: OrchestratorConfig,
82    ) -> anyhow::Result<Self> {
83        // Create retrieval planner wrapped in Arc<RwLock>
84        let retrieval_planner = Arc::new(RwLock::new(RetrievalPlanner::new(database.clone())));
85        
86        // Create tier manager — derive limits from ctx_size if available
87        let tier_manager_config = if config.ctx_size > 0 {
88            TierManagerConfig::from_ctx_size(config.ctx_size)
89        } else {
90            TierManagerConfig::default()
91        };
92        let tier_manager = TierManager::new(
93            database.clone(),
94            tier_manager_config,
95        );
96        let tier_manager = Arc::new(RwLock::new(tier_manager));
97
98        // Create context builder — derive limits from ctx_size if available
99        let context_builder_config = if config.ctx_size > 0 {
100            ContextBuilderConfig::from_ctx_size(config.ctx_size)
101        } else {
102            ContextBuilderConfig::default()
103        };
104        let context_builder = Arc::new(RwLock::new(ContextBuilder::new(context_builder_config)));
105        
106        // Initialize smart retrieval if enabled
107        let smart_retrieval = if config.enable_smart_retrieval {
108            let smart_ret = SmartRetrieval::new(
109                Arc::clone(&tier_manager),
110                config.smart_retrieval_config.clone(),
111            );
112            info!("Smart retrieval initialized (enabled)");
113            Some(Arc::new(smart_ret))
114        } else {
115            info!("Smart retrieval disabled");
116            None
117        };
118
119        let orchestrator = Self {
120            database,
121            retrieval_planner,
122            tier_manager,
123            context_builder,
124            config,
125            llm_worker: None,
126            smart_retrieval,
127        };
128
129        info!("Context orchestrator initialized successfully");
130
131        Ok(orchestrator)
132    }
133
134    /// Set the LLM worker for embedding-based semantic search
135    pub fn set_llm_worker(&mut self, worker: Arc<LLMWorker>) {
136        self.llm_worker = Some(worker);
137        info!("Context orchestrator: LLM worker set for semantic search");
138    }
139    
140    /// Chat persistence: Expose database for conversation API handlers
141    pub fn database(&self) -> &Arc<MemoryDatabase> {
142        &self.database
143    }
144    
145    /// Process conversation and return optimized context
146    pub async fn process_conversation(
147        &self,
148        session_id: &str,
149        messages: &[Message],
150        user_query: Option<&str>,
151    ) -> anyhow::Result<Vec<Message>> {
152        if !self.config.enabled || messages.is_empty() {
153            debug!("Context engine disabled or no messages");
154            return Ok(messages.to_vec());
155        }
156        
157        info!("Processing conversation for session {} ({} messages)", session_id, messages.len());
158        
159        // Update current messages in Tier 1
160        {
161            let tier_manager = self.tier_manager.write().await;
162            tier_manager.store_tier1_content(session_id, messages).await;
163        }
164
165        // Check if conversation is approaching the context window limit.
166        // At 60% of max_context_tokens, fire off a background summarization task —
167        // non-blocking so the current request returns immediately. The updated summary
168        // will be prepended on the *next* turn, not this one.
169        let estimated_tokens: usize = messages.iter().map(|m| m.content.len() / 4).sum();
170        let summary_threshold = (self.config.max_context_tokens as f32 * 0.60) as usize;
171        if estimated_tokens >= summary_threshold {
172            if let Some(worker) = self.llm_worker.clone() {
173                let db = Arc::clone(&self.database);
174                let sid = session_id.to_string();
175                let msgs = messages.to_vec();
176                tokio::spawn(async move {
177                    generate_and_store_summary(&db, &worker, &sid, &msgs).await;
178                });
179            }
180        }
181
182        // Save ONLY the last user message (new query) to database
183        if let Some(last_message) = messages.last() {
184            if last_message.role == "user" {
185                let tier_manager = self.tier_manager.read().await;
186                if let Err(e) = tier_manager.store_tier3_content(session_id, std::slice::from_ref(last_message)).await {
187                    warn!("Failed to persist user query to database: {}", e);
188                } else {
189                    info!("✅ Persisted user query to database for session {}", session_id);
190                }
191            }
192        }
193        
194        // Create retrieval plan
195        let plan = {
196            let retrieval_planner = self.retrieval_planner.read().await;
197            
198            // --- UPDATED CALL ---
199            // Detect if the user is referring to past conversations
200            let has_past_refs = if let Some(query) = user_query {
201                retrieval_planner.has_past_references_in_text(query)
202            } else {
203                false
204            };
205            
206            // Now create the plan using the detected references and the user query
207            retrieval_planner.create_plan(
208                session_id,
209                messages,
210                self.config.max_context_tokens,
211                user_query,
212                has_past_refs, // Passing the reference check to the planner
213            ).await?
214        };
215        
216        if !plan.needs_retrieval {
217            debug!("No retrieval needed, returning current messages");
218            return Ok(messages.to_vec());
219        }
220        
221        // Execute retrieval plan (includes semantic search when KV cache misses)
222        let retrieved_content = self.execute_retrieval_plan(session_id, &plan, user_query).await?;
223
224        // === SMART RETRIEVAL INTEGRATION ===
225        // If smart retrieval is enabled, use it to optimize the context assembly
226        let optimized_context = if let Some(ref smart_retrieval) = self.smart_retrieval {
227            match smart_retrieval.retrieve(
228                session_id,
229                messages,
230                retrieved_content.tier3.clone(),
231                retrieved_content.cross_session.clone(),
232            ).await {
233                Ok(smart_result) => {
234                    info!(
235                        "🎯 Smart retrieval: Strategy={:?}, Tokens={}, Savings={:.1}%",
236                        smart_result.strategy,
237                        smart_result.retrieved_tokens,
238                        smart_result.compute_savings * 100.0
239                    );
240                    smart_result.messages
241                }
242                Err(e) => {
243                    warn!("Smart retrieval failed, falling back to standard: {}", e);
244                    let mut context_builder = self.context_builder.write().await;
245                    context_builder.build_context(
246                        messages,
247                        retrieved_content.tier1,
248                        retrieved_content.tier3,
249                        retrieved_content.cross_session,
250                        user_query,
251                    ).await?
252                }
253            }
254        } else {
255            // Smart retrieval disabled, use standard context building
256            let mut context_builder = self.context_builder.write().await;
257            context_builder.build_context(
258                messages,
259                retrieved_content.tier1,
260                retrieved_content.tier3,
261                retrieved_content.cross_session,
262                user_query,
263            ).await?
264        };
265        
266        // If a cumulative summary exists for this session (generated at a prior threshold
267        // crossing), prepend it so the LLM always has full history context — not just
268        // the recent window. This is the re-feed path after a context compression event.
269        let mut final_context = self.prepend_session_summary(session_id, optimized_context).await;
270
271        // If we used retrieval, update statistics
272        if let Some(query) = user_query {
273            if let Some(response) = final_context.last() {
274                if response.role == "assistant" {
275                    self.update_engagement(query, &response.content).await;
276                }
277            }
278        }
279
280        info!(
281            "Context optimization complete: {} -> {} messages",
282            messages.len(),
283            final_context.len()
284        );
285
286        Ok(final_context)
287    }
288    
289    /// Prepend the stored cumulative summary as the first system message in the context,
290
291    /// Prepend the stored cumulative summary as the first system message in the context,
292    /// so the LLM has full history even after the active window was trimmed.
293    /// Returns the context unchanged if no summary exists for this session.
294    async fn prepend_session_summary(
295        &self,
296        session_id: &str,
297        mut context: Vec<Message>,
298    ) -> Vec<Message> {
299        match self.database.session_summaries.get(session_id) {
300            Ok(Some(summary)) => {
301                debug!(
302                    "Prepending cumulative summary for session {} (clear #{}, {} tokens)",
303                    session_id, summary.clear_count, summary.token_count
304                );
305                context.insert(0, Message {
306                    role: "system".to_string(),
307                    content: format!(
308                        "[Conversation history summary — covers everything before this window:]\n{}",
309                        summary.summary_text
310                    ),
311                });
312                context
313            }
314            Ok(None) => context,
315            Err(e) => {
316                debug!("Could not fetch summary for session {}: {}", session_id, e);
317                context
318            }
319        }
320    }
321
322    /// Save assistant response to database (Tier 3)
323    pub async fn save_assistant_response(
324        &self,
325        session_id: &str,
326        response: &str,
327    ) -> anyhow::Result<()> {
328        let assistant_message = Message {
329            role: "assistant".to_string(),
330            content: response.to_string(),
331        };
332        
333        let tier_manager = self.tier_manager.read().await;
334        tier_manager.store_tier3_content(session_id, &[assistant_message]).await
335    }
336    
337    /// Execute retrieval plan across all tiers.
338    /// When semantic_search is enabled and we have an LLM worker, we embed the query
339    /// and find similar messages via HNSW — this is the core "KV cache miss → DB retrieval" path.
340    async fn execute_retrieval_plan(
341        &self,
342        session_id: &str,
343        plan: &RetrievalPlan,
344        user_query: Option<&str>,
345    ) -> anyhow::Result<RetrievedContent> {
346        let mut retrieved = RetrievedContent::default();
347
348        // Retrieve from Tier 1 (current context — hot KV cache)
349        if plan.use_tier1 {
350            let tier_manager = self.tier_manager.read().await;
351            retrieved.tier1 = tier_manager.get_tier1_content(session_id).await;
352        }
353
354        // ── Semantic Search: KV cache miss path ──
355        // If the retrieval plan calls for semantic search AND we have embeddings available,
356        // embed the user query and find semantically similar past messages from the DB.
357        // This avoids re-computing full context — we retrieve just the relevant history.
358        //
359        // IMPORTANT: Skip entirely when no embeddings exist yet (first conversation / fresh DB).
360        // This avoids a wasted round-trip to llama-server /v1/embeddings when there's nothing to search.
361        let mut semantic_results: Vec<crate::memory_db::StoredMessage> = Vec::new();
362
363        let has_embeddings = self.database.embeddings.get_stats()
364            .map(|s| s.total_embeddings > 0)
365            .unwrap_or(false);
366
367        if plan.semantic_search && has_embeddings {
368            if let (Some(ref llm_worker), Some(query)) = (&self.llm_worker, user_query) {
369                match llm_worker.generate_embeddings(vec![query.to_string()]).await {
370                    Ok(query_embeddings) if !query_embeddings.is_empty() => {
371                        let query_vec = &query_embeddings[0];
372                        // Search HNSW index for similar past messages
373                        match self.database.embeddings.find_similar_embeddings(
374                            query_vec,
375                            "llama-server",
376                            (plan.max_messages * 2) as i32,
377                            0.3, // similarity threshold
378                        ) {
379                            Ok(similar) if !similar.is_empty() => {
380                                info!("Semantic search found {} similar messages for context retrieval", similar.len());
381                                // Fetch actual message content for each match
382                                for (message_id, _similarity) in &similar {
383                                    // Get message from DB by ID
384                                    let conn = self.database.conversations.get_conn_public();
385                                    if let Ok(conn) = conn {
386                                        let mut stmt = conn.prepare(
387                                            "SELECT id, session_id, message_index, role, content, tokens,
388                                                    timestamp, importance_score, embedding_generated
389                                             FROM messages WHERE id = ?1"
390                                        ).ok();
391                                        if let Some(ref mut stmt) = stmt {
392                                            if let Ok(mut rows) = stmt.query([message_id]) {
393                                                if let Ok(Some(row)) = rows.next() {
394                                                    let ts_str: String = row.get(6).unwrap_or_default();
395                                                    let ts = chrono::DateTime::parse_from_rfc3339(&ts_str)
396                                                        .map(|dt| dt.with_timezone(&chrono::Utc))
397                                                        .unwrap_or_else(|_| chrono::Utc::now());
398                                                    semantic_results.push(crate::memory_db::StoredMessage {
399                                                        id: row.get(0).unwrap_or(0),
400                                                        session_id: row.get(1).unwrap_or_default(),
401                                                        message_index: row.get(2).unwrap_or(0),
402                                                        role: row.get(3).unwrap_or_default(),
403                                                        content: row.get(4).unwrap_or_default(),
404                                                        tokens: row.get(5).unwrap_or(0),
405                                                        timestamp: ts,
406                                                        importance_score: row.get(7).unwrap_or(0.5),
407                                                        embedding_generated: row.get(8).unwrap_or(true),
408                                                    });
409                                                }
410                                            }
411                                        }
412                                    }
413                                }
414                            }
415                            Ok(_) => debug!("Semantic search: no results above threshold"),
416                            Err(e) => debug!("Semantic search failed: {}", e),
417                        }
418                    }
419                    Ok(_) => debug!("Empty embedding response for query"),
420                    Err(e) => debug!("Query embedding generation failed (semantic search skipped): {}", e),
421                }
422            }
423        }
424
425        // Retrieve from Tier 3 (full database) — keyword fallback or supplement
426        if plan.use_tier3 {
427            let tier_manager = self.tier_manager.read().await;
428            if plan.keyword_search && !plan.search_topics.is_empty() {
429                for topic in &plan.search_topics {
430                    let limit_per_topic = plan.max_messages / plan.search_topics.len().max(1);
431
432                    if let Ok(results) = tier_manager.search_tier3_content(
433                        session_id,
434                        topic,
435                        limit_per_topic,
436                    ).await {
437                        // Merge with semantic results, deduplicating by message ID
438                        let semantic_ids: std::collections::HashSet<i64> = semantic_results.iter().map(|m| m.id).collect();
439                        let mut merged = semantic_results.clone();
440                        for msg in results {
441                            if !semantic_ids.contains(&msg.id) {
442                                merged.push(msg);
443                            }
444                        }
445                        retrieved.tier3 = Some(merged);
446                        break;
447                    }
448                }
449                // If keyword search found nothing but semantic did, use semantic results
450                if retrieved.tier3.is_none() && !semantic_results.is_empty() {
451                    retrieved.tier3 = Some(semantic_results.clone());
452                }
453            } else {
454                if !semantic_results.is_empty() {
455                    // Use semantic results as tier3 content
456                    retrieved.tier3 = Some(semantic_results.clone());
457                } else {
458                    retrieved.tier3 = tier_manager.get_tier3_content(
459                        session_id,
460                        Some((plan.max_messages as i64).min(i32::MAX as i64) as i32),
461                        Some(0),
462                    ).await.ok();
463                }
464            }
465        } else if !semantic_results.is_empty() {
466            // Even if tier3 wasn't planned, if semantic search found relevant content, use it
467            retrieved.tier3 = Some(semantic_results);
468        }
469
470        // Add cross-session search if needed
471        if plan.cross_session_search && !plan.search_topics.is_empty() {
472            let tier_manager = self.tier_manager.read().await;
473            if let Ok(cross_session_results) = tier_manager.search_cross_session_content(
474                session_id,
475                &plan.search_topics.join(" "),
476                10,
477            ).await {
478                retrieved.cross_session = Some(cross_session_results);
479            }
480        }
481
482        Ok(retrieved)
483    }
484    
485    async fn update_engagement(&self, user_query: &str, assistant_response: &str) {
486        debug!("Engagement updated for query: {} (response length: {})", 
487               user_query, assistant_response.len());
488    }
489    
490    pub async fn get_session_stats(&self, session_id: &str) -> anyhow::Result<SessionStats> {
491        let tier_manager = self.tier_manager.read().await;
492        let tier_stats = tier_manager.get_tier_stats(session_id).await;
493        let db_stats = self.database.get_stats()?;
494        
495        Ok(SessionStats {
496            session_id: session_id.to_string(),
497            tier_stats,
498            database_stats: db_stats,
499        })
500    }
501    
502    pub async fn cleanup(&self, older_than_seconds: u64) -> anyhow::Result<CleanupStats> {
503        info!("Starting cleanup of old data");
504        let db_cleaned = self.database.cleanup_old_data((older_than_seconds / 86400) as i32)?;
505        let tier_manager = self.tier_manager.read().await;
506        let cache_cleaned = tier_manager.cleanup_cache(older_than_seconds).await;
507        
508        Ok(CleanupStats {
509            sessions_cleaned: db_cleaned,
510            cache_entries_cleaned: cache_cleaned,
511        })
512    }
513    
514    /// Search messages across sessions by keywords
515    pub async fn search_messages(
516        &self,
517        session_id: Option<&str>,
518        keywords: &[String],
519        limit: usize,
520    ) -> anyhow::Result<Vec<crate::memory_db::StoredMessage>> {
521        if keywords.is_empty() {
522            return Ok(Vec::new());
523        }
524        
525        if let Some(sid) = session_id {
526            // Search within specific session
527            self.database.search_messages_by_keywords(sid, keywords, limit).await
528        } else {
529            // Search across all sessions (would need cross-session search implementation)
530            // For now, return empty results for global search
531            Ok(Vec::new())
532        }
533    }
534    
535    pub fn set_enabled(&mut self, enabled: bool) {
536        self.config.enabled = enabled;
537        info!("Context engine {}", if enabled { "enabled" } else { "disabled" });
538    }
539    
540    pub fn update_config(&mut self, config: OrchestratorConfig) {
541        self.config = config;
542        info!("Context engine configuration updated");
543    }
544    
545    pub fn get_config(&self) -> &OrchestratorConfig {
546        &self.config
547    }
548
549    // Chat persistence: Expose tier manager to ensure sessions exist before processing
550    pub fn tier_manager(&self) -> &Arc<RwLock<TierManager>> {
551        &self.tier_manager
552    }
553}
554
555/// Free function — runs in a background `tokio::spawn` task so it never blocks a request.
556/// Generates or updates the single cumulative session summary and persists it to SQLite.
557async fn generate_and_store_summary(
558    database: &Arc<crate::memory_db::MemoryDatabase>,
559    llm_worker: &Arc<LLMWorker>,
560    session_id: &str,
561    messages: &[Message],
562) {
563    if messages.len() < 4 {
564        return;
565    }
566
567    let existing = database.session_summaries.get(session_id).unwrap_or(None);
568
569    let system_content = match &existing {
570        Some(prev) => format!(
571            "You are a concise summarizer. You have a running summary of a conversation \
572             and new messages that occurred since that summary. Produce ONE updated summary \
573             covering EVERYTHING — the prior summary and the new messages combined. \
574             Target under 400 tokens. Include key facts, decisions, code, numbers, names. \
575             No commentary.\n\nPRIOR SUMMARY:\n{}",
576            prev.summary_text
577        ),
578        None => "You are a concise summarizer. Summarize the following conversation \
579                 into key facts, decisions, code snippets, and figures. \
580                 Target under 300 tokens. No commentary.".to_string(),
581    };
582
583    let mut context: Vec<Message> = vec![Message {
584        role: "system".to_string(),
585        content: system_content,
586    }];
587
588    let tail = if messages.len() > 40 { &messages[messages.len() - 40..] } else { messages };
589    context.extend_from_slice(tail);
590
591    let user_prompt = if existing.is_some() {
592        "Produce the updated cumulative summary now, covering both the prior summary and these new messages."
593    } else {
594        "Summarize the conversation above."
595    };
596    context.push(Message { role: "user".to_string(), content: user_prompt.to_string() });
597
598    match llm_worker.generate_response(session_id.to_string(), context).await {
599        Ok(summary) if !summary.trim().is_empty() => {
600            let token_estimate = (summary.len() / 4) as i32;
601            let clear_num = existing.as_ref().map(|s| s.clear_count + 1).unwrap_or(1);
602            match database.session_summaries.upsert(
603                session_id, &summary, token_estimate, messages.len() as i32,
604            ) {
605                Ok(_) => info!(
606                    "Background: updated cumulative summary #{} for session {} ({} tokens)",
607                    clear_num, session_id, token_estimate
608                ),
609                Err(e) => info!("Background: could not persist summary for {}: {}", session_id, e),
610            }
611        }
612        Ok(_) => debug!("Background: summary was empty for session {}", session_id),
613        Err(e) => debug!("Background: summary skipped for {}: {}", session_id, e),
614    }
615}
616
617impl Clone for ContextOrchestrator {
618    fn clone(&self) -> Self {
619        Self {
620            database: self.database.clone(),
621            retrieval_planner: self.retrieval_planner.clone(),
622            tier_manager: self.tier_manager.clone(),
623            context_builder: self.context_builder.clone(),
624            config: self.config.clone(),
625            llm_worker: self.llm_worker.clone(),
626            smart_retrieval: self.smart_retrieval.clone(),
627        }
628    }
629}
630
631#[derive(Debug, Default)]
632struct RetrievedContent {
633    tier1: Option<Vec<Message>>,
634    tier3: Option<Vec<crate::memory_db::StoredMessage>>,
635    cross_session: Option<Vec<crate::memory_db::StoredMessage>>,
636}
637
638#[derive(Debug, Clone)]
639pub struct SessionStats {
640    pub session_id: String,
641    pub tier_stats: crate::context_engine::tier_manager::TierStats,
642    pub database_stats: crate::memory_db::schema::DatabaseStats,
643}
644
645#[derive(Debug, Clone)]
646pub struct CleanupStats {
647    pub sessions_cleaned: usize,
648    pub cache_entries_cleaned: usize,
649}