1use async_trait::async_trait;
6use chrono::{DateTime, Utc};
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::sync::{Arc, RwLock};
10
11use crate::error::Result;
12use crate::session::SessionContext;
13use crate::types::{ExtractionSource, SessionActivity, SupportTier};
14use nexus_agent::activity_monitor::ActivityMonitor;
15use nexus_agent::dream_cycle::run_nap;
16
17pub type SessionEndCallback = Arc<dyn Fn(SessionContext) + Send + Sync>;
19
20#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
22pub struct LifecycleCapabilities {
23 pub session_start: bool,
24 pub session_end: bool,
25 pub checkpoint: bool,
26 pub error_hook: bool,
27 pub compact: bool,
28}
29
30impl LifecycleCapabilities {
31 pub fn end_only() -> Self {
32 Self {
33 session_end: true,
34 ..Default::default()
35 }
36 }
37 pub fn monitor_only() -> Self {
38 Self::default()
39 }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct HookResult {
45 pub success: bool,
46 pub agent_type: String,
47 pub source: ExtractionSource,
48 pub context: Option<SessionContext>,
49 pub error: Option<String>,
50 pub timestamp: DateTime<Utc>,
51}
52
53impl HookResult {
54 pub fn success(agent_type: impl Into<String>, source: ExtractionSource) -> Self {
55 Self {
56 success: true,
57 agent_type: agent_type.into(),
58 source,
59 context: None,
60 error: None,
61 timestamp: Utc::now(),
62 }
63 }
64 pub fn success_with_context(
65 agent_type: impl Into<String>,
66 source: ExtractionSource,
67 context: SessionContext,
68 ) -> Self {
69 Self {
70 success: true,
71 agent_type: agent_type.into(),
72 source,
73 context: Some(context),
74 error: None,
75 timestamp: Utc::now(),
76 }
77 }
78 pub fn failure(
79 agent_type: impl Into<String>,
80 source: ExtractionSource,
81 error: impl Into<String>,
82 ) -> Self {
83 Self {
84 success: false,
85 agent_type: agent_type.into(),
86 source,
87 context: None,
88 error: Some(error.into()),
89 timestamp: Utc::now(),
90 }
91 }
92}
93
94#[async_trait]
96pub trait AgentHook: Send + Sync {
97 fn agent_type(&self) -> &str;
98 async fn install_session_end_hook(&mut self, callback: SessionEndCallback) -> Result<()>;
99 async fn install_session_start_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
100 Err(crate::error::HookError::NotSupported(
101 "Session start hooks not supported for this agent".to_string(),
102 ))
103 }
104 async fn install_compact_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
105 Err(crate::error::HookError::NotSupported(
106 "Compact/checkpoint hooks not supported for this agent".to_string(),
107 ))
108 }
109 async fn detect_session_activity(&self) -> Result<SessionActivity>;
110 async fn extract_session_context(&self) -> Result<SessionContext>;
111 async fn install_checkpoint_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
112 Err(crate::error::HookError::NotSupported(
113 "Checkpoint hooks not supported for this agent".to_string(),
114 ))
115 }
116 async fn install_error_hook(&mut self, _callback: SessionEndCallback) -> Result<()> {
117 Err(crate::error::HookError::NotSupported(
118 "Error hooks not supported for this agent".to_string(),
119 ))
120 }
121 fn is_hook_installed(&self) -> bool {
122 false
123 }
124 async fn uninstall_hooks(&mut self) -> Result<()> {
125 Ok(())
126 }
127 fn reliability_score(&self) -> f32 {
128 1.0
129 }
130 fn lifecycle_capabilities(&self) -> LifecycleCapabilities {
131 LifecycleCapabilities::end_only()
132 }
133 fn support_tier(&self) -> SupportTier {
134 SupportTier::MonitorOnly
135 }
136 fn record_activity(&self) {}
137}
138
139pub struct BaseHook {
141 pub agent_type: String,
142 pub installed: bool,
143 pub callbacks: Vec<SessionEndCallback>,
144 pub session_start_callbacks: Vec<SessionEndCallback>,
145 pub checkpoint_callbacks: Vec<SessionEndCallback>,
146 pub error_callbacks: Vec<SessionEndCallback>,
147 pub activity_monitor: std::sync::Mutex<ActivityMonitor>,
148 pub rescorer: RwLock<Option<Arc<crate::rescorer::SessionRescorer>>>,
149 project_root: PathBuf,
151}
152
153impl BaseHook {
154 pub fn new(agent_type: impl Into<String>) -> Self {
155 let agent_type = agent_type.into();
156 let project_root = std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."));
157
158 Self {
159 agent_type,
160 installed: false,
161 callbacks: Vec::new(),
162 session_start_callbacks: Vec::new(),
163 checkpoint_callbacks: Vec::new(),
164 error_callbacks: Vec::new(),
165 activity_monitor: std::sync::Mutex::new(ActivityMonitor::load()),
166 rescorer: RwLock::new(None),
167 project_root,
168 }
169 }
170 fn ensure_rescorer(&self) {
172 let read = self.rescorer.read().unwrap();
173 if read.is_none() {
174 drop(read);
175 let mut write = self.rescorer.write().unwrap();
176 if write.is_none() {
177 if let Ok(cwd) = std::env::current_dir() {
178 let project = nexus_core::ProjectIdentity::resolve(&cwd);
179 let config = nexus_core::Config::from_env().unwrap_or_default();
180 *write = Some(Arc::new(crate::rescorer::SessionRescorer::new(
181 project,
182 config.cognitive_system.rescore_turn_interval,
183 config.cognitive_system.rescore_drift_threshold,
184 )));
185 }
186 }
187 }
188 }
189
190 pub fn record_activity(&self) {
191 self.record_activity_with_content("activity recorded")
192 }
193
194 pub fn record_activity_with_content(&self, content: &str) {
195 self.ensure_rescorer();
197
198 if let Ok(mut monitor) = self.activity_monitor.lock() {
199 let mut disk = ActivityMonitor::load();
202 disk.record_activity();
203 *monitor = disk.clone();
204 drop(monitor);
205 if let Err(e) = disk.save() {
206 tracing::debug!("Failed to save activity monitor: {e}");
207 }
208 }
209
210 if content == "activity recorded" {
212 return;
213 }
214 let rescorer = self.rescorer.read().unwrap().clone();
216 if let Some(rescorer) = rescorer {
217 let content = content.to_string();
218 let agent_type = self.agent_type.clone();
219 if let Ok(handle) = tokio::runtime::Handle::try_current() {
220 handle.spawn(async move {
221 let config = nexus_core::Config::from_env().unwrap_or_default();
222 let embeddings = if config.embedding.enabled {
223 nexus_agent::runtime::create_embedding_service(&config).await
224 } else {
225 None
226 };
227 if let Some(similarity) =
228 rescorer.on_turn(&content, embeddings.as_deref()).await
229 {
230 let _ = rescorer.rescore(embeddings.as_deref(), &agent_type).await;
231
232 if similarity < rescorer.drift_threshold() {
236 let mut data = std::collections::HashMap::new();
237 data.insert("agent_type".to_string(), serde_json::json!(agent_type));
238 data.insert("drift_detected".to_string(), serde_json::json!(true));
239 data.insert("similarity".to_string(), serde_json::json!(similarity));
240 data.insert(
241 "threshold".to_string(),
242 serde_json::json!(rescorer.drift_threshold()),
243 );
244
245 let event = nexus_orchestrator::Event::with_data(
246 nexus_orchestrator::EventType::CognitiveDrift,
247 data,
248 )
249 .with_source("base_hook");
250
251 let event_bus = nexus_orchestrator::EventBus::global();
252 let _ = event_bus.publish(event);
253 }
254 }
255 });
256 }
257 }
258 }
259
260 pub fn add_callback(&mut self, callback: SessionEndCallback) {
261 self.callbacks.push(callback);
262 }
263
264 pub fn add_session_start_callback(&mut self, callback: SessionEndCallback) {
265 self.session_start_callbacks.push(callback);
266 }
267
268 pub fn add_checkpoint_callback(&mut self, callback: SessionEndCallback) {
269 self.checkpoint_callbacks.push(callback);
270 }
271
272 pub fn add_error_callback(&mut self, callback: SessionEndCallback) {
273 self.error_callbacks.push(callback);
274 }
275
276 pub fn trigger_session_start_callbacks(&self, context: SessionContext) {
277 for callback in &self.session_start_callbacks {
278 callback(context.clone());
279 }
280 }
281
282 pub fn trigger_checkpoint_callbacks(&self, context: SessionContext) {
283 for callback in &self.checkpoint_callbacks {
284 callback(context.clone());
285 }
286 }
287
288 pub fn trigger_error_callbacks(&self, context: SessionContext) {
289 for callback in &self.error_callbacks {
290 callback(context.clone());
291 }
292 }
293
294 pub fn trigger_callbacks(&self, context: SessionContext) {
295 for callback in &self.callbacks {
296 callback(context.clone());
297 }
298
299 if let Some(session_id) = context.session_id.as_ref() {
300 let session_id = session_id.clone();
301 let agent_type = context.agent_type.clone();
302 let project_root = self.project_root.clone();
303
304 if let Ok(handle) = tokio::runtime::Handle::try_current() {
305 handle.spawn(async move {
306 let config = nexus_core::Config::from_env().unwrap_or_default();
307 if config.cognitive_system.dream_triggers.nap_on_session_end {
308 let cwd = project_root;
309 let pool_url = config.database_url();
310 if let Some(parent) = config.database.path.parent() {
311 let _ = std::fs::create_dir_all(parent);
312 }
313 if let Ok(mut storage) =
314 nexus_storage::StorageManager::from_url(&pool_url).await
315 {
316 if let Err(e) = storage.initialize().await {
317 tracing::warn!("Failed to initialize storage for nap: {e}");
318 return;
319 }
320 let ns_repo = nexus_storage::repository::NamespaceRepository::new(
321 storage.pool().clone(),
322 );
323 if let Ok(namespace) =
324 ns_repo.get_or_create(&agent_type, &agent_type).await
325 {
326 let llm_result = nexus_llm::create_client_auto_with_fallback();
327 let llm = match llm_result {
328 Ok(client) => client,
329 Err(e) => {
330 tracing::warn!(
331 "Failed to create LLM client for session-end nap: {}",
332 e
333 );
334 return;
335 }
336 };
337 let embeddings = if config.embedding.enabled {
338 nexus_agent::runtime::create_embedding_service(&config).await
339 } else {
340 None
341 };
342 let timeout = std::time::Duration::from_secs(
343 config.cognition.session_end_dream_timeout_secs,
344 );
345 let services = nexus_agent::dream_cycle::DreamServices {
346 pool: storage.pool().clone(),
347 cognition: config.cognition.clone(),
348 agent: config.agent.clone(),
349 llm,
350 embeddings,
351 cognitive_system: config.cognitive_system.clone(),
352 };
353 match run_nap(&session_id, &cwd, namespace.id, &services, timeout)
354 .await
355 {
356 Ok(nap_result) => {
357 if nap_result.timed_out {
358 tracing::warn!(
359 session_id = %session_id,
360 "nap timed out; not publishing DreamCompleted"
361 );
362 } else {
363 let mut data = std::collections::HashMap::new();
365 data.insert(
366 "agent_type".to_string(),
367 serde_json::json!(agent_type),
368 );
369 data.insert(
370 "processed".to_string(),
371 serde_json::json!(nap_result.memories_processed),
372 );
373
374 let event = nexus_orchestrator::Event::with_data(
375 nexus_orchestrator::EventType::DreamCompleted,
376 data,
377 )
378 .with_source("agent_supervisor");
379
380 let event_bus = nexus_orchestrator::EventBus::global();
381 let _ = event_bus.publish(event);
382 }
383 }
384 Err(e) => {
385 tracing::warn!(
386 session_id = %session_id,
387 error = %e,
388 "Session-end nap failed"
389 );
390 }
391 }
392 } else {
393 tracing::debug!("Failed to get/create namespace for nap");
394 }
395 } else {
396 tracing::debug!("Failed to create storage for nap");
397 }
398 }
399 });
400 }
401 }
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn test_hook_result_success() {
411 let result = HookResult::success("test-agent", ExtractionSource::Manual);
412 assert!(result.success);
413 assert!(result.error.is_none());
414 }
415
416 #[test]
417 fn test_hook_result_failure() {
418 let result = HookResult::failure(
419 "test-agent",
420 ExtractionSource::Manual,
421 "Something went wrong",
422 );
423 assert!(!result.success);
424 assert!(result.error.is_some());
425 assert_eq!(result.error.unwrap(), "Something went wrong");
426 }
427
428 #[test]
429 fn test_hook_result_with_context() {
430 let ctx = SessionContext::new("test");
431 let result = HookResult::success_with_context(
432 "test-agent",
433 ExtractionSource::NativeHook("skill".to_string()),
434 ctx,
435 );
436 assert!(result.success);
437 assert!(result.context.is_some());
438 }
439
440 #[test]
441 fn test_base_hook() {
442 let mut hook = BaseHook::new("test");
443 assert_eq!(hook.agent_type, "test");
444 assert!(!hook.installed);
445
446 let called = std::sync::Arc::new(std::sync::atomic::AtomicBool::new(false));
447 let called_clone = called.clone();
448 hook.add_callback(Arc::new(move |_ctx| {
449 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
450 }));
451
452 hook.trigger_callbacks(SessionContext::new("test"));
453 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
454 }
455
456 #[test]
457 fn test_lifecycle_capabilities_default() {
458 let caps = LifecycleCapabilities::default();
459 assert!(!caps.session_start);
460 assert!(!caps.session_end);
461 assert!(!caps.checkpoint);
462 assert!(!caps.error_hook);
463 assert!(!caps.compact);
464 }
465
466 #[test]
467 fn test_lifecycle_capabilities_end_only() {
468 let caps = LifecycleCapabilities::end_only();
469 assert!(!caps.session_start);
470 assert!(caps.session_end);
471 assert!(!caps.checkpoint);
472 assert!(!caps.error_hook);
473 assert!(!caps.compact);
474 }
475
476 #[test]
477 fn test_lifecycle_capabilities_monitor_only() {
478 let caps = LifecycleCapabilities::monitor_only();
479 assert!(!caps.session_end);
480 }
481}