1use crate::memory::Message;
4use crate::memory_db::MemoryDatabase;
5use crate::context_engine::{
6 retrieval_planner::RetrievalPlan,
7 retrieval_planner::RetrievalPlanner,
8 tier_manager::{TierManager, TierManagerConfig},
9 context_builder::{ContextBuilder, ContextBuilderConfig},
10 smart_retrieval::{SmartRetrieval, SmartRetrievalConfig},
11};
12use crate::worker_threads::LLMWorker;
13
14use std::sync::Arc;
15use tracing::{info, debug, warn};
16use tokio::sync::RwLock;
17
18pub struct ContextOrchestrator {
20 database: Arc<MemoryDatabase>,
21 retrieval_planner: Arc<RwLock<RetrievalPlanner>>,
22 tier_manager: Arc<RwLock<TierManager>>,
23 context_builder: Arc<RwLock<ContextBuilder>>,
24 config: OrchestratorConfig,
25 llm_worker: Option<Arc<LLMWorker>>,
27 smart_retrieval: Option<Arc<SmartRetrieval>>,
29}
30
31#[derive(Debug, Clone)]
33pub struct OrchestratorConfig {
34 pub enabled: bool,
35 pub max_context_tokens: usize,
36 pub auto_optimize: bool,
37 pub enable_metrics: bool,
38 pub session_timeout_seconds: u64,
39 pub enable_smart_retrieval: bool,
41 pub smart_retrieval_config: SmartRetrievalConfig,
43 pub ctx_size: u32,
45}
46
47impl Default for OrchestratorConfig {
48 fn default() -> Self {
49 Self {
50 enabled: true,
51 max_context_tokens: 4000,
52 auto_optimize: true,
53 enable_metrics: true,
54 session_timeout_seconds: 3600,
55 enable_smart_retrieval: true, smart_retrieval_config: SmartRetrievalConfig::default(),
57 ctx_size: 0,
58 }
59 }
60}
61
62impl OrchestratorConfig {
63 pub fn from_ctx_size(ctx_size: u32) -> Self {
67 let max_context_tokens = (ctx_size as f32 * 0.75) as usize;
68 Self {
69 max_context_tokens,
70 smart_retrieval_config: SmartRetrievalConfig::from_ctx_size(ctx_size),
71 ctx_size,
72 ..Self::default()
73 }
74 }
75}
76
77impl ContextOrchestrator {
78 pub async fn new(
80 database: Arc<MemoryDatabase>,
81 config: OrchestratorConfig,
82 ) -> anyhow::Result<Self> {
83 let retrieval_planner = Arc::new(RwLock::new(RetrievalPlanner::new(database.clone())));
85
86 let tier_manager_config = if config.ctx_size > 0 {
88 TierManagerConfig::from_ctx_size(config.ctx_size)
89 } else {
90 TierManagerConfig::default()
91 };
92 let tier_manager = TierManager::new(
93 database.clone(),
94 tier_manager_config,
95 );
96 let tier_manager = Arc::new(RwLock::new(tier_manager));
97
98 let context_builder_config = if config.ctx_size > 0 {
100 ContextBuilderConfig::from_ctx_size(config.ctx_size)
101 } else {
102 ContextBuilderConfig::default()
103 };
104 let context_builder = Arc::new(RwLock::new(ContextBuilder::new(context_builder_config)));
105
106 let smart_retrieval = if config.enable_smart_retrieval {
108 let smart_ret = SmartRetrieval::new(
109 Arc::clone(&tier_manager),
110 config.smart_retrieval_config.clone(),
111 );
112 info!("Smart retrieval initialized (enabled)");
113 Some(Arc::new(smart_ret))
114 } else {
115 info!("Smart retrieval disabled");
116 None
117 };
118
119 let orchestrator = Self {
120 database,
121 retrieval_planner,
122 tier_manager,
123 context_builder,
124 config,
125 llm_worker: None,
126 smart_retrieval,
127 };
128
129 info!("Context orchestrator initialized successfully");
130
131 Ok(orchestrator)
132 }
133
134 pub fn set_llm_worker(&mut self, worker: Arc<LLMWorker>) {
136 self.llm_worker = Some(worker);
137 info!("Context orchestrator: LLM worker set for semantic search");
138 }
139
140 pub fn database(&self) -> &Arc<MemoryDatabase> {
142 &self.database
143 }
144
145 pub async fn process_conversation(
147 &self,
148 session_id: &str,
149 messages: &[Message],
150 user_query: Option<&str>,
151 ) -> anyhow::Result<Vec<Message>> {
152 if !self.config.enabled || messages.is_empty() {
153 debug!("Context engine disabled or no messages");
154 return Ok(messages.to_vec());
155 }
156
157 info!("Processing conversation for session {} ({} messages)", session_id, messages.len());
158
159 {
161 let tier_manager = self.tier_manager.write().await;
162 tier_manager.store_tier1_content(session_id, messages).await;
163 }
164
165 let estimated_tokens: usize = messages.iter().map(|m| m.content.len() / 4).sum();
170 let summary_threshold = (self.config.max_context_tokens as f32 * 0.60) as usize;
171 if estimated_tokens >= summary_threshold {
172 if let Some(worker) = self.llm_worker.clone() {
173 let db = Arc::clone(&self.database);
174 let sid = session_id.to_string();
175 let msgs = messages.to_vec();
176 tokio::spawn(async move {
177 generate_and_store_summary(&db, &worker, &sid, &msgs).await;
178 });
179 }
180 }
181
182 if let Some(last_message) = messages.last() {
184 if last_message.role == "user" {
185 let tier_manager = self.tier_manager.read().await;
186 if let Err(e) = tier_manager.store_tier3_content(session_id, std::slice::from_ref(last_message)).await {
187 warn!("Failed to persist user query to database: {}", e);
188 } else {
189 info!("✅ Persisted user query to database for session {}", session_id);
190 }
191 }
192 }
193
194 let plan = {
196 let retrieval_planner = self.retrieval_planner.read().await;
197
198 let has_past_refs = if let Some(query) = user_query {
201 retrieval_planner.has_past_references_in_text(query)
202 } else {
203 false
204 };
205
206 retrieval_planner.create_plan(
208 session_id,
209 messages,
210 self.config.max_context_tokens,
211 user_query,
212 has_past_refs, ).await?
214 };
215
216 if !plan.needs_retrieval {
217 debug!("No retrieval needed, returning current messages");
218 return Ok(messages.to_vec());
219 }
220
221 let retrieved_content = self.execute_retrieval_plan(session_id, &plan, user_query).await?;
223
224 let optimized_context = if let Some(ref smart_retrieval) = self.smart_retrieval {
227 match smart_retrieval.retrieve(
228 session_id,
229 messages,
230 retrieved_content.tier3.clone(),
231 retrieved_content.cross_session.clone(),
232 ).await {
233 Ok(smart_result) => {
234 info!(
235 "🎯 Smart retrieval: Strategy={:?}, Tokens={}, Savings={:.1}%",
236 smart_result.strategy,
237 smart_result.retrieved_tokens,
238 smart_result.compute_savings * 100.0
239 );
240 smart_result.messages
241 }
242 Err(e) => {
243 warn!("Smart retrieval failed, falling back to standard: {}", e);
244 let mut context_builder = self.context_builder.write().await;
245 context_builder.build_context(
246 messages,
247 retrieved_content.tier1,
248 retrieved_content.tier3,
249 retrieved_content.cross_session,
250 user_query,
251 ).await?
252 }
253 }
254 } else {
255 let mut context_builder = self.context_builder.write().await;
257 context_builder.build_context(
258 messages,
259 retrieved_content.tier1,
260 retrieved_content.tier3,
261 retrieved_content.cross_session,
262 user_query,
263 ).await?
264 };
265
266 let mut final_context = self.prepend_session_summary(session_id, optimized_context).await;
270
271 if let Some(query) = user_query {
273 if let Some(response) = final_context.last() {
274 if response.role == "assistant" {
275 self.update_engagement(query, &response.content).await;
276 }
277 }
278 }
279
280 info!(
281 "Context optimization complete: {} -> {} messages",
282 messages.len(),
283 final_context.len()
284 );
285
286 Ok(final_context)
287 }
288
289 async fn prepend_session_summary(
295 &self,
296 session_id: &str,
297 mut context: Vec<Message>,
298 ) -> Vec<Message> {
299 match self.database.session_summaries.get(session_id) {
300 Ok(Some(summary)) => {
301 debug!(
302 "Prepending cumulative summary for session {} (clear #{}, {} tokens)",
303 session_id, summary.clear_count, summary.token_count
304 );
305 context.insert(0, Message {
306 role: "system".to_string(),
307 content: format!(
308 "[Conversation history summary — covers everything before this window:]\n{}",
309 summary.summary_text
310 ),
311 });
312 context
313 }
314 Ok(None) => context,
315 Err(e) => {
316 debug!("Could not fetch summary for session {}: {}", session_id, e);
317 context
318 }
319 }
320 }
321
322 pub async fn save_assistant_response(
324 &self,
325 session_id: &str,
326 response: &str,
327 ) -> anyhow::Result<()> {
328 let assistant_message = Message {
329 role: "assistant".to_string(),
330 content: response.to_string(),
331 };
332
333 let tier_manager = self.tier_manager.read().await;
334 tier_manager.store_tier3_content(session_id, &[assistant_message]).await
335 }
336
337 async fn execute_retrieval_plan(
341 &self,
342 session_id: &str,
343 plan: &RetrievalPlan,
344 user_query: Option<&str>,
345 ) -> anyhow::Result<RetrievedContent> {
346 let mut retrieved = RetrievedContent::default();
347
348 if plan.use_tier1 {
350 let tier_manager = self.tier_manager.read().await;
351 retrieved.tier1 = tier_manager.get_tier1_content(session_id).await;
352 }
353
354 let mut semantic_results: Vec<crate::memory_db::StoredMessage> = Vec::new();
362
363 let has_embeddings = self.database.embeddings.get_stats()
364 .map(|s| s.total_embeddings > 0)
365 .unwrap_or(false);
366
367 if plan.semantic_search && has_embeddings {
368 if let (Some(ref llm_worker), Some(query)) = (&self.llm_worker, user_query) {
369 match llm_worker.generate_embeddings(vec![query.to_string()]).await {
370 Ok(query_embeddings) if !query_embeddings.is_empty() => {
371 let query_vec = &query_embeddings[0];
372 match self.database.embeddings.find_similar_embeddings(
374 query_vec,
375 "llama-server",
376 (plan.max_messages * 2) as i32,
377 0.3, ) {
379 Ok(similar) if !similar.is_empty() => {
380 info!("Semantic search found {} similar messages for context retrieval", similar.len());
381 for (message_id, _similarity) in &similar {
383 let conn = self.database.conversations.get_conn_public();
385 if let Ok(conn) = conn {
386 let mut stmt = conn.prepare(
387 "SELECT id, session_id, message_index, role, content, tokens,
388 timestamp, importance_score, embedding_generated
389 FROM messages WHERE id = ?1"
390 ).ok();
391 if let Some(ref mut stmt) = stmt {
392 if let Ok(mut rows) = stmt.query([message_id]) {
393 if let Ok(Some(row)) = rows.next() {
394 let ts_str: String = row.get(6).unwrap_or_default();
395 let ts = chrono::DateTime::parse_from_rfc3339(&ts_str)
396 .map(|dt| dt.with_timezone(&chrono::Utc))
397 .unwrap_or_else(|_| chrono::Utc::now());
398 semantic_results.push(crate::memory_db::StoredMessage {
399 id: row.get(0).unwrap_or(0),
400 session_id: row.get(1).unwrap_or_default(),
401 message_index: row.get(2).unwrap_or(0),
402 role: row.get(3).unwrap_or_default(),
403 content: row.get(4).unwrap_or_default(),
404 tokens: row.get(5).unwrap_or(0),
405 timestamp: ts,
406 importance_score: row.get(7).unwrap_or(0.5),
407 embedding_generated: row.get(8).unwrap_or(true),
408 });
409 }
410 }
411 }
412 }
413 }
414 }
415 Ok(_) => debug!("Semantic search: no results above threshold"),
416 Err(e) => debug!("Semantic search failed: {}", e),
417 }
418 }
419 Ok(_) => debug!("Empty embedding response for query"),
420 Err(e) => debug!("Query embedding generation failed (semantic search skipped): {}", e),
421 }
422 }
423 }
424
425 if plan.use_tier3 {
427 let tier_manager = self.tier_manager.read().await;
428 if plan.keyword_search && !plan.search_topics.is_empty() {
429 for topic in &plan.search_topics {
430 let limit_per_topic = plan.max_messages / plan.search_topics.len().max(1);
431
432 if let Ok(results) = tier_manager.search_tier3_content(
433 session_id,
434 topic,
435 limit_per_topic,
436 ).await {
437 let semantic_ids: std::collections::HashSet<i64> = semantic_results.iter().map(|m| m.id).collect();
439 let mut merged = semantic_results.clone();
440 for msg in results {
441 if !semantic_ids.contains(&msg.id) {
442 merged.push(msg);
443 }
444 }
445 retrieved.tier3 = Some(merged);
446 break;
447 }
448 }
449 if retrieved.tier3.is_none() && !semantic_results.is_empty() {
451 retrieved.tier3 = Some(semantic_results.clone());
452 }
453 } else {
454 if !semantic_results.is_empty() {
455 retrieved.tier3 = Some(semantic_results.clone());
457 } else {
458 retrieved.tier3 = tier_manager.get_tier3_content(
459 session_id,
460 Some((plan.max_messages as i64).min(i32::MAX as i64) as i32),
461 Some(0),
462 ).await.ok();
463 }
464 }
465 } else if !semantic_results.is_empty() {
466 retrieved.tier3 = Some(semantic_results);
468 }
469
470 if plan.cross_session_search && !plan.search_topics.is_empty() {
472 let tier_manager = self.tier_manager.read().await;
473 if let Ok(cross_session_results) = tier_manager.search_cross_session_content(
474 session_id,
475 &plan.search_topics.join(" "),
476 10,
477 ).await {
478 retrieved.cross_session = Some(cross_session_results);
479 }
480 }
481
482 Ok(retrieved)
483 }
484
485 async fn update_engagement(&self, user_query: &str, assistant_response: &str) {
486 debug!("Engagement updated for query: {} (response length: {})",
487 user_query, assistant_response.len());
488 }
489
490 pub async fn get_session_stats(&self, session_id: &str) -> anyhow::Result<SessionStats> {
491 let tier_manager = self.tier_manager.read().await;
492 let tier_stats = tier_manager.get_tier_stats(session_id).await;
493 let db_stats = self.database.get_stats()?;
494
495 Ok(SessionStats {
496 session_id: session_id.to_string(),
497 tier_stats,
498 database_stats: db_stats,
499 })
500 }
501
502 pub async fn cleanup(&self, older_than_seconds: u64) -> anyhow::Result<CleanupStats> {
503 info!("Starting cleanup of old data");
504 let db_cleaned = self.database.cleanup_old_data((older_than_seconds / 86400) as i32)?;
505 let tier_manager = self.tier_manager.read().await;
506 let cache_cleaned = tier_manager.cleanup_cache(older_than_seconds).await;
507
508 Ok(CleanupStats {
509 sessions_cleaned: db_cleaned,
510 cache_entries_cleaned: cache_cleaned,
511 })
512 }
513
514 pub async fn search_messages(
516 &self,
517 session_id: Option<&str>,
518 keywords: &[String],
519 limit: usize,
520 ) -> anyhow::Result<Vec<crate::memory_db::StoredMessage>> {
521 if keywords.is_empty() {
522 return Ok(Vec::new());
523 }
524
525 if let Some(sid) = session_id {
526 self.database.search_messages_by_keywords(sid, keywords, limit).await
528 } else {
529 Ok(Vec::new())
532 }
533 }
534
535 pub fn set_enabled(&mut self, enabled: bool) {
536 self.config.enabled = enabled;
537 info!("Context engine {}", if enabled { "enabled" } else { "disabled" });
538 }
539
540 pub fn update_config(&mut self, config: OrchestratorConfig) {
541 self.config = config;
542 info!("Context engine configuration updated");
543 }
544
545 pub fn get_config(&self) -> &OrchestratorConfig {
546 &self.config
547 }
548
549 pub fn tier_manager(&self) -> &Arc<RwLock<TierManager>> {
551 &self.tier_manager
552 }
553}
554
555async fn generate_and_store_summary(
558 database: &Arc<crate::memory_db::MemoryDatabase>,
559 llm_worker: &Arc<LLMWorker>,
560 session_id: &str,
561 messages: &[Message],
562) {
563 if messages.len() < 4 {
564 return;
565 }
566
567 let existing = database.session_summaries.get(session_id).unwrap_or(None);
568
569 let system_content = match &existing {
570 Some(prev) => format!(
571 "You are a concise summarizer. You have a running summary of a conversation \
572 and new messages that occurred since that summary. Produce ONE updated summary \
573 covering EVERYTHING — the prior summary and the new messages combined. \
574 Target under 400 tokens. Include key facts, decisions, code, numbers, names. \
575 No commentary.\n\nPRIOR SUMMARY:\n{}",
576 prev.summary_text
577 ),
578 None => "You are a concise summarizer. Summarize the following conversation \
579 into key facts, decisions, code snippets, and figures. \
580 Target under 300 tokens. No commentary.".to_string(),
581 };
582
583 let mut context: Vec<Message> = vec![Message {
584 role: "system".to_string(),
585 content: system_content,
586 }];
587
588 let tail = if messages.len() > 40 { &messages[messages.len() - 40..] } else { messages };
589 context.extend_from_slice(tail);
590
591 let user_prompt = if existing.is_some() {
592 "Produce the updated cumulative summary now, covering both the prior summary and these new messages."
593 } else {
594 "Summarize the conversation above."
595 };
596 context.push(Message { role: "user".to_string(), content: user_prompt.to_string() });
597
598 match llm_worker.generate_response(session_id.to_string(), context).await {
599 Ok(summary) if !summary.trim().is_empty() => {
600 let token_estimate = (summary.len() / 4) as i32;
601 let clear_num = existing.as_ref().map(|s| s.clear_count + 1).unwrap_or(1);
602 match database.session_summaries.upsert(
603 session_id, &summary, token_estimate, messages.len() as i32,
604 ) {
605 Ok(_) => info!(
606 "Background: updated cumulative summary #{} for session {} ({} tokens)",
607 clear_num, session_id, token_estimate
608 ),
609 Err(e) => info!("Background: could not persist summary for {}: {}", session_id, e),
610 }
611 }
612 Ok(_) => debug!("Background: summary was empty for session {}", session_id),
613 Err(e) => debug!("Background: summary skipped for {}: {}", session_id, e),
614 }
615}
616
617impl Clone for ContextOrchestrator {
618 fn clone(&self) -> Self {
619 Self {
620 database: self.database.clone(),
621 retrieval_planner: self.retrieval_planner.clone(),
622 tier_manager: self.tier_manager.clone(),
623 context_builder: self.context_builder.clone(),
624 config: self.config.clone(),
625 llm_worker: self.llm_worker.clone(),
626 smart_retrieval: self.smart_retrieval.clone(),
627 }
628 }
629}
630
631#[derive(Debug, Default)]
632struct RetrievedContent {
633 tier1: Option<Vec<Message>>,
634 tier3: Option<Vec<crate::memory_db::StoredMessage>>,
635 cross_session: Option<Vec<crate::memory_db::StoredMessage>>,
636}
637
638#[derive(Debug, Clone)]
639pub struct SessionStats {
640 pub session_id: String,
641 pub tier_stats: crate::context_engine::tier_manager::TierStats,
642 pub database_stats: crate::memory_db::schema::DatabaseStats,
643}
644
645#[derive(Debug, Clone)]
646pub struct CleanupStats {
647 pub sessions_cleaned: usize,
648 pub cache_entries_cleaned: usize,
649}