Skip to main content

nexus_memory_hooks/
base.rs

1//! AgentHook trait definition
2//!
3//! This module defines the core AgentHook trait that all agent hooks must implement.
4
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::sync::{Arc, RwLock};
10
11use crate::error::Result;
12use crate::session::SessionContext;
13use crate::types::{ExtractionSource, SessionActivity, SupportTier};
14use nexus_agent::activity_monitor::ActivityMonitor;
15use nexus_agent::dream_cycle::run_nap;
16
17/// Callback type for session end events
18pub type SessionEndCallback = Arc<dyn Fn(SessionContext) + Send + Sync>;
19
20/// Describes which lifecycle events an agent hook can handle.
21#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
22pub struct LifecycleCapabilities {
23    pub session_start: bool,
24    pub session_end: bool,
25    pub checkpoint: bool,
26    pub error_hook: bool,
27    pub compact: bool,
28}
29
30impl LifecycleCapabilities {
31    pub fn end_only() -> Self {
32        Self {
33            session_end: true,
34            ..Default::default()
35        }
36    }
37    pub fn monitor_only() -> Self {
38        Self::default()
39    }
40}
41
42/// Result of a hook operation
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct HookResult {
45    pub success: bool,
46    pub agent_type: String,
47    pub source: ExtractionSource,
48    pub context: Option<SessionContext>,
49    pub error: Option<String>,
50    pub timestamp: DateTime<Utc>,
51}
52
53impl HookResult {
54    pub fn success(agent_type: impl Into<String>, source: ExtractionSource) -> Self {
55        Self {
56            success: true,
57            agent_type: agent_type.into(),
58            source,
59            context: None,
60            error: None,
61            timestamp: Utc::now(),
62        }
63    }
64    pub fn success_with_context(
65        agent_type: impl Into<String>,
66        source: ExtractionSource,
67        context: SessionContext,
68    ) -> Self {
69        Self {
70            success: true,
71            agent_type: agent_type.into(),
72            source,
73            context: Some(context),
74            error: None,
75            timestamp: Utc::now(),
76        }
77    }
78    pub fn failure(
79        agent_type: impl Into<String>,
80        source: ExtractionSource,
81        error: impl Into<String>,
82    ) -> Self {
83        Self {
84            success: false,
85            agent_type: agent_type.into(),
86            source,
87            context: None,
88            error: Some(error.into()),
89            timestamp: Utc::now(),
90        }
91    }
92}
93
94/// AgentHook trait - all agent hooks must implement this
95#[async_trait]
96pub trait AgentHook: Send + Sync {
97    fn agent_type(&self) -> &str;
98    async fn install_session_end_hook(&mut self, callback: SessionEndCallback) -> Result<()>;
99    async fn install_session_start_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
100        Err(crate::error::HookError::NotSupported(
101            "Session start hooks not supported for this agent".to_string(),
102        ))
103    }
104    async fn install_compact_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
105        Err(crate::error::HookError::NotSupported(
106            "Compact/checkpoint hooks not supported for this agent".to_string(),
107        ))
108    }
109    async fn detect_session_activity(&self) -> Result<SessionActivity>;
110    async fn extract_session_context(&self) -> Result<SessionContext>;
111    async fn install_checkpoint_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
112        Err(crate::error::HookError::NotSupported(
113            "Checkpoint hooks not supported for this agent".to_string(),
114        ))
115    }
116    async fn install_error_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
117        Err(crate::error::HookError::NotSupported(
118            "Error hooks not supported for this agent".to_string(),
119        ))
120    }
121    fn is_hook_installed(&self) -> bool {
122        false
123    }
124    async fn uninstall_hooks(&mut self) -> Result<()> {
125        Ok(())
126    }
127    fn reliability_score(&self) -> f32 {
128        1.0
129    }
130    fn lifecycle_capabilities(&self) -> LifecycleCapabilities {
131        LifecycleCapabilities::end_only()
132    }
133    fn support_tier(&self) -> SupportTier {
134        SupportTier::MonitorOnly
135    }
136    fn record_activity(&self) {}
137}
138
139/// Base hook implementation with common functionality
140pub struct BaseHook {
141    pub agent_type: String,
142    pub installed: bool,
143    pub callbacks: Vec<SessionEndCallback>,
144    pub session_start_callbacks: Vec<SessionEndCallback>,
145    pub checkpoint_callbacks: Vec<SessionEndCallback>,
146    pub error_callbacks: Vec<SessionEndCallback>,
147    pub activity_monitor: std::sync::Mutex<ActivityMonitor>,
148    pub rescorer: RwLock<Option<Arc<crate::rescorer::SessionRescorer>>>,
149    /// Project root resolved at construction time, used for nap/dream cycles.
150    project_root: PathBuf,
151}
152
153impl BaseHook {
154    pub fn new(agent_type: impl Into<String>) -> Self {
155        let agent_type = agent_type.into();
156        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
157
158        Self {
159            agent_type,
160            installed: false,
161            callbacks: Vec::new(),
162            session_start_callbacks: Vec::new(),
163            checkpoint_callbacks: Vec::new(),
164            error_callbacks: Vec::new(),
165            activity_monitor: std::sync::Mutex::new(ActivityMonitor::load()),
166            rescorer: RwLock::new(None),
167            project_root,
168        }
169    }
170    /// Lazily initialize the rescorer on first use
171    fn ensure_rescorer(&self) {
172        let read = self.rescorer.read().unwrap();
173        if read.is_none() {
174            drop(read);
175            let mut write = self.rescorer.write().unwrap();
176            if write.is_none() {
177                if let Ok(cwd) = std::env::current_dir() {
178                    let project = nexus_core::ProjectIdentity::resolve(&cwd);
179                    let config = nexus_core::Config::from_env().unwrap_or_default();
180                    *write = Some(Arc::new(crate::rescorer::SessionRescorer::new(
181                        project,
182                        config.cognitive_system.rescore_turn_interval,
183                        config.cognitive_system.rescore_drift_threshold,
184                    )));
185                }
186            }
187        }
188    }
189
190    pub fn record_activity(&self) {
191        self.record_activity_with_content("activity recorded")
192    }
193
194    pub fn record_activity_with_content(&self, content: &str) {
195        // Ensure rescorer is initialized lazily
196        self.ensure_rescorer();
197
198        if let Ok(mut monitor) = self.activity_monitor.lock() {
199            // Reload full monitor from disk to avoid overwriting concurrent writes
200            // to fields like last_deep_dream, deep_dream_cooldown, etc.
201            let mut disk = ActivityMonitor::load();
202            disk.record_activity();
203            *monitor = disk.clone();
204            drop(monitor);
205            if let Err(e) = disk.save() {
206                tracing::debug!("Failed to save activity monitor: {e}");
207            }
208        }
209
210        // Skip drift/rescore for activity-only sampling (placeholder content)
211        if content == "activity recorded" {
212            return;
213        }
214        // Trigger real-time re-scoring if drift detected
215        let rescorer = self.rescorer.read().unwrap().clone();
216        if let Some(rescorer) = rescorer {
217            let content = content.to_string();
218            let agent_type = self.agent_type.clone();
219            if let Ok(handle) = tokio::runtime::Handle::try_current() {
220                handle.spawn(async move {
221                    let config = nexus_core::Config::from_env().unwrap_or_default();
222                    let embeddings = if config.embedding.enabled {
223                        nexus_agent::runtime::create_embedding_service(&config).await
224                    } else {
225                        None
226                    };
227                    if let Some(similarity) =
228                        rescorer.on_turn(&content, embeddings.as_deref()).await
229                    {
230                        let _ = rescorer.rescore(embeddings.as_deref(), &agent_type).await;
231
232                        // PHASE 11: Notify orchestrator of drift
233                        // Only publish drift event for actual topic drift, not interval triggers
234                        // (interval triggers return similarity = 1.0)
235                        if similarity < rescorer.drift_threshold() {
236                            let mut data = std::collections::HashMap::new();
237                            data.insert("agent_type".to_string(), serde_json::json!(agent_type));
238                            data.insert("drift_detected".to_string(), serde_json::json!(true));
239                            data.insert("similarity".to_string(), serde_json::json!(similarity));
240                            data.insert(
241                                "threshold".to_string(),
242                                serde_json::json!(rescorer.drift_threshold()),
243                            );
244
245                            let event = nexus_orchestrator::Event::with_data(
246                                nexus_orchestrator::EventType::CognitiveDrift,
247                                data,
248                            )
249                            .with_source("base_hook");
250
251                            let event_bus = nexus_orchestrator::EventBus::global();
252                            let _ = event_bus.publish(event);
253                        }
254                    }
255                });
256            }
257        }
258    }
259
260    pub fn add_callback(&mut self, callback: SessionEndCallback) {
261        self.callbacks.push(callback);
262    }
263
264    pub fn add_session_start_callback(&mut self, callback: SessionEndCallback) {
265        self.session_start_callbacks.push(callback);
266    }
267
268    pub fn add_checkpoint_callback(&mut self, callback: SessionEndCallback) {
269        self.checkpoint_callbacks.push(callback);
270    }
271
272    pub fn add_error_callback(&mut self, callback: SessionEndCallback) {
273        self.error_callbacks.push(callback);
274    }
275
276    pub fn trigger_session_start_callbacks(&self, context: SessionContext) {
277        for callback in &self.session_start_callbacks {
278            callback(context.clone());
279        }
280    }
281
282    pub fn trigger_checkpoint_callbacks(&self, context: SessionContext) {
283        for callback in &self.checkpoint_callbacks {
284            callback(context.clone());
285        }
286    }
287
288    pub fn trigger_error_callbacks(&self, context: SessionContext) {
289        for callback in &self.error_callbacks {
290            callback(context.clone());
291        }
292    }
293
294    pub fn trigger_callbacks(&self, context: SessionContext) {
295        for callback in &self.callbacks {
296            callback(context.clone());
297        }
298
299        if let Some(session_id) = context.session_id.as_ref() {
300            let session_id = session_id.clone();
301            let agent_type = context.agent_type.clone();
302            let project_root = self.project_root.clone();
303
304            if let Ok(handle) = tokio::runtime::Handle::try_current() {
305                handle.spawn(async move {
306                    let config = nexus_core::Config::from_env().unwrap_or_default();
307                    if config.cognitive_system.dream_triggers.nap_on_session_end {
308                        let cwd = project_root;
309                        let pool_url = config.database_url();
310                        if let Some(parent) = config.database.path.parent() {
311                            let _ = std::fs::create_dir_all(parent);
312                        }
313                        if let Ok(mut storage) =
314                            nexus_storage::StorageManager::from_url(&pool_url).await
315                        {
316                            if let Err(e) = storage.initialize().await {
317                                tracing::warn!("Failed to initialize storage for nap: {e}");
318                                return;
319                            }
320                            let ns_repo = nexus_storage::repository::NamespaceRepository::new(
321                                storage.pool().clone(),
322                            );
323                            if let Ok(namespace) =
324                                ns_repo.get_or_create(&agent_type, &agent_type).await
325                            {
326                                let llm_result = nexus_llm::create_client_auto_with_fallback();
327                                let llm = match llm_result {
328                                    Ok(client) => client,
329                                    Err(e) => {
330                                        tracing::warn!(
331                                            "Failed to create LLM client for session-end nap: {}",
332                                            e
333                                        );
334                                        return;
335                                    }
336                                };
337                                let embeddings = if config.embedding.enabled {
338                                    nexus_agent::runtime::create_embedding_service(&config).await
339                                } else {
340                                    None
341                                };
342                                let timeout = std::time::Duration::from_secs(
343                                    config.cognition.session_end_dream_timeout_secs,
344                                );
345                                let services = nexus_agent::dream_cycle::DreamServices {
346                                    pool: storage.pool().clone(),
347                                    cognition: config.cognition.clone(),
348                                    agent: config.agent.clone(),
349                                    llm,
350                                    embeddings,
351                                    cognitive_system: config.cognitive_system.clone(),
352                                };
353                                match run_nap(&session_id, &cwd, namespace.id, &services, timeout)
354                                    .await
355                                {
356                                    Ok(nap_result) => {
357                                        if nap_result.timed_out {
358                                            tracing::warn!(
359                                                session_id = %session_id,
360                                                "nap timed out; not publishing DreamCompleted"
361                                            );
362                                        } else {
363                                            // PHASE 11: Notify of dream completion
364                                            let mut data = std::collections::HashMap::new();
365                                            data.insert(
366                                                "agent_type".to_string(),
367                                                serde_json::json!(agent_type),
368                                            );
369                                            data.insert(
370                                                "processed".to_string(),
371                                                serde_json::json!(nap_result.memories_processed),
372                                            );
373
374                                            let event = nexus_orchestrator::Event::with_data(
375                                                nexus_orchestrator::EventType::DreamCompleted,
376                                                data,
377                                            )
378                                            .with_source("agent_supervisor");
379
380                                            let event_bus = nexus_orchestrator::EventBus::global();
381                                            let _ = event_bus.publish(event);
382                                        }
383                                    }
384                                    Err(e) => {
385                                        tracing::warn!(
386                                            session_id = %session_id,
387                                            error = %e,
388                                            "Session-end nap failed"
389                                        );
390                                    }
391                                }
392                            } else {
393                                tracing::debug!("Failed to get/create namespace for nap");
394                            }
395                        } else {
396                            tracing::debug!("Failed to create storage for nap");
397                        }
398                    }
399                });
400            }
401        }
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn test_hook_result_success() {
411        let result = HookResult::success("test-agent", ExtractionSource::Manual);
412        assert!(result.success);
413        assert!(result.error.is_none());
414    }
415
416    #[test]
417    fn test_hook_result_failure() {
418        let result = HookResult::failure(
419            "test-agent",
420            ExtractionSource::Manual,
421            "Something went wrong",
422        );
423        assert!(!result.success);
424        assert!(result.error.is_some());
425        assert_eq!(result.error.unwrap(), "Something went wrong");
426    }
427
428    #[test]
429    fn test_hook_result_with_context() {
430        let ctx = SessionContext::new("test");
431        let result = HookResult::success_with_context(
432            "test-agent",
433            ExtractionSource::NativeHook("skill".to_string()),
434            ctx,
435        );
436        assert!(result.success);
437        assert!(result.context.is_some());
438    }
439
440    #[test]
441    fn test_base_hook() {
442        let mut hook = BaseHook::new("test");
443        assert_eq!(hook.agent_type, "test");
444        assert!(!hook.installed);
445
446        let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
447        let called_clone = called.clone();
448        hook.add_callback(Arc::new(move |_ctx| {
449            called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
450        }));
451
452        hook.trigger_callbacks(SessionContext::new("test"));
453        assert!(called.load(std::sync::atomic::Ordering::SeqCst));
454    }
455
456    #[test]
457    fn test_lifecycle_capabilities_default() {
458        let caps = LifecycleCapabilities::default();
459        assert!(!caps.session_start);
460        assert!(!caps.session_end);
461        assert!(!caps.checkpoint);
462        assert!(!caps.error_hook);
463        assert!(!caps.compact);
464    }
465
466    #[test]
467    fn test_lifecycle_capabilities_end_only() {
468        let caps = LifecycleCapabilities::end_only();
469        assert!(!caps.session_start);
470        assert!(caps.session_end);
471        assert!(!caps.checkpoint);
472        assert!(!caps.error_hook);
473        assert!(!caps.compact);
474    }
475
476    #[test]
477    fn test_lifecycle_capabilities_monitor_only() {
478        let caps = LifecycleCapabilities::monitor_only();
479        assert!(!caps.session_end);
480    }
481}