Skip to main content

nexus_memory_hooks/
base.rs

1//! AgentHook trait definition
2//!
3//! This module defines the core AgentHook trait that all agent hooks must implement.
4
5use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::sync::{Arc, RwLock};
10
11use crate::error::Result;
12use crate::session::SessionContext;
13use crate::types::{ExtractionSource, SessionActivity, SupportTier};
14use nexus_agent::activity_monitor::ActivityMonitor;
15use nexus_agent::dream_cycle::run_nap;
16
17/// Callback type for session end events
18pub type SessionEndCallback = Arc<dyn Fn(SessionContext) + Send + Sync>;
19
20/// Describes which lifecycle events an agent hook can handle.
21#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
22pub struct LifecycleCapabilities {
23    pub session_start: bool,
24    pub session_end: bool,
25    pub checkpoint: bool,
26    pub error_hook: bool,
27    pub compact: bool,
28}
29
30impl LifecycleCapabilities {
31    pub fn end_only() -> Self {
32        Self {
33            session_end: true,
34            ..Default::default()
35        }
36    }
37    pub fn monitor_only() -> Self {
38        Self::default()
39    }
40}
41
42/// Result of a hook operation
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct HookResult {
45    pub success: bool,
46    pub agent_type: String,
47    pub source: ExtractionSource,
48    pub context: Option<SessionContext>,
49    pub error: Option<String>,
50    pub timestamp: DateTime<Utc>,
51}
52
53impl HookResult {
54    pub fn success(agent_type: impl Into<String>, source: ExtractionSource) -> Self {
55        Self {
56            success: true,
57            agent_type: agent_type.into(),
58            source,
59            context: None,
60            error: None,
61            timestamp: Utc::now(),
62        }
63    }
64    pub fn success_with_context(
65        agent_type: impl Into<String>,
66        source: ExtractionSource,
67        context: SessionContext,
68    ) -> Self {
69        Self {
70            success: true,
71            agent_type: agent_type.into(),
72            source,
73            context: Some(context),
74            error: None,
75            timestamp: Utc::now(),
76        }
77    }
78    pub fn failure(
79        agent_type: impl Into<String>,
80        source: ExtractionSource,
81        error: impl Into<String>,
82    ) -> Self {
83        Self {
84            success: false,
85            agent_type: agent_type.into(),
86            source,
87            context: None,
88            error: Some(error.into()),
89            timestamp: Utc::now(),
90        }
91    }
92}
93
94/// AgentHook trait - all agent hooks must implement this
95#[async_trait]
96pub trait AgentHook: Send + Sync {
97    fn agent_type(&self) -> &str;
98    async fn install_session_end_hook(&mut self, callback: SessionEndCallback) -> Result<()>;
99    async fn install_session_start_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
100        Err(crate::error::HookError::NotSupported(
101            "Session start hooks not supported for this agent".to_string(),
102        ))
103    }
104    async fn install_compact_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
105        Err(crate::error::HookError::NotSupported(
106            "Compact/checkpoint hooks not supported for this agent".to_string(),
107        ))
108    }
109    async fn detect_session_activity(&self) -> Result<SessionActivity>;
110    async fn extract_session_context(&self) -> Result<SessionContext>;
111    async fn install_checkpoint_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
112        Err(crate::error::HookError::NotSupported(
113            "Checkpoint hooks not supported for this agent".to_string(),
114        ))
115    }
116    async fn install_error_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
117        Err(crate::error::HookError::NotSupported(
118            "Error hooks not supported for this agent".to_string(),
119        ))
120    }
121    fn is_hook_installed(&self) -> bool {
122        false
123    }
124    async fn uninstall_hooks(&mut self) -> Result<()> {
125        Ok(())
126    }
127    fn reliability_score(&self) -> f32 {
128        1.0
129    }
130    fn lifecycle_capabilities(&self) -> LifecycleCapabilities {
131        LifecycleCapabilities::end_only()
132    }
133    fn support_tier(&self) -> SupportTier {
134        SupportTier::MonitorOnly
135    }
136    fn record_activity(&self) {}
137}
138
139/// Base hook implementation with common functionality
140pub struct BaseHook {
141    pub agent_type: String,
142    pub installed: bool,
143    pub callbacks: Vec<SessionEndCallback>,
144    pub session_start_callbacks: Vec<SessionEndCallback>,
145    pub checkpoint_callbacks: Vec<SessionEndCallback>,
146    pub error_callbacks: Vec<SessionEndCallback>,
147    pub activity_monitor: std::sync::Mutex<ActivityMonitor>,
148    pub rescorer: RwLock<Option<Arc<crate::rescorer::SessionRescorer>>>,
149    /// Project root resolved at construction time, used for nap/dream cycles.
150    project_root: PathBuf,
151}
152
153impl BaseHook {
154    pub fn new(agent_type: impl Into<String>) -> Self {
155        let agent_type = agent_type.into();
156        let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
157
158        Self {
159            agent_type,
160            installed: false,
161            callbacks: Vec::new(),
162            session_start_callbacks: Vec::new(),
163            checkpoint_callbacks: Vec::new(),
164            error_callbacks: Vec::new(),
165            activity_monitor: std::sync::Mutex::new(ActivityMonitor::load()),
166            rescorer: RwLock::new(None),
167            project_root,
168        }
169    }
170    /// Lazily initialize the rescorer on first use
171    /// Initialize the rescorer using the project root associated with this hook
172    fn ensure_rescorer(&self) {
173        let read = self.rescorer.read().unwrap();
174        if read.is_none() {
175            drop(read);
176            let mut write = self.rescorer.write().unwrap();
177            if write.is_none() {
178                let project = nexus_core::ProjectIdentity::resolve(&self.project_root);
179                let config = nexus_core::Config::from_env().unwrap_or_default();
180                *write = Some(Arc::new(crate::rescorer::SessionRescorer::new(
181                    project,
182                    config.cognitive_system.rescore_turn_interval,
183                    config.cognitive_system.rescore_drift_threshold,
184                )));
185            }
186        }
187    }
188
189    pub fn record_activity(&self) {
190        self.record_activity_with_content("activity recorded")
191    }
192
193    pub fn record_activity_with_content(&self, content: &str) {
194        // Ensure rescorer is initialized lazily
195        self.ensure_rescorer();
196
197        if let Ok(mut monitor) = self.activity_monitor.lock() {
198            // Reload full monitor from disk to avoid overwriting concurrent writes
199            // to fields like last_deep_dream, deep_dream_cooldown, etc.
200            let mut disk = ActivityMonitor::load();
201            disk.record_activity();
202            *monitor = disk.clone();
203            drop(monitor);
204            if let Err(e) = disk.save() {
205                tracing::debug!("Failed to save activity monitor: {e}");
206            }
207        }
208
209        // Skip drift/rescore for activity-only sampling (placeholder content)
210        if content == "activity recorded" {
211            return;
212        }
213        // Trigger real-time re-scoring if drift detected
214        let rescorer = self.rescorer.read().unwrap().clone();
215        if let Some(rescorer) = rescorer {
216            let content = content.to_string();
217            let agent_type = self.agent_type.clone();
218            if let Ok(handle) = tokio::runtime::Handle::try_current() {
219                handle.spawn(async move {
220                    let config = nexus_core::Config::from_env().unwrap_or_default();
221                    let embeddings = if config.embedding.enabled {
222                        nexus_agent::runtime::create_embedding_service(&config).await
223                    } else {
224                        None
225                    };
226                    if let Some(similarity) =
227                        rescorer.on_turn(&content, embeddings.as_deref()).await
228                    {
229                        let _ = rescorer.rescore(embeddings.as_deref(), &agent_type).await;
230
231                        // PHASE 11: Notify orchestrator of drift
232                        // Only publish drift event for actual topic drift, not interval triggers
233                        // (interval triggers return similarity = 1.0)
234                        if similarity < rescorer.drift_threshold() {
235                            let mut data = std::collections::HashMap::new();
236                            data.insert("agent_type".to_string(), serde_json::json!(agent_type));
237                            data.insert("drift_detected".to_string(), serde_json::json!(true));
238                            data.insert("similarity".to_string(), serde_json::json!(similarity));
239                            data.insert(
240                                "threshold".to_string(),
241                                serde_json::json!(rescorer.drift_threshold()),
242                            );
243
244                            let event = nexus_orchestrator::Event::with_data(
245                                nexus_orchestrator::EventType::CognitiveDrift,
246                                data,
247                            )
248                            .with_source("base_hook");
249
250                            let event_bus = nexus_orchestrator::EventBus::global();
251                            let _ = event_bus.publish(event);
252                        }
253                    }
254                });
255            }
256        }
257    }
258
259    pub fn add_callback(&mut self, callback: SessionEndCallback) {
260        self.callbacks.push(callback);
261    }
262
263    pub fn add_session_start_callback(&mut self, callback: SessionEndCallback) {
264        self.session_start_callbacks.push(callback);
265    }
266
267    pub fn add_checkpoint_callback(&mut self, callback: SessionEndCallback) {
268        self.checkpoint_callbacks.push(callback);
269    }
270
271    pub fn add_error_callback(&mut self, callback: SessionEndCallback) {
272        self.error_callbacks.push(callback);
273    }
274
275    pub fn trigger_session_start_callbacks(&self, context: SessionContext) {
276        for callback in &self.session_start_callbacks {
277            callback(context.clone());
278        }
279    }
280
281    pub fn trigger_checkpoint_callbacks(&self, context: SessionContext) {
282        for callback in &self.checkpoint_callbacks {
283            callback(context.clone());
284        }
285    }
286
287    pub fn trigger_error_callbacks(&self, context: SessionContext) {
288        for callback in &self.error_callbacks {
289            callback(context.clone());
290        }
291    }
292
293    pub fn trigger_callbacks(&self, context: SessionContext) {
294        for callback in &self.callbacks {
295            callback(context.clone());
296        }
297
298        if let Some(session_id) = context.session_id.as_ref() {
299            let session_id = session_id.clone();
300            let agent_type = context.agent_type.clone();
301            let project_root = self.project_root.clone();
302
303            if let Ok(handle) = tokio::runtime::Handle::try_current() {
304                handle.spawn(async move {
305                    let config = nexus_core::Config::from_env().unwrap_or_default();
306                    if config.cognitive_system.dream_triggers.nap_on_session_end {
307                        let cwd = project_root;
308                        let pool_url = config.database_url();
309                        if let Some(parent) = config.database.path.parent() {
310                            let _ = std::fs::create_dir_all(parent);
311                        }
312                        if let Ok(mut storage) =
313                            nexus_storage::StorageManager::from_url(&pool_url).await
314                        {
315                            if let Err(e) = storage.initialize().await {
316                                tracing::warn!("Failed to initialize storage for nap: {e}");
317                                return;
318                            }
319                            let ns_repo = nexus_storage::repository::NamespaceRepository::new(
320                                storage.pool().clone(),
321                            );
322                            if let Ok(namespace) =
323                                ns_repo.get_or_create(&agent_type, &agent_type).await
324                            {
325                                let llm_result = nexus_llm::create_client_auto_with_fallback();
326                                let llm = match llm_result {
327                                    Ok(client) => client,
328                                    Err(e) => {
329                                        tracing::warn!(
330                                            "Failed to create LLM client for session-end nap: {}",
331                                            e
332                                        );
333                                        return;
334                                    }
335                                };
336                                let embeddings = if config.embedding.enabled {
337                                    nexus_agent::runtime::create_embedding_service(&config).await
338                                } else {
339                                    None
340                                };
341                                let timeout = std::time::Duration::from_secs(
342                                    config.cognition.session_end_dream_timeout_secs,
343                                );
344                                let services = nexus_agent::dream_cycle::DreamServices {
345                                    pool: storage.pool().clone(),
346                                    cognition: config.cognition.clone(),
347                                    agent: config.agent.clone(),
348                                    llm,
349                                    embeddings,
350                                    cognitive_system: config.cognitive_system.clone(),
351                                };
352                                match run_nap(&session_id, &cwd, namespace.id, &services, timeout)
353                                    .await
354                                {
355                                    Ok(nap_result) => {
356                                        if nap_result.timed_out {
357                                            tracing::warn!(
358                                                session_id = %session_id,
359                                                "nap timed out; not publishing DreamCompleted"
360                                            );
361                                        } else {
362                                            // PHASE 11: Notify of dream completion
363                                            let mut data = std::collections::HashMap::new();
364                                            data.insert(
365                                                "agent_type".to_string(),
366                                                serde_json::json!(agent_type),
367                                            );
368                                            data.insert(
369                                                "processed".to_string(),
370                                                serde_json::json!(nap_result.memories_processed),
371                                            );
372
373                                            let event = nexus_orchestrator::Event::with_data(
374                                                nexus_orchestrator::EventType::DreamCompleted,
375                                                data,
376                                            )
377                                            .with_source("agent_supervisor");
378
379                                            let event_bus = nexus_orchestrator::EventBus::global();
380                                            let _ = event_bus.publish(event);
381                                        }
382                                    }
383                                    Err(e) => {
384                                        tracing::warn!(
385                                            session_id = %session_id,
386                                            error = %e,
387                                            "Session-end nap failed"
388                                        );
389                                    }
390                                }
391                            } else {
392                                tracing::debug!("Failed to get/create namespace for nap");
393                            }
394                        } else {
395                            tracing::debug!("Failed to create storage for nap");
396                        }
397                    }
398                });
399            }
400        }
401    }
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407
408    #[test]
409    fn test_hook_result_success() {
410        let result = HookResult::success("test-agent", ExtractionSource::Manual);
411        assert!(result.success);
412        assert!(result.error.is_none());
413    }
414
415    #[test]
416    fn test_hook_result_failure() {
417        let result = HookResult::failure(
418            "test-agent",
419            ExtractionSource::Manual,
420            "Something went wrong",
421        );
422        assert!(!result.success);
423        assert!(result.error.is_some());
424        assert_eq!(result.error.unwrap(), "Something went wrong");
425    }
426
427    #[test]
428    fn test_hook_result_with_context() {
429        let ctx = SessionContext::new("test");
430        let result = HookResult::success_with_context(
431            "test-agent",
432            ExtractionSource::NativeHook("skill".to_string()),
433            ctx,
434        );
435        assert!(result.success);
436        assert!(result.context.is_some());
437    }
438
439    #[test]
440    fn test_base_hook() {
441        let mut hook = BaseHook::new("test");
442        assert_eq!(hook.agent_type, "test");
443        assert!(!hook.installed);
444
445        let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
446        let called_clone = called.clone();
447        hook.add_callback(Arc::new(move |_ctx| {
448            called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
449        }));
450
451        hook.trigger_callbacks(SessionContext::new("test"));
452        assert!(called.load(std::sync::atomic::Ordering::SeqCst));
453    }
454
455    #[test]
456    fn test_lifecycle_capabilities_default() {
457        let caps = LifecycleCapabilities::default();
458        assert!(!caps.session_start);
459        assert!(!caps.session_end);
460        assert!(!caps.checkpoint);
461        assert!(!caps.error_hook);
462        assert!(!caps.compact);
463    }
464
465    #[test]
466    fn test_lifecycle_capabilities_end_only() {
467        let caps = LifecycleCapabilities::end_only();
468        assert!(!caps.session_start);
469        assert!(caps.session_end);
470        assert!(!caps.checkpoint);
471        assert!(!caps.error_hook);
472        assert!(!caps.compact);
473    }
474
475    #[test]
476    fn test_lifecycle_capabilities_monitor_only() {
477        let caps = LifecycleCapabilities::monitor_only();
478        assert!(!caps.session_end);
479    }
480}