1use crate::memory::Message;
4use crate::memory_db::MemoryDatabase;
5use crate::context_engine::{
6 retrieval_planner::RetrievalPlan,
7 retrieval_planner::RetrievalPlanner,
8 tier_manager::{TierManager, TierManagerConfig},
9 context_builder::{ContextBuilder, ContextBuilderConfig},
10 smart_retrieval::{SmartRetrieval, SmartRetrievalConfig},
11};
12use crate::worker_threads::LLMWorker;
13
14use std::sync::Arc;
15use tracing::{info, debug, warn};
16use tokio::sync::RwLock;
17
18pub struct ContextOrchestrator {
20 database: Arc<MemoryDatabase>,
21 retrieval_planner: Arc<RwLock<RetrievalPlanner>>,
22 tier_manager: Arc<RwLock<TierManager>>,
23 context_builder: Arc<RwLock<ContextBuilder>>,
24 config: OrchestratorConfig,
25 llm_worker: Option<Arc<LLMWorker>>,
27 smart_retrieval: Option<Arc<SmartRetrieval>>,
29}
30
31#[derive(Debug, Clone)]
33pub struct OrchestratorConfig {
34 pub enabled: bool,
35 pub max_context_tokens: usize,
36 pub auto_optimize: bool,
37 pub enable_metrics: bool,
38 pub session_timeout_seconds: u64,
39 pub enable_smart_retrieval: bool,
41 pub smart_retrieval_config: SmartRetrievalConfig,
43}
44
45impl Default for OrchestratorConfig {
46 fn default() -> Self {
47 Self {
48 enabled: true,
49 max_context_tokens: 4000,
50 auto_optimize: true,
51 enable_metrics: true,
52 session_timeout_seconds: 3600,
53 enable_smart_retrieval: true, smart_retrieval_config: SmartRetrievalConfig::default(),
55 }
56 }
57}
58
59impl ContextOrchestrator {
60 pub async fn new(
62 database: Arc<MemoryDatabase>,
63 config: OrchestratorConfig,
64 ) -> anyhow::Result<Self> {
65 let retrieval_planner = Arc::new(RwLock::new(RetrievalPlanner::new(database.clone())));
67
68 let tier_manager_config = TierManagerConfig::default();
70 let tier_manager = TierManager::new(
71 database.clone(),
72 tier_manager_config,
73 );
74 let tier_manager = Arc::new(RwLock::new(tier_manager));
75
76 let context_builder_config = ContextBuilderConfig::default();
78 let context_builder = Arc::new(RwLock::new(ContextBuilder::new(context_builder_config)));
79
80 let smart_retrieval = if config.enable_smart_retrieval {
82 let smart_ret = SmartRetrieval::new(
83 Arc::clone(&tier_manager),
84 config.smart_retrieval_config.clone(),
85 );
86 info!("Smart retrieval initialized (enabled)");
87 Some(Arc::new(smart_ret))
88 } else {
89 info!("Smart retrieval disabled");
90 None
91 };
92
93 let orchestrator = Self {
94 database,
95 retrieval_planner,
96 tier_manager,
97 context_builder,
98 config,
99 llm_worker: None,
100 smart_retrieval,
101 };
102
103 info!("Context orchestrator initialized successfully");
104
105 Ok(orchestrator)
106 }
107
108 pub fn set_llm_worker(&mut self, worker: Arc<LLMWorker>) {
110 self.llm_worker = Some(worker);
111 info!("Context orchestrator: LLM worker set for semantic search");
112 }
113
114 pub fn database(&self) -> &Arc<MemoryDatabase> {
116 &self.database
117 }
118
119 pub async fn process_conversation(
121 &self,
122 session_id: &str,
123 messages: &[Message],
124 user_query: Option<&str>,
125 ) -> anyhow::Result<Vec<Message>> {
126 if !self.config.enabled || messages.is_empty() {
127 debug!("Context engine disabled or no messages");
128 return Ok(messages.to_vec());
129 }
130
131 info!("Processing conversation for session {} ({} messages)", session_id, messages.len());
132
133 {
135 let tier_manager = self.tier_manager.write().await;
136 tier_manager.store_tier1_content(session_id, messages).await;
137 }
138
139 if let Some(last_message) = messages.last() {
141 if last_message.role == "user" {
142 let tier_manager = self.tier_manager.read().await;
143 if let Err(e) = tier_manager.store_tier3_content(session_id, std::slice::from_ref(last_message)).await {
144 warn!("Failed to persist user query to database: {}", e);
145 } else {
146 info!("✅ Persisted user query to database for session {}", session_id);
147 }
148 }
149 }
150
151 let plan = {
153 let retrieval_planner = self.retrieval_planner.read().await;
154
155 let has_past_refs = if let Some(query) = user_query {
158 retrieval_planner.has_past_references_in_text(query)
159 } else {
160 false
161 };
162
163 retrieval_planner.create_plan(
165 session_id,
166 messages,
167 self.config.max_context_tokens,
168 user_query,
169 has_past_refs, ).await?
171 };
172
173 if !plan.needs_retrieval {
174 debug!("No retrieval needed, returning current messages");
175 return Ok(messages.to_vec());
176 }
177
178 let retrieved_content = self.execute_retrieval_plan(session_id, &plan, user_query).await?;
180
181 let optimized_context = if let Some(ref smart_retrieval) = self.smart_retrieval {
184 match smart_retrieval.retrieve(
185 session_id,
186 messages,
187 retrieved_content.tier2.clone(),
188 retrieved_content.tier3.clone(),
189 retrieved_content.cross_session.clone(),
190 ).await {
191 Ok(smart_result) => {
192 info!(
193 "🎯 Smart retrieval: Strategy={:?}, Tokens={}, Savings={:.1}%",
194 smart_result.strategy,
195 smart_result.retrieved_tokens,
196 smart_result.compute_savings * 100.0
197 );
198 smart_result.messages
199 }
200 Err(e) => {
201 warn!("Smart retrieval failed, falling back to standard: {}", e);
202 let mut context_builder = self.context_builder.write().await;
204 context_builder.build_context(
205 messages,
206 retrieved_content.tier1,
207 retrieved_content.tier2,
208 retrieved_content.tier3,
209 retrieved_content.cross_session,
210 user_query,
211 ).await?
212 }
213 }
214 } else {
215 let mut context_builder = self.context_builder.write().await;
217 context_builder.build_context(
218 messages,
219 retrieved_content.tier1,
220 retrieved_content.tier2,
221 retrieved_content.tier3,
222 retrieved_content.cross_session,
223 user_query,
224 ).await?
225 };
226
227 if let Some(query) = user_query {
229 if let Some(response) = optimized_context.last() {
230 if response.role == "assistant" {
231 self.update_engagement(query, &response.content).await;
232 }
233 }
234 }
235
236 info!(
237 "Context optimization complete: {} -> {} messages",
238 messages.len(),
239 optimized_context.len()
240 );
241
242 Ok(optimized_context)
243 }
244
245 pub async fn save_assistant_response(
247 &self,
248 session_id: &str,
249 response: &str,
250 ) -> anyhow::Result<()> {
251 let assistant_message = Message {
252 role: "assistant".to_string(),
253 content: response.to_string(),
254 };
255
256 let tier_manager = self.tier_manager.read().await;
257 tier_manager.store_tier3_content(session_id, &[assistant_message]).await
258 }
259
260 async fn execute_retrieval_plan(
264 &self,
265 session_id: &str,
266 plan: &RetrievalPlan,
267 user_query: Option<&str>,
268 ) -> anyhow::Result<RetrievedContent> {
269 let mut retrieved = RetrievedContent::default();
270
271 if plan.use_tier1 {
273 let tier_manager = self.tier_manager.read().await;
274 retrieved.tier1 = tier_manager.get_tier1_content(session_id).await;
275 }
276
277 if plan.use_tier2 {
279 let tier_manager = self.tier_manager.read().await;
280 retrieved.tier2 = tier_manager.get_tier2_content(session_id).await;
281 }
282
283 let mut semantic_results: Vec<crate::memory_db::StoredMessage> = Vec::new();
291
292 let has_embeddings = self.database.embeddings.get_stats()
293 .map(|s| s.total_embeddings > 0)
294 .unwrap_or(false);
295
296 if plan.semantic_search && has_embeddings {
297 if let (Some(ref llm_worker), Some(query)) = (&self.llm_worker, user_query) {
298 match llm_worker.generate_embeddings(vec![query.to_string()]).await {
299 Ok(query_embeddings) if !query_embeddings.is_empty() => {
300 let query_vec = &query_embeddings[0];
301 match self.database.embeddings.find_similar_embeddings(
303 query_vec,
304 "llama-server",
305 (plan.max_messages * 2) as i32,
306 0.3, ) {
308 Ok(similar) if !similar.is_empty() => {
309 info!("Semantic search found {} similar messages for context retrieval", similar.len());
310 for (message_id, _similarity) in &similar {
312 let conn = self.database.conversations.get_conn_public();
314 if let Ok(conn) = conn {
315 let mut stmt = conn.prepare(
316 "SELECT id, session_id, message_index, role, content, tokens,
317 timestamp, importance_score, embedding_generated
318 FROM messages WHERE id = ?1"
319 ).ok();
320 if let Some(ref mut stmt) = stmt {
321 if let Ok(mut rows) = stmt.query([message_id]) {
322 if let Ok(Some(row)) = rows.next() {
323 let ts_str: String = row.get(6).unwrap_or_default();
324 let ts = chrono::DateTime::parse_from_rfc3339(&ts_str)
325 .map(|dt| dt.with_timezone(&chrono::Utc))
326 .unwrap_or_else(|_| chrono::Utc::now());
327 semantic_results.push(crate::memory_db::StoredMessage {
328 id: row.get(0).unwrap_or(0),
329 session_id: row.get(1).unwrap_or_default(),
330 message_index: row.get(2).unwrap_or(0),
331 role: row.get(3).unwrap_or_default(),
332 content: row.get(4).unwrap_or_default(),
333 tokens: row.get(5).unwrap_or(0),
334 timestamp: ts,
335 importance_score: row.get(7).unwrap_or(0.5),
336 embedding_generated: row.get(8).unwrap_or(true),
337 });
338 }
339 }
340 }
341 }
342 }
343 }
344 Ok(_) => debug!("Semantic search: no results above threshold"),
345 Err(e) => debug!("Semantic search failed: {}", e),
346 }
347 }
348 Ok(_) => debug!("Empty embedding response for query"),
349 Err(e) => debug!("Query embedding generation failed (semantic search skipped): {}", e),
350 }
351 }
352 }
353
354 if plan.use_tier3 {
356 let tier_manager = self.tier_manager.read().await;
357 if plan.keyword_search && !plan.search_topics.is_empty() {
358 for topic in &plan.search_topics {
359 let limit_per_topic = plan.max_messages / plan.search_topics.len().max(1);
360
361 if let Ok(results) = tier_manager.search_tier3_content(
362 session_id,
363 topic,
364 limit_per_topic,
365 ).await {
366 let semantic_ids: std::collections::HashSet<i64> = semantic_results.iter().map(|m| m.id).collect();
368 let mut merged = semantic_results.clone();
369 for msg in results {
370 if !semantic_ids.contains(&msg.id) {
371 merged.push(msg);
372 }
373 }
374 retrieved.tier3 = Some(merged);
375 break;
376 }
377 }
378 if retrieved.tier3.is_none() && !semantic_results.is_empty() {
380 retrieved.tier3 = Some(semantic_results.clone());
381 }
382 } else {
383 if !semantic_results.is_empty() {
384 retrieved.tier3 = Some(semantic_results.clone());
386 } else {
387 retrieved.tier3 = tier_manager.get_tier3_content(
388 session_id,
389 Some((plan.max_messages as i64).min(i32::MAX as i64) as i32),
390 Some(0),
391 ).await.ok();
392 }
393 }
394 } else if !semantic_results.is_empty() {
395 retrieved.tier3 = Some(semantic_results);
397 }
398
399 if plan.cross_session_search && !plan.search_topics.is_empty() {
401 let tier_manager = self.tier_manager.read().await;
402 if let Ok(cross_session_results) = tier_manager.search_cross_session_content(
403 session_id,
404 &plan.search_topics.join(" "),
405 10,
406 ).await {
407 retrieved.cross_session = Some(cross_session_results);
408 }
409 }
410
411 Ok(retrieved)
412 }
413
414 async fn update_engagement(&self, user_query: &str, assistant_response: &str) {
415 debug!("Engagement updated for query: {} (response length: {})",
416 user_query, assistant_response.len());
417 }
418
419 pub async fn get_session_stats(&self, session_id: &str) -> anyhow::Result<SessionStats> {
420 let tier_manager = self.tier_manager.read().await;
421 let tier_stats = tier_manager.get_tier_stats(session_id).await;
422 let db_stats = self.database.get_stats()?;
423
424 Ok(SessionStats {
425 session_id: session_id.to_string(),
426 tier_stats,
427 database_stats: db_stats,
428 })
429 }
430
431 pub async fn cleanup(&self, older_than_seconds: u64) -> anyhow::Result<CleanupStats> {
432 info!("Starting cleanup of old data");
433 let db_cleaned = self.database.cleanup_old_data((older_than_seconds / 86400) as i32)?;
434 let tier_manager = self.tier_manager.read().await;
435 let cache_cleaned = tier_manager.cleanup_cache(older_than_seconds).await;
436
437 Ok(CleanupStats {
438 sessions_cleaned: db_cleaned,
439 cache_entries_cleaned: cache_cleaned,
440 })
441 }
442
443 pub async fn search_messages(
445 &self,
446 session_id: Option<&str>,
447 keywords: &[String],
448 limit: usize,
449 ) -> anyhow::Result<Vec<crate::memory_db::StoredMessage>> {
450 if keywords.is_empty() {
451 return Ok(Vec::new());
452 }
453
454 if let Some(sid) = session_id {
455 self.database.search_messages_by_keywords(sid, keywords, limit).await
457 } else {
458 Ok(Vec::new())
461 }
462 }
463
464 pub fn set_enabled(&mut self, enabled: bool) {
465 self.config.enabled = enabled;
466 info!("Context engine {}", if enabled { "enabled" } else { "disabled" });
467 }
468
469 pub fn update_config(&mut self, config: OrchestratorConfig) {
470 self.config = config;
471 info!("Context engine configuration updated");
472 }
473
474 pub fn get_config(&self) -> &OrchestratorConfig {
475 &self.config
476 }
477
478 pub fn tier_manager(&self) -> &Arc<RwLock<TierManager>> {
480 &self.tier_manager
481 }
482}
483
484impl Clone for ContextOrchestrator {
485 fn clone(&self) -> Self {
486 Self {
487 database: self.database.clone(),
488 retrieval_planner: self.retrieval_planner.clone(),
489 tier_manager: self.tier_manager.clone(),
490 context_builder: self.context_builder.clone(),
491 config: self.config.clone(),
492 llm_worker: self.llm_worker.clone(),
493 smart_retrieval: self.smart_retrieval.clone(),
494 }
495 }
496}
497
498#[derive(Debug, Default)]
499struct RetrievedContent {
500 tier1: Option<Vec<Message>>,
501 tier2: Option<Vec<crate::memory_db::Summary>>,
502 tier3: Option<Vec<crate::memory_db::StoredMessage>>,
503 cross_session: Option<Vec<crate::memory_db::StoredMessage>>,
504}
505
506#[derive(Debug, Clone)]
507pub struct SessionStats {
508 pub session_id: String,
509 pub tier_stats: crate::context_engine::tier_manager::TierStats,
510 pub database_stats: crate::memory_db::schema::DatabaseStats,
511}
512
513#[derive(Debug, Clone)]
514pub struct CleanupStats {
515 pub sessions_cleaned: usize,
516 pub cache_entries_cleaned: usize,
517}