Skip to main content

nexus_memory_agent/
supervisor.rs

1//! Agent supervisor - manages background agent loops
2
3use std::sync::Arc;
4
5use chrono::Utc;
6use nexus_core::config::{AgentConfig, CognitionConfig};
7use nexus_core::traits::EmbeddingService;
8use nexus_core::{CognitiveLevel, CognitiveMetadata, Config};
9use nexus_llm::LlmClient;
10use nexus_storage::repository::{ListMemoryFilters, MemoryRepository, ProcessedFileRepository};
11use sqlx::SqlitePool;
12use tokio::sync::RwLock;
13use tokio::task::JoinHandle;
14use tokio::time::{interval, Duration};
15use tokio_util::sync::CancellationToken;
16use tracing::{debug, error, info};
17
18use crate::error::AgentError;
19use crate::inbox::InboxScanner;
20use crate::ingest::IngestService;
21use crate::pulse;
22use crate::query::QueryService;
23use crate::runtime::{drain_cognition_jobs, run_dream_cycle, DreamCycleRequest};
24use crate::types::{AgentStatus, QueryIntrospection};
25
26/// How long to wait for tasks to shut down gracefully before force-aborting.
27const GRACEFUL_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
28
29pub struct AgentSupervisor {
30    config: AgentConfig,
31    llm: Arc<dyn LlmClient>,
32    query_embedder: Option<Arc<dyn EmbeddingService>>,
33    pool: SqlitePool,
34    namespace_id: i64,
35    status: Arc<RwLock<AgentStatus>>,
36    cancel_token: CancellationToken,
37    tasks: Vec<JoinHandle<()>>,
38}
39
40impl AgentSupervisor {
41    pub fn new(
42        config: AgentConfig,
43        llm: Arc<dyn LlmClient>,
44        pool: SqlitePool,
45        namespace_id: i64,
46    ) -> Self {
47        let status = Arc::new(RwLock::new(AgentStatus {
48            enabled: config.enabled,
49            namespace: config.namespace.clone(),
50            inbox_dir: config.inbox_dir.clone(),
51            last_scan: None,
52            last_consolidation: None,
53            files_processed: 0,
54            memories_consolidated: 0,
55            queries_answered: 0,
56            errors: Vec::new(),
57        }));
58
59        Self {
60            config,
61            llm,
62            query_embedder: None,
63            pool,
64            namespace_id,
65            status,
66            cancel_token: CancellationToken::new(),
67            tasks: Vec::new(),
68        }
69    }
70
71    pub async fn start(&mut self) -> Result<(), AgentError> {
72        if !self.config.enabled {
73            info!("Agent is disabled, not starting supervisor");
74            return Ok(());
75        }
76
77        info!("Starting agent supervisor");
78
79        // Spawn inbox scanner task
80        let inbox_handle = self.spawn_inbox_scanner().await?;
81        self.tasks.push(inbox_handle);
82
83        // Spawn consolidation task
84        let consolidation_handle = self.spawn_consolidation_task().await?;
85        self.tasks.push(consolidation_handle);
86
87        let cognition_handle = self.spawn_cognition_worker_task().await?;
88        self.tasks.push(cognition_handle);
89
90        info!("Agent supervisor started with {} tasks", self.tasks.len());
91        Ok(())
92    }
93
94    pub async fn stop(&mut self) {
95        info!("Stopping agent supervisor (signaling graceful shutdown)");
96
97        self.cancel_token.cancel();
98
99        // Wait for tasks to complete gracefully
100        let mut remaining: Vec<JoinHandle<()>> = Vec::new();
101        for task in self.tasks.drain(..) {
102            if task.is_finished() {
103                let _ = task.await;
104            } else {
105                remaining.push(task);
106            }
107        }
108
109        if !remaining.is_empty() {
110            info!(
111                "Waiting up to {}s for {} task(s) to finish gracefully",
112                GRACEFUL_SHUTDOWN_TIMEOUT.as_secs(),
113                remaining.len()
114            );
115            match tokio::time::timeout(GRACEFUL_SHUTDOWN_TIMEOUT, async {
116                for task in remaining {
117                    let _ = task.await;
118                }
119            })
120            .await
121            {
122                Ok(()) => info!("All tasks shut down gracefully"),
123                Err(_) => {
124                    // Tasks didn't finish in time — they were already cancelled,
125                    // and their JoinHandles will return once the loop exits.
126                    info!("Graceful shutdown timed out");
127                }
128            }
129        }
130
131        self.tasks.clear();
132        info!("Agent supervisor stopped");
133    }
134
135    pub async fn get_status(&self) -> AgentStatus {
136        self.status.read().await.clone()
137    }
138
139    pub fn with_query_embedder(mut self, embedder: Arc<dyn EmbeddingService>) -> Self {
140        self.query_embedder = Some(embedder);
141        self
142    }
143
144    /// Increment the queries answered counter (for external callers like the web API).
145    pub async fn increment_queries_answered(&self) {
146        let mut s = self.status.write().await;
147        s.queries_answered += 1;
148    }
149
150    pub fn query_service(&self) -> QueryService {
151        if let Some(embedder) = &self.query_embedder {
152            QueryService::with_embedder(self.llm.clone(), self.config.clone(), embedder.clone())
153        } else {
154            QueryService::new(self.llm.clone(), self.config.clone())
155        }
156    }
157
158    pub fn ingest_service(&self) -> IngestService {
159        IngestService::new(self.llm.clone(), self.config.clone())
160    }
161
162    /// Get the agent namespace ID
163    pub fn namespace_id(&self) -> i64 {
164        self.namespace_id
165    }
166
167    /// Compute query introspection (ranking decisions) without calling the LLM.
168    pub async fn query_introspection(
169        &self,
170        question: &str,
171        namespace_id: i64,
172        memory_repo: &MemoryRepository,
173    ) -> Result<QueryIntrospection, AgentError> {
174        self.query_service()
175            .query_introspection(question, namespace_id, memory_repo)
176            .await
177    }
178
179    async fn spawn_inbox_scanner(&self) -> Result<JoinHandle<()>, AgentError> {
180        let config = self.config.clone();
181        let llm = self.llm.clone();
182        let pool = self.pool.clone();
183        let namespace_id = self.namespace_id;
184        let status = self.status.clone();
185        let interval_secs = config.scan_interval_secs;
186        let cancel = self.cancel_token.clone();
187
188        let handle = tokio::spawn(async move {
189            let ingest_service = IngestService::new(llm.clone(), config.clone());
190            let scanner = InboxScanner::new(config, ingest_service);
191            let mut ticker = interval(Duration::from_secs(interval_secs));
192
193            loop {
194                tokio::select! {
195                    _ = ticker.tick() => {}
196                    _ = cancel.cancelled() => {
197                        info!("Inbox scanner received shutdown signal");
198                        break;
199                    }
200                }
201
202                let processed_repo = ProcessedFileRepository::new(&pool);
203                let memory_repo = MemoryRepository::new(pool.clone());
204
205                match scanner
206                    .run(namespace_id, &processed_repo, &memory_repo)
207                    .await
208                {
209                    Ok(result) => {
210                        let mut s = status.write().await;
211                        s.last_scan = Some(Utc::now());
212                        s.files_processed += result.processed;
213                        pulse::write_pulse(
214                            "inbox_scan",
215                            s.memories_consolidated,
216                            s.files_processed,
217                        );
218                    }
219                    Err(e) => {
220                        error!(error = %e, namespace_id, "Inbox scan failed");
221                        let mut s = status.write().await;
222                        s.errors.push(format!("Scan error: {}", e));
223                        if s.errors.len() > 10 {
224                            s.errors.remove(0);
225                        }
226                    }
227                }
228            }
229        });
230
231        Ok(handle)
232    }
233
234    async fn spawn_consolidation_task(&self) -> Result<JoinHandle<()>, AgentError> {
235        let config = self.config.clone();
236        let llm = self.llm.clone();
237        let pool = self.pool.clone();
238        let namespace_id = self.namespace_id;
239        let status = self.status.clone();
240        let base_interval_secs = config.consolidation_interval_mins * 60;
241        let cancel = self.cancel_token.clone();
242        let cognition = Config::from_env()
243            .map(|config| config.cognition)
244            .unwrap_or_default();
245
246        let handle = tokio::spawn(async move {
247            loop {
248                let sleep_duration = if cognition.adaptive_dream_enabled {
249                    compute_adaptive_dream_interval(
250                        pool.clone(),
251                        namespace_id,
252                        base_interval_secs,
253                        &cognition,
254                    )
255                    .await
256                } else {
257                    Duration::from_secs(base_interval_secs)
258                };
259
260                tokio::select! {
261                    _ = tokio::time::sleep(sleep_duration) => {}
262                    _ = cancel.cancelled() => {
263                        info!("Consolidation task received shutdown signal");
264                        break;
265                    }
266                }
267
268                let lease_owner = format!("supervisor-dream-{}", namespace_id);
269                match run_dream_cycle(
270                    pool.clone(),
271                    &cognition,
272                    &config,
273                    llm.clone(),
274                    None,
275                    DreamCycleRequest {
276                        namespace_id,
277                        lease_owner: &lease_owner,
278                        perspective: None,
279                        session_key: None,
280                        reflect_reason: "namespace_dream",
281                        digest_reason: "dream_digest",
282                    },
283                )
284                .await
285                {
286                    Ok(processed) if processed > 0 => {
287                        let mut s = status.write().await;
288                        s.last_consolidation = Some(Utc::now());
289                        s.memories_consolidated += processed as u64;
290                        pulse::write_pulse(
291                            "consolidation",
292                            s.memories_consolidated,
293                            s.files_processed,
294                        );
295                    }
296                    Ok(_) => {
297                        debug!("No memories to consolidate");
298                    }
299                    Err(e) => {
300                        error!(error = %e, namespace_id, "Consolidation failed");
301                        let mut s = status.write().await;
302                        s.errors.push(format!("Consolidation error: {}", e));
303                        if s.errors.len() > 10 {
304                            s.errors.remove(0);
305                        }
306                    }
307                }
308            }
309        });
310
311        Ok(handle)
312    }
313
314    async fn spawn_cognition_worker_task(&self) -> Result<JoinHandle<()>, AgentError> {
315        let config = self.config.clone();
316        let pool = self.pool.clone();
317        let namespace_id = self.namespace_id;
318        let status = self.status.clone();
319        let llm = self.llm.clone();
320        let cancel = self.cancel_token.clone();
321        let cognition = nexus_core::Config::from_env()
322            .map(|config| config.cognition)
323            .unwrap_or_default();
324
325        let handle = tokio::spawn(async move {
326            let mut ticker = interval(Duration::from_secs(config.scan_interval_secs.max(1)));
327
328            loop {
329                tokio::select! {
330                    _ = ticker.tick() => {}
331                    _ = cancel.cancelled() => {
332                        info!("Cognition worker received shutdown signal");
333                        break;
334                    }
335                }
336
337                match drain_cognition_jobs(
338                    pool.clone(),
339                    namespace_id,
340                    &cognition,
341                    &config,
342                    llm.clone(),
343                    None,
344                    &format!("supervisor-{}", namespace_id),
345                )
346                .await
347                {
348                    Ok(processed) => {
349                        if processed > 0 {
350                            debug!(namespace_id, processed, "Cognition worker drained jobs");
351                            let mut s = status.write().await;
352                            s.last_consolidation = Some(Utc::now());
353                            s.memories_consolidated += processed as u64;
354                            pulse::write_pulse(
355                                "cognition",
356                                s.memories_consolidated,
357                                s.files_processed,
358                            );
359                        }
360                    }
361                    Err(error) => {
362                        error!(error = %error, namespace_id, "Cognition worker failed");
363                        let mut s = status.write().await;
364                        s.errors.push(format!("Cognition error: {}", error));
365                        if s.errors.len() > 10 {
366                            s.errors.remove(0);
367                        }
368                    }
369                }
370            }
371        });
372
373        Ok(handle)
374    }
375}
376
377/// Compute the next dream-cycle interval based on contradiction density.
378///
379/// When contradictions exist in the namespace, the interval is shortened
380/// proportionally (down to `adaptive_dream_min_interval_secs`). Otherwise
381/// the base interval is used (capped at `adaptive_dream_max_interval_secs`).
382async fn compute_adaptive_dream_interval(
383    pool: SqlitePool,
384    namespace_id: i64,
385    base_interval_secs: u64,
386    cognition: &CognitionConfig,
387) -> Duration {
388    let repo = MemoryRepository::new(pool);
389    let min = cognition.adaptive_dream_min_interval_secs;
390    let max = cognition.adaptive_dream_max_interval_secs;
391    let base = base_interval_secs.clamp(min, max);
392
393    let contradiction_count = match repo
394        .list_filtered(
395            namespace_id,
396            ListMemoryFilters {
397                category: None,
398                since: None,
399                until: None,
400                content_like: None,
401                include_raw: false,
402                limit: 256,
403                offset: 0,
404            },
405        )
406        .await
407    {
408        Ok(memories) => memories
409            .iter()
410            .filter(|m| {
411                CognitiveMetadata::from_metadata(&m.metadata)
412                    .map(|c| c.level == CognitiveLevel::Contradiction)
413                    .unwrap_or(false)
414            })
415            .count(),
416        Err(_) => return Duration::from_secs(base),
417    };
418
419    if contradiction_count == 0 {
420        return Duration::from_secs(base);
421    }
422
423    // Shorten interval proportionally: each contradiction shaves 10% off the base,
424    // down to the minimum.
425    let factor = 1.0 - ((contradiction_count as f32 * 0.10).min(0.9));
426    let adapted = (base as f32 * factor) as u64;
427    Duration::from_secs(adapted.clamp(min, max))
428}