Skip to main content

ai_agent/utils/hooks/
session_hooks.rs

1// Source: ~/claudecode/openclaudecode/src/utils/hooks/sessionHooks.ts
2#![allow(dead_code)]
3
4use std::collections::HashMap;
5use std::sync::{Arc, Mutex};
6
7use crate::utils::hooks::hooks_settings::{HookCommand, HookEvent, is_hook_equal};
8
9/// Function hook callback - returns true if check passes, false to block
10pub type FunctionHookCallback = Box<dyn Fn(&[serde_json::Value]) -> bool + Send + Sync>;
11
12/// Function hook type with callback embedded.
13/// Session-scoped only, cannot be persisted to settings.json.
14#[derive(Clone)]
15pub struct FunctionHook {
16    pub id: Option<String>,
17    pub timeout: Option<u64>,
18    pub callback: Arc<dyn Fn(&[serde_json::Value]) -> bool + Send + Sync>,
19    pub error_message: String,
20    pub status_message: Option<String>,
21}
22
23impl FunctionHook {
24    pub fn new(
25        id: Option<String>,
26        timeout: Option<u64>,
27        callback: Arc<dyn Fn(&[serde_json::Value]) -> bool + Send + Sync>,
28        error_message: String,
29    ) -> Self {
30        Self {
31            id,
32            timeout,
33            callback,
34            error_message,
35            status_message: None,
36        }
37    }
38}
39
40/// Extended hook command that can be either a regular hook or a function hook
41#[derive(Clone)]
42pub enum SessionHookCommand {
43    Regular(HookCommand),
44    Function(FunctionHook),
45}
46
47/// On hook success callback
48pub type OnHookSuccess = Arc<dyn Fn(&SessionHookCommand, &AggregatedHookResult) + Send + Sync>;
49
50/// Aggregated hook result
51pub struct AggregatedHookResult {
52    pub success: bool,
53    pub output: Option<String>,
54}
55
56/// Session hook matcher
57#[derive(Clone)]
58pub struct SessionHookMatcher {
59    pub matcher: String,
60    pub skill_root: Option<String>,
61    pub hooks: Vec<SessionHookEntry>,
62}
63
64/// A single hook entry in a matcher
65#[derive(Clone)]
66pub struct SessionHookEntry {
67    pub hook: SessionHookCommand,
68    pub on_hook_success: Option<OnHookSuccess>,
69}
70
71/// Session store for hooks
72#[derive(Clone, Default)]
73pub struct SessionStore {
74    pub hooks: HashMap<HookEvent, Vec<SessionHookMatcher>>,
75}
76
77/// Session hooks state - uses Arc<Mutex<>> for interior mutability
78/// This mimics the TypeScript Map pattern where .set/.delete don't change
79/// the container's identity.
80pub struct SessionHooksState {
81    hooks: HashMap<String, SessionStore>,
82}
83
84impl SessionHooksState {
85    pub fn new() -> Self {
86        Self {
87            hooks: HashMap::new(),
88        }
89    }
90}
91
92lazy_static::lazy_static! {
93    static ref SESSION_HOOKS_STATE: Arc<Mutex<SessionHooksState>> = Arc::new(Mutex::new(
94        SessionHooksState::new()
95    ));
96}
97
98/// Add a command or prompt hook to the session.
99/// Session hooks are temporary, in-memory only, and cleared when session ends.
100pub fn add_session_hook(
101    set_app_state: &dyn Fn(&dyn Fn(&mut serde_json::Value)),
102    session_id: &str,
103    event: &HookEvent,
104    matcher: &str,
105    hook: HookCommand,
106    on_hook_success: Option<OnHookSuccess>,
107    skill_root: Option<&str>,
108) {
109    add_hook_to_session(
110        set_app_state,
111        session_id,
112        event,
113        matcher,
114        SessionHookCommand::Regular(hook),
115        on_hook_success,
116        skill_root.map(|s| s.to_string()),
117    );
118}
119
120/// Add a function hook to the session.
121/// Function hooks execute TypeScript callbacks in-memory for validation.
122/// Returns the hook ID (for removal)
123pub fn add_function_hook(
124    set_app_state: &dyn Fn(&dyn Fn(&mut serde_json::Value)),
125    session_id: &str,
126    event: &HookEvent,
127    matcher: &str,
128    callback: Arc<dyn Fn(&[serde_json::Value]) -> bool + Send + Sync>,
129    error_message: String,
130    timeout: Option<u64>,
131    id: Option<String>,
132) -> String {
133    let hook_id = id.unwrap_or_else(|| {
134        format!(
135            "function-hook-{}-{}",
136            std::time::SystemTime::now()
137                .duration_since(std::time::UNIX_EPOCH)
138                .unwrap()
139                .as_millis(),
140            rand::random::<u64>()
141        )
142    });
143
144    let hook = FunctionHook::new(Some(hook_id.clone()), timeout, callback, error_message);
145
146    add_hook_to_session(
147        set_app_state,
148        session_id,
149        event,
150        matcher,
151        SessionHookCommand::Function(hook),
152        None,
153        None,
154    );
155
156    hook_id
157}
158
159/// Remove a function hook by ID from the session
160pub fn remove_function_hook(
161    set_app_state: &dyn Fn(&dyn Fn(&mut serde_json::Value)),
162    session_id: &str,
163    event: &HookEvent,
164    hook_id: &str,
165) {
166    set_app_state(&|state: &mut serde_json::Value| {
167        // In a real implementation, we'd access the session hooks from app state
168        // For now, we use the global state
169    });
170
171    log_for_debugging(&format!(
172        "Removed function hook {} for event {} in session {}",
173        hook_id,
174        event.as_str(),
175        session_id
176    ));
177}
178
179/// Internal helper to add a hook to session state
180fn add_hook_to_session(
181    set_app_state: &dyn Fn(&dyn Fn(&mut serde_json::Value)),
182    session_id: &str,
183    event: &HookEvent,
184    matcher: &str,
185    hook: SessionHookCommand,
186    on_hook_success: Option<OnHookSuccess>,
187    skill_root: Option<String>,
188) {
189    // Call set_app_state to notify state change (matches TypeScript behavior)
190    set_app_state(&|state: &mut serde_json::Value| {
191        // Update state with the new hook
192        if let Some(session_hooks) = state.get_mut("session_hooks") {
193            if let Some(session_map) = session_hooks.as_object_mut() {
194                let _ = session_map.entry(session_id.to_string());
195            }
196        }
197    });
198
199    let mut state = SESSION_HOOKS_STATE.lock().unwrap();
200    let store = state
201        .hooks
202        .entry(session_id.to_string())
203        .or_insert_with(SessionStore::default);
204
205    let event_matchers = store.hooks.entry(event.clone()).or_default();
206
207    // Find existing matcher or create new one
208    let existing_matcher_index = event_matchers
209        .iter()
210        .position(|m| m.matcher == matcher && m.skill_root == skill_root);
211
212    if let Some(idx) = existing_matcher_index {
213        // Add to existing matcher
214        event_matchers[idx].hooks.push(SessionHookEntry {
215            hook,
216            on_hook_success,
217        });
218    } else {
219        // Create new matcher
220        event_matchers.push(SessionHookMatcher {
221            matcher: matcher.to_string(),
222            skill_root,
223            hooks: vec![SessionHookEntry {
224                hook,
225                on_hook_success,
226            }],
227        });
228    }
229
230    log_for_debugging(&format!(
231        "Added session hook for event {} in session {}",
232        event.as_str(),
233        session_id
234    ));
235}
236
237/// Remove a specific hook from the session
238pub fn remove_session_hook(
239    set_app_state: &dyn Fn(&dyn Fn(&mut serde_json::Value)),
240    session_id: &str,
241    event: &HookEvent,
242    hook: &HookCommand,
243) {
244    set_app_state(&|state: &mut serde_json::Value| {
245        // In a real implementation, we'd access the session hooks from app state
246    });
247
248    let mut state = SESSION_HOOKS_STATE.lock().unwrap();
249    if let Some(store) = state.hooks.get_mut(session_id) {
250        if let Some(event_matchers) = store.hooks.get_mut(event) {
251            // Remove the hook from all matchers
252            for matcher in event_matchers.iter_mut() {
253                matcher.hooks.retain(|entry| {
254                    if let SessionHookCommand::Regular(ref regular_hook) = entry.hook {
255                        !is_hook_equal(regular_hook, hook)
256                    } else {
257                        true // Don't remove function hooks by HookCommand
258                    }
259                });
260            }
261            // Remove empty matchers
262            event_matchers.retain(|m| !m.hooks.is_empty());
263
264            // Remove empty event matchers
265            store.hooks.retain(|_, matchers| !matchers.is_empty());
266        }
267    }
268
269    log_for_debugging(&format!(
270        "Removed session hook for event {} in session {}",
271        event.as_str(),
272        session_id
273    ));
274}
275
276/// Extended hook matcher that includes optional skillRoot for skill-scoped hooks
277#[derive(Clone)]
278pub struct SessionDerivedHookMatcher {
279    pub matcher: String,
280    pub hooks: Vec<HookCommand>,
281    pub skill_root: Option<String>,
282}
283
284/// Function hook matcher
285#[derive(Clone)]
286pub struct FunctionHookMatcher {
287    pub matcher: String,
288    pub hooks: Vec<FunctionHook>,
289}
290
291/// Get all session hooks for a specific event (excluding function hooks)
292pub fn get_session_hooks(
293    _session_id: &str,
294    event: Option<&HookEvent>,
295) -> HashMap<HookEvent, Vec<SessionDerivedHookMatcher>> {
296    let state = SESSION_HOOKS_STATE.lock().unwrap();
297    let store = match state.hooks.get(_session_id) {
298        Some(s) => s,
299        None => return HashMap::new(),
300    };
301
302    let mut result = HashMap::new();
303
304    if let Some(event) = event {
305        if let Some(session_matchers) = store.hooks.get(event) {
306            let derived_matchers = convert_to_hook_matchers(session_matchers);
307            if !derived_matchers.is_empty() {
308                result.insert(event.clone(), derived_matchers);
309            }
310        }
311    } else {
312        for (evt, session_matchers) in &store.hooks {
313            let derived_matchers = convert_to_hook_matchers(session_matchers);
314            if !derived_matchers.is_empty() {
315                result.insert(evt.clone(), derived_matchers);
316            }
317        }
318    }
319
320    result
321}
322
323/// Get all session function hooks for a specific event
324pub fn get_session_function_hooks(
325    session_id: &str,
326    event: Option<&HookEvent>,
327) -> HashMap<HookEvent, Vec<FunctionHookMatcher>> {
328    let state = SESSION_HOOKS_STATE.lock().unwrap();
329    let store = match state.hooks.get(session_id) {
330        Some(s) => s,
331        None => return HashMap::new(),
332    };
333
334    let mut result = HashMap::new();
335
336    let extract_function_hooks =
337        |session_matchers: &[SessionHookMatcher]| -> Vec<FunctionHookMatcher> {
338            session_matchers
339                .iter()
340                .map(|sm| {
341                    let function_hooks: Vec<FunctionHook> = sm
342                        .hooks
343                        .iter()
344                        .filter_map(|entry| {
345                            if let SessionHookCommand::Function(ref fh) = entry.hook {
346                                Some(fh.clone())
347                            } else {
348                                None
349                            }
350                        })
351                        .collect();
352                    FunctionHookMatcher {
353                        matcher: sm.matcher.clone(),
354                        hooks: function_hooks,
355                    }
356                })
357                .filter(|m| !m.hooks.is_empty())
358                .collect()
359        };
360
361    if let Some(event) = event {
362        if let Some(session_matchers) = store.hooks.get(event) {
363            let function_matchers = extract_function_hooks(session_matchers);
364            if !function_matchers.is_empty() {
365                result.insert(event.clone(), function_matchers);
366            }
367        }
368    } else {
369        for (evt, session_matchers) in &store.hooks {
370            let function_matchers = extract_function_hooks(session_matchers);
371            if !function_matchers.is_empty() {
372                result.insert(evt.clone(), function_matchers);
373            }
374        }
375    }
376
377    result
378}
379
380/// Get the full hook entry (including callbacks) for a specific session hook
381pub fn get_session_hook_callback(
382    session_id: &str,
383    event: &HookEvent,
384    matcher: &str,
385    hook: &HookCommand,
386) -> Option<SessionHookEntry> {
387    let state = SESSION_HOOKS_STATE.lock().unwrap();
388    let store = state.hooks.get(session_id)?;
389    let event_matchers = store.hooks.get(event)?;
390
391    // Find the hook in the matchers
392    for matcher_entry in event_matchers {
393        if matcher_entry.matcher == matcher || matcher.is_empty() {
394            for entry in &matcher_entry.hooks {
395                if let SessionHookCommand::Regular(ref regular_hook) = entry.hook {
396                    if is_hook_equal(regular_hook, hook) {
397                        return Some(entry.clone());
398                    }
399                }
400            }
401        }
402    }
403
404    None
405}
406
407/// Clear all session hooks for a specific session
408pub fn clear_session_hooks(
409    set_app_state: &dyn Fn(&dyn Fn(&mut serde_json::Value)),
410    session_id: &str,
411) {
412    // Call set_app_state to notify state change (matches TypeScript behavior)
413    set_app_state(&|state: &mut serde_json::Value| {
414        if let Some(session_hooks) = state.get_mut("session_hooks") {
415            if let Some(session_map) = session_hooks.as_object_mut() {
416                session_map.remove(session_id);
417            }
418        }
419    });
420
421    let mut state = SESSION_HOOKS_STATE.lock().unwrap();
422    state.hooks.remove(session_id);
423
424    log_for_debugging(&format!(
425        "Cleared all session hooks for session {}",
426        session_id
427    ));
428}
429
430/// Convert session hook matchers to regular hook matchets
431fn convert_to_hook_matchers(
432    session_matchers: &[SessionHookMatcher],
433) -> Vec<SessionDerivedHookMatcher> {
434    session_matchers
435        .iter()
436        .map(|sm| SessionDerivedHookMatcher {
437            matcher: sm.matcher.clone(),
438            skill_root: sm.skill_root.clone(),
439            // Filter out function hooks - they can't be persisted to HookMatcher format
440            hooks: sm
441                .hooks
442                .iter()
443                .filter_map(|entry| {
444                    if let SessionHookCommand::Regular(ref h) = entry.hook {
445                        Some(h.clone())
446                    } else {
447                        None
448                    }
449                })
450                .collect(),
451        })
452        .collect()
453}
454
455/// Log for debugging
456fn log_for_debugging(msg: &str) {
457    log::debug!("{}", msg);
458}
459
460#[cfg(test)]
461mod tests {
462    use super::*;
463
464    #[test]
465    fn test_add_and_get_session_hook() {
466        // Clean up any leftover state from other tests
467        {
468            let mut state = SESSION_HOOKS_STATE.lock().unwrap();
469            state.hooks.remove("test-session");
470        }
471
472        let hook = HookCommand::Command {
473            command: "echo test".to_string(),
474            shell: None,
475            if_condition: None,
476            timeout: None,
477            status_message: None,
478            once: None,
479            r#async: None,
480            async_rewake: None,
481        };
482
483        // Use the internal state directly for testing
484        let mut state = SESSION_HOOKS_STATE.lock().unwrap();
485        let store = state
486            .hooks
487            .entry("test-session".to_string())
488            .or_insert_with(SessionStore::default);
489
490        store
491            .hooks
492            .entry(HookEvent::Stop)
493            .or_default()
494            .push(SessionHookMatcher {
495                matcher: String::new(),
496                skill_root: None,
497                hooks: vec![SessionHookEntry {
498                    hook: SessionHookCommand::Regular(hook.clone()),
499                    on_hook_success: None,
500                }],
501            });
502
503        // Verify it was added
504        let store = state.hooks.get("test-session").unwrap();
505        let stop_hooks = store.hooks.get(&HookEvent::Stop).unwrap();
506        assert_eq!(stop_hooks.len(), 1);
507    }
508
509    #[test]
510    fn test_clear_session_hooks() {
511        // Clean up any leftover state from other tests
512        {
513            let mut state = SESSION_HOOKS_STATE.lock().unwrap();
514            state.hooks.remove("clear-test-session");
515        }
516
517        // Add some hooks first
518        {
519            let mut state = SESSION_HOOKS_STATE.lock().unwrap();
520            let store = state
521                .hooks
522                .entry("clear-test-session".to_string())
523                .or_insert_with(SessionStore::default);
524
525            store
526                .hooks
527                .entry(HookEvent::Stop)
528                .or_default()
529                .push(SessionHookMatcher {
530                    matcher: String::new(),
531                    skill_root: None,
532                    hooks: vec![SessionHookEntry {
533                        hook: SessionHookCommand::Regular(HookCommand::Command {
534                            command: "echo test".to_string(),
535                            shell: None,
536                            if_condition: None,
537                            timeout: None,
538                            status_message: None,
539                            once: None,
540                            r#async: None,
541                            async_rewake: None,
542                        }),
543                        on_hook_success: None,
544                    }],
545                });
546        }
547
548        // Clear them
549        let _set_app_state = |_: &dyn Fn(&mut serde_json::Value)| {};
550        clear_session_hooks(&_set_app_state, "clear-test-session");
551
552        // Verify they were cleared
553        let state = SESSION_HOOKS_STATE.lock().unwrap();
554        assert!(state.hooks.get("clear-test-session").is_none());
555    }
556
557    #[test]
558    fn test_function_hook() {
559        let callback = Arc::new(|_messages: &[serde_json::Value]| true);
560        let hook = FunctionHook::new(
561            Some("test-fn-hook".to_string()),
562            Some(5000),
563            callback,
564            "Function hook failed".to_string(),
565        );
566
567        assert_eq!(hook.id, Some("test-fn-hook".to_string()));
568        assert_eq!(hook.timeout, Some(5000));
569    }
570}