1use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::sync::{Arc, RwLock};
10
11use crate::error::Result;
12use crate::session::SessionContext;
13use crate::types::{ExtractionSource, SessionActivity, SupportTier};
14use nexus_agent::activity_monitor::ActivityMonitor;
15use nexus_agent::dream_cycle::run_nap;
16
17pub type SessionEndCallback = Arc<dyn Fn(SessionContext) + Send + Sync>;
19
20#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
22pub struct LifecycleCapabilities {
23 pub session_start: bool,
24 pub session_end: bool,
25 pub checkpoint: bool,
26 pub error_hook: bool,
27 pub compact: bool,
28}
29
30impl LifecycleCapabilities {
31 pub fn end_only() -> Self {
32 Self {
33 session_end: true,
34 ..Default::default()
35 }
36 }
37 pub fn monitor_only() -> Self {
38 Self::default()
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct HookResult {
45 pub success: bool,
46 pub agent_type: String,
47 pub source: ExtractionSource,
48 pub context: Option<SessionContext>,
49 pub error: Option<String>,
50 pub timestamp: DateTime<Utc>,
51}
52
53impl HookResult {
54 pub fn success(agent_type: impl Into<String>, source: ExtractionSource) -> Self {
55 Self {
56 success: true,
57 agent_type: agent_type.into(),
58 source,
59 context: None,
60 error: None,
61 timestamp: Utc::now(),
62 }
63 }
64 pub fn success_with_context(
65 agent_type: impl Into<String>,
66 source: ExtractionSource,
67 context: SessionContext,
68 ) -> Self {
69 Self {
70 success: true,
71 agent_type: agent_type.into(),
72 source,
73 context: Some(context),
74 error: None,
75 timestamp: Utc::now(),
76 }
77 }
78 pub fn failure(
79 agent_type: impl Into<String>,
80 source: ExtractionSource,
81 error: impl Into<String>,
82 ) -> Self {
83 Self {
84 success: false,
85 agent_type: agent_type.into(),
86 source,
87 context: None,
88 error: Some(error.into()),
89 timestamp: Utc::now(),
90 }
91 }
92}
93
94#[async_trait]
96pub trait AgentHook: Send + Sync {
97 fn agent_type(&self) -> &str;
98 async fn install_session_end_hook(&mut self, callback: SessionEndCallback) -> Result<()>;
99 async fn install_session_start_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
100 Err(crate::error::HookError::NotSupported(
101 "Session start hooks not supported for this agent".to_string(),
102 ))
103 }
104 async fn install_compact_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
105 Err(crate::error::HookError::NotSupported(
106 "Compact/checkpoint hooks not supported for this agent".to_string(),
107 ))
108 }
109 async fn detect_session_activity(&self) -> Result<SessionActivity>;
110 async fn extract_session_context(&self) -> Result<SessionContext>;
111 async fn install_checkpoint_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
112 Err(crate::error::HookError::NotSupported(
113 "Checkpoint hooks not supported for this agent".to_string(),
114 ))
115 }
116 async fn install_error_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
117 Err(crate::error::HookError::NotSupported(
118 "Error hooks not supported for this agent".to_string(),
119 ))
120 }
121 fn is_hook_installed(&self) -> bool {
122 false
123 }
124 async fn uninstall_hooks(&mut self) -> Result<()> {
125 Ok(())
126 }
127 fn reliability_score(&self) -> f32 {
128 1.0
129 }
130 fn lifecycle_capabilities(&self) -> LifecycleCapabilities {
131 LifecycleCapabilities::end_only()
132 }
133 fn support_tier(&self) -> SupportTier {
134 SupportTier::MonitorOnly
135 }
136 fn record_activity(&self) {}
137}
138
139pub struct BaseHook {
141 pub agent_type: String,
142 pub installed: bool,
143 pub callbacks: Vec<SessionEndCallback>,
144 pub session_start_callbacks: Vec<SessionEndCallback>,
145 pub checkpoint_callbacks: Vec<SessionEndCallback>,
146 pub error_callbacks: Vec<SessionEndCallback>,
147 pub activity_monitor: std::sync::Mutex<ActivityMonitor>,
148 pub rescorer: RwLock<Option<Arc<crate::rescorer::SessionRescorer>>>,
149 project_root: PathBuf,
151}
152
153impl BaseHook {
154 pub fn new(agent_type: impl Into<String>) -> Self {
155 let agent_type = agent_type.into();
156 let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
157
158 Self {
159 agent_type,
160 installed: false,
161 callbacks: Vec::new(),
162 session_start_callbacks: Vec::new(),
163 checkpoint_callbacks: Vec::new(),
164 error_callbacks: Vec::new(),
165 activity_monitor: std::sync::Mutex::new(ActivityMonitor::load()),
166 rescorer: RwLock::new(None),
167 project_root,
168 }
169 }
170 fn ensure_rescorer(&self) {
173 let read = self.rescorer.read().unwrap();
174 if read.is_none() {
175 drop(read);
176 let mut write = self.rescorer.write().unwrap();
177 if write.is_none() {
178 let project = nexus_core::ProjectIdentity::resolve(&self.project_root);
179 let config = nexus_core::Config::from_env().unwrap_or_default();
180 *write = Some(Arc::new(crate::rescorer::SessionRescorer::new(
181 project,
182 config.cognitive_system.rescore_turn_interval,
183 config.cognitive_system.rescore_drift_threshold,
184 )));
185 }
186 }
187 }
188
189 pub fn record_activity(&self) {
190 self.record_activity_with_content("activity recorded")
191 }
192
193 pub fn record_activity_with_content(&self, content: &str) {
194 self.ensure_rescorer();
196
197 if let Ok(mut monitor) = self.activity_monitor.lock() {
198 let mut disk = ActivityMonitor::load();
201 disk.record_activity();
202 *monitor = disk.clone();
203 drop(monitor);
204 if let Err(e) = disk.save() {
205 tracing::debug!("Failed to save activity monitor: {e}");
206 }
207 }
208
209 if content == "activity recorded" {
211 return;
212 }
213 let rescorer = self.rescorer.read().unwrap().clone();
215 if let Some(rescorer) = rescorer {
216 let content = content.to_string();
217 let agent_type = self.agent_type.clone();
218 if let Ok(handle) = tokio::runtime::Handle::try_current() {
219 handle.spawn(async move {
220 let config = nexus_core::Config::from_env().unwrap_or_default();
221 let embeddings = if config.embedding.enabled {
222 nexus_agent::runtime::create_embedding_service(&config).await
223 } else {
224 None
225 };
226 if let Some(similarity) =
227 rescorer.on_turn(&content, embeddings.as_deref()).await
228 {
229 let _ = rescorer.rescore(embeddings.as_deref(), &agent_type).await;
230
231 if similarity < rescorer.drift_threshold() {
235 let mut data = std::collections::HashMap::new();
236 data.insert("agent_type".to_string(), serde_json::json!(agent_type));
237 data.insert("drift_detected".to_string(), serde_json::json!(true));
238 data.insert("similarity".to_string(), serde_json::json!(similarity));
239 data.insert(
240 "threshold".to_string(),
241 serde_json::json!(rescorer.drift_threshold()),
242 );
243
244 let event = nexus_orchestrator::Event::with_data(
245 nexus_orchestrator::EventType::CognitiveDrift,
246 data,
247 )
248 .with_source("base_hook");
249
250 let event_bus = nexus_orchestrator::EventBus::global();
251 let _ = event_bus.publish(event);
252 }
253 }
254 });
255 }
256 }
257 }
258
259 pub fn add_callback(&mut self, callback: SessionEndCallback) {
260 self.callbacks.push(callback);
261 }
262
263 pub fn add_session_start_callback(&mut self, callback: SessionEndCallback) {
264 self.session_start_callbacks.push(callback);
265 }
266
267 pub fn add_checkpoint_callback(&mut self, callback: SessionEndCallback) {
268 self.checkpoint_callbacks.push(callback);
269 }
270
271 pub fn add_error_callback(&mut self, callback: SessionEndCallback) {
272 self.error_callbacks.push(callback);
273 }
274
275 pub fn trigger_session_start_callbacks(&self, context: SessionContext) {
276 for callback in &self.session_start_callbacks {
277 callback(context.clone());
278 }
279 }
280
281 pub fn trigger_checkpoint_callbacks(&self, context: SessionContext) {
282 for callback in &self.checkpoint_callbacks {
283 callback(context.clone());
284 }
285 }
286
287 pub fn trigger_error_callbacks(&self, context: SessionContext) {
288 for callback in &self.error_callbacks {
289 callback(context.clone());
290 }
291 }
292
293 pub fn trigger_callbacks(&self, context: SessionContext) {
294 for callback in &self.callbacks {
295 callback(context.clone());
296 }
297
298 if let Some(session_id) = context.session_id.as_ref() {
299 let session_id = session_id.clone();
300 let agent_type = context.agent_type.clone();
301 let project_root = self.project_root.clone();
302
303 if let Ok(handle) = tokio::runtime::Handle::try_current() {
304 handle.spawn(async move {
305 let config = nexus_core::Config::from_env().unwrap_or_default();
306 if config.cognitive_system.dream_triggers.nap_on_session_end {
307 let cwd = project_root;
308 let pool_url = config.database_url();
309 if let Some(parent) = config.database.path.parent() {
310 let _ = std::fs::create_dir_all(parent);
311 }
312 if let Ok(mut storage) =
313 nexus_storage::StorageManager::from_url(&pool_url).await
314 {
315 if let Err(e) = storage.initialize().await {
316 tracing::warn!("Failed to initialize storage for nap: {e}");
317 return;
318 }
319 let ns_repo = nexus_storage::repository::NamespaceRepository::new(
320 storage.pool().clone(),
321 );
322 if let Ok(namespace) =
323 ns_repo.get_or_create(&agent_type, &agent_type).await
324 {
325 let llm_result = nexus_llm::create_client_auto_with_fallback();
326 let llm = match llm_result {
327 Ok(client) => client,
328 Err(e) => {
329 tracing::warn!(
330 "Failed to create LLM client for session-end nap: {}",
331 e
332 );
333 return;
334 }
335 };
336 let embeddings = if config.embedding.enabled {
337 nexus_agent::runtime::create_embedding_service(&config).await
338 } else {
339 None
340 };
341 let timeout = std::time::Duration::from_secs(
342 config.cognition.session_end_dream_timeout_secs,
343 );
344 let services = nexus_agent::dream_cycle::DreamServices {
345 pool: storage.pool().clone(),
346 cognition: config.cognition.clone(),
347 agent: config.agent.clone(),
348 llm,
349 embeddings,
350 cognitive_system: config.cognitive_system.clone(),
351 };
352 match run_nap(&session_id, &cwd, namespace.id, &services, timeout)
353 .await
354 {
355 Ok(nap_result) => {
356 if nap_result.timed_out {
357 tracing::warn!(
358 session_id = %session_id,
359 "nap timed out; not publishing DreamCompleted"
360 );
361 } else {
362 let mut data = std::collections::HashMap::new();
364 data.insert(
365 "agent_type".to_string(),
366 serde_json::json!(agent_type),
367 );
368 data.insert(
369 "processed".to_string(),
370 serde_json::json!(nap_result.memories_processed),
371 );
372
373 let event = nexus_orchestrator::Event::with_data(
374 nexus_orchestrator::EventType::DreamCompleted,
375 data,
376 )
377 .with_source("agent_supervisor");
378
379 let event_bus = nexus_orchestrator::EventBus::global();
380 let _ = event_bus.publish(event);
381 }
382 }
383 Err(e) => {
384 tracing::warn!(
385 session_id = %session_id,
386 error = %e,
387 "Session-end nap failed"
388 );
389 }
390 }
391 } else {
392 tracing::debug!("Failed to get/create namespace for nap");
393 }
394 } else {
395 tracing::debug!("Failed to create storage for nap");
396 }
397 }
398 });
399 }
400 }
401 }
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407
408 #[test]
409 fn test_hook_result_success() {
410 let result = HookResult::success("test-agent", ExtractionSource::Manual);
411 assert!(result.success);
412 assert!(result.error.is_none());
413 }
414
415 #[test]
416 fn test_hook_result_failure() {
417 let result = HookResult::failure(
418 "test-agent",
419 ExtractionSource::Manual,
420 "Something went wrong",
421 );
422 assert!(!result.success);
423 assert!(result.error.is_some());
424 assert_eq!(result.error.unwrap(), "Something went wrong");
425 }
426
427 #[test]
428 fn test_hook_result_with_context() {
429 let ctx = SessionContext::new("test");
430 let result = HookResult::success_with_context(
431 "test-agent",
432 ExtractionSource::NativeHook("skill".to_string()),
433 ctx,
434 );
435 assert!(result.success);
436 assert!(result.context.is_some());
437 }
438
439 #[test]
440 fn test_base_hook() {
441 let mut hook = BaseHook::new("test");
442 assert_eq!(hook.agent_type, "test");
443 assert!(!hook.installed);
444
445 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
446 let called_clone = called.clone();
447 hook.add_callback(Arc::new(move |_ctx| {
448 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
449 }));
450
451 hook.trigger_callbacks(SessionContext::new("test"));
452 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
453 }
454
455 #[test]
456 fn test_lifecycle_capabilities_default() {
457 let caps = LifecycleCapabilities::default();
458 assert!(!caps.session_start);
459 assert!(!caps.session_end);
460 assert!(!caps.checkpoint);
461 assert!(!caps.error_hook);
462 assert!(!caps.compact);
463 }
464
465 #[test]
466 fn test_lifecycle_capabilities_end_only() {
467 let caps = LifecycleCapabilities::end_only();
468 assert!(!caps.session_start);
469 assert!(caps.session_end);
470 assert!(!caps.checkpoint);
471 assert!(!caps.error_hook);
472 assert!(!caps.compact);
473 }
474
475 #[test]
476 fn test_lifecycle_capabilities_monitor_only() {
477 let caps = LifecycleCapabilities::monitor_only();
478 assert!(!caps.session_end);
479 }
480}