Skip to main content

codex_runtime/runtime/
hooks.rs

1use std::collections::{HashMap, HashSet};
2use std::sync::{Arc, RwLock};
3
4use crate::plugin::{
5    BlockReason, HookAction, HookContext, HookIssue, HookPhase, HookReport, PostHook, PreHook,
6};
7
8#[derive(Clone, Default)]
9pub struct RuntimeHookConfig {
10    pub pre_hooks: Vec<Arc<dyn PreHook>>,
11    pub post_hooks: Vec<Arc<dyn PostHook>>,
12    /// Hooks that fire specifically for PreToolUse phase via the internal approval loop.
13    /// When non-empty, the runtime manages the approval channel internally and auto-escalates
14    /// ApprovalPolicy from Never → Untrusted so codex sends approval requests.
15    pub pre_tool_use_hooks: Vec<Arc<dyn PreHook>>,
16}
17
18impl std::fmt::Debug for RuntimeHookConfig {
19    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
20        f.debug_struct("RuntimeHookConfig")
21            .field("pre_hooks", &hook_names(&self.pre_hooks))
22            .field("post_hooks", &hook_names(&self.post_hooks))
23            .field("pre_tool_use_hooks", &hook_names(&self.pre_tool_use_hooks))
24            .finish()
25    }
26}
27
28impl PartialEq for RuntimeHookConfig {
29    fn eq(&self, other: &Self) -> bool {
30        hook_names(&self.pre_hooks) == hook_names(&other.pre_hooks)
31            && hook_names(&self.post_hooks) == hook_names(&other.post_hooks)
32            && hook_names(&self.pre_tool_use_hooks) == hook_names(&other.pre_tool_use_hooks)
33    }
34}
35
36impl Eq for RuntimeHookConfig {}
37
38impl RuntimeHookConfig {
39    /// Create empty hook config.
40    /// Allocation: none. Complexity: O(1).
41    pub fn new() -> Self {
42        Self::default()
43    }
44
45    /// Register one pre hook.
46    /// Allocation: amortized O(1) push. Complexity: O(1).
47    pub fn with_pre_hook(mut self, hook: Arc<dyn PreHook>) -> Self {
48        self.pre_hooks.push(hook);
49        self
50    }
51
52    /// Register one post hook.
53    /// Allocation: amortized O(1) push. Complexity: O(1).
54    pub fn with_post_hook(mut self, hook: Arc<dyn PostHook>) -> Self {
55        self.post_hooks.push(hook);
56        self
57    }
58
59    /// Register one pre-tool-use hook (fires in PreToolUse phase via the approval loop).
60    /// Allocation: amortized O(1) push. Complexity: O(1).
61    pub fn with_pre_tool_use_hook(mut self, hook: Arc<dyn PreHook>) -> Self {
62        self.pre_tool_use_hooks.push(hook);
63        self
64    }
65
66    /// True when at least one tool-use hook is registered.
67    /// Allocation: none. Complexity: O(1).
68    pub fn has_pre_tool_use_hooks(&self) -> bool {
69        !self.pre_tool_use_hooks.is_empty()
70    }
71
72    /// True when at least one hook of any kind is configured.
73    /// Allocation: none. Complexity: O(1).
74    pub fn is_empty(&self) -> bool {
75        self.pre_hooks.is_empty()
76            && self.post_hooks.is_empty()
77            && self.pre_tool_use_hooks.is_empty()
78    }
79}
80
81/// Merge default hooks with overlay hooks.
82/// Ordering is overlay-first so duplicate names prefer overlay entries.
83pub(crate) fn merge_hook_configs(
84    defaults: &RuntimeHookConfig,
85    overlay: &RuntimeHookConfig,
86) -> RuntimeHookConfig {
87    if defaults.is_empty() {
88        return overlay.clone();
89    }
90    if overlay.is_empty() {
91        return defaults.clone();
92    }
93    RuntimeHookConfig {
94        pre_hooks: merge_preferred_hooks(&overlay.pre_hooks, &defaults.pre_hooks),
95        post_hooks: merge_preferred_hooks(&overlay.post_hooks, &defaults.post_hooks),
96        pre_tool_use_hooks: merge_preferred_hooks(
97            &overlay.pre_tool_use_hooks,
98            &defaults.pre_tool_use_hooks,
99        ),
100    }
101}
102
103pub(crate) struct HookKernel {
104    pre_hooks: RwLock<Vec<Arc<dyn PreHook>>>,
105    post_hooks: RwLock<Vec<Arc<dyn PostHook>>>,
106    pre_tool_use_hooks: RwLock<Vec<Arc<dyn PreHook>>>,
107    thread_scoped_pre_tool_use_hooks: RwLock<HashMap<String, Vec<Arc<dyn PreHook>>>>,
108    latest_report: RwLock<HookReport>,
109}
110
111#[derive(Clone, Debug)]
112pub(crate) struct PreHookDecision {
113    pub hook_name: String,
114    pub action: HookAction,
115}
116
117impl HookKernel {
118    pub(crate) fn new(config: RuntimeHookConfig) -> Self {
119        Self {
120            pre_hooks: RwLock::new(config.pre_hooks),
121            post_hooks: RwLock::new(config.post_hooks),
122            pre_tool_use_hooks: RwLock::new(config.pre_tool_use_hooks),
123            thread_scoped_pre_tool_use_hooks: RwLock::new(HashMap::new()),
124            latest_report: RwLock::new(HookReport::default()),
125        }
126    }
127
128    pub(crate) fn is_enabled(&self) -> bool {
129        rwlock_len(&self.pre_hooks) > 0
130            || rwlock_len(&self.post_hooks) > 0
131            || rwlock_len(&self.pre_tool_use_hooks) > 0
132    }
133
134    /// True when at least one pre-tool-use hook is registered.
135    /// Allocation: none (read lock only). Complexity: O(1).
136    pub(crate) fn has_pre_tool_use_hooks(&self) -> bool {
137        rwlock_len(&self.pre_tool_use_hooks) > 0
138            || match self.thread_scoped_pre_tool_use_hooks.read() {
139                Ok(guard) => guard.values().any(|hooks| !hooks.is_empty()),
140                Err(poisoned) => poisoned
141                    .into_inner()
142                    .values()
143                    .any(|hooks| !hooks.is_empty()),
144            }
145    }
146
147    pub(crate) fn register_thread_scoped_pre_tool_use_hooks(
148        &self,
149        thread_id: &str,
150        hooks: &[Arc<dyn PreHook>],
151    ) {
152        if hooks.is_empty() {
153            return;
154        }
155        let mut guard = match self.thread_scoped_pre_tool_use_hooks.write() {
156            Ok(guard) => guard,
157            Err(poisoned) => poisoned.into_inner(),
158        };
159        let entry = guard.entry(thread_id.to_owned()).or_default();
160        let mut names: HashSet<&'static str> = entry.iter().map(|hook| hook.hook_name()).collect();
161        for hook in hooks {
162            if names.insert(hook.hook_name()) {
163                entry.push(Arc::clone(hook));
164            }
165        }
166    }
167
168    pub(crate) fn clear_thread_scoped_pre_tool_use_hooks(&self, thread_id: &str) {
169        let mut guard = match self.thread_scoped_pre_tool_use_hooks.write() {
170            Ok(guard) => guard,
171            Err(poisoned) => poisoned.into_inner(),
172        };
173        guard.remove(thread_id);
174    }
175
176    /// Register additional hooks into runtime kernel.
177    /// Duplicate names are ignored to keep execution deterministic.
178    /// Allocation: O(n) for name set snapshot. Complexity: O(n + m), n=existing, m=incoming.
179    pub(crate) fn register(&self, config: RuntimeHookConfig) {
180        if config.is_empty() {
181            return;
182        }
183        register_dedup_hooks(&self.pre_hooks, config.pre_hooks);
184        register_dedup_hooks(&self.post_hooks, config.post_hooks);
185        register_dedup_hooks(&self.pre_tool_use_hooks, config.pre_tool_use_hooks);
186    }
187
188    pub(crate) fn report_snapshot(&self) -> HookReport {
189        match self.latest_report.read() {
190            Ok(guard) => guard.clone(),
191            Err(poisoned) => poisoned.into_inner().clone(),
192        }
193    }
194
195    pub(crate) fn set_latest_report(&self, report: HookReport) {
196        match self.latest_report.write() {
197            Ok(mut guard) => *guard = report,
198            Err(poisoned) => *poisoned.into_inner() = report,
199        }
200    }
201
202    /// Execute global pre hooks plus optional scoped hooks for one call.
203    /// Scoped hooks are appended after globals and deduplicated by hook name.
204    /// Returns `Err(BlockReason)` on the first hook that returns `HookAction::Block`.
205    /// Subsequent hooks are not executed. Allocation: O(n) decisions vec.
206    pub(crate) async fn run_pre_with(
207        &self,
208        ctx: &HookContext,
209        report: &mut HookReport,
210        scoped: Option<&RuntimeHookConfig>,
211    ) -> Result<Vec<PreHookDecision>, BlockReason> {
212        let hooks = merge_owned_with_overlay(
213            read_rwlock_vec(&self.pre_hooks),
214            scoped.map(|cfg| cfg.pre_hooks.as_slice()),
215        );
216        let mut decisions = Vec::with_capacity(hooks.len());
217        for hook in hooks {
218            match hook.call(ctx).await {
219                Ok(HookAction::Block(reason)) => return Err(reason),
220                Ok(action) => decisions.push(PreHookDecision {
221                    hook_name: hook.name().to_owned(),
222                    action,
223                }),
224                Err(issue) => report.push(normalize_issue(issue, hook.name(), ctx.phase)),
225            }
226        }
227        Ok(decisions)
228    }
229
230    /// Execute pre-tool-use hooks for one approval request.
231    /// Returns `Err(BlockReason)` on the first hook that blocks (→ deny approval).
232    /// Returns `Ok(())` when all hooks pass (→ approve).
233    /// Allocation: O(n) hook vec clone. Complexity: O(n), n = hook count.
234    pub(crate) async fn run_pre_tool_use_with(
235        &self,
236        ctx: &HookContext,
237        report: &mut HookReport,
238    ) -> Result<(), BlockReason> {
239        let mut hooks = read_rwlock_vec(&self.pre_tool_use_hooks);
240        if let Some(thread_id) = ctx.thread_id.as_deref() {
241            let scoped = self.thread_scoped_pre_tool_use_hooks_for(thread_id);
242            hooks = merge_owned_with_overlay(hooks, scoped.as_deref());
243        }
244        for hook in hooks {
245            match hook.call(ctx).await {
246                Ok(HookAction::Block(reason)) => return Err(reason),
247                Ok(_) => {}
248                Err(issue) => report.push(normalize_issue(issue, hook.name(), ctx.phase)),
249            }
250        }
251        Ok(())
252    }
253
254    fn thread_scoped_pre_tool_use_hooks_for(
255        &self,
256        thread_id: &str,
257    ) -> Option<Vec<Arc<dyn PreHook>>> {
258        let guard = match self.thread_scoped_pre_tool_use_hooks.read() {
259            Ok(guard) => guard,
260            Err(poisoned) => poisoned.into_inner(),
261        };
262        guard.get(thread_id).cloned()
263    }
264
265    /// Execute global post hooks plus optional scoped hooks for one call.
266    /// Scoped hooks are appended after globals and deduplicated by hook name.
267    pub(crate) async fn run_post_with(
268        &self,
269        ctx: &HookContext,
270        report: &mut HookReport,
271        scoped: Option<&RuntimeHookConfig>,
272    ) {
273        let hooks = merge_owned_with_overlay(
274            read_rwlock_vec(&self.post_hooks),
275            scoped.map(|cfg| cfg.post_hooks.as_slice()),
276        );
277        for hook in hooks {
278            if let Err(issue) = hook.call(ctx).await {
279                report.push(normalize_issue(issue, hook.name(), ctx.phase));
280            }
281        }
282    }
283}
284
285fn normalize_issue(mut issue: HookIssue, fallback_name: &str, phase: HookPhase) -> HookIssue {
286    if issue.hook_name.trim().is_empty() {
287        issue.hook_name = fallback_name.to_owned();
288    }
289    issue.phase = phase;
290    issue
291}
292
293fn hook_names<T>(hooks: &[Arc<T>]) -> Vec<&'static str>
294where
295    T: ?Sized + HookName,
296{
297    hooks.iter().map(|hook| hook.hook_name()).collect()
298}
299
300trait HookName {
301    fn hook_name(&self) -> &'static str;
302}
303
304impl HookName for dyn PreHook {
305    fn hook_name(&self) -> &'static str {
306        self.name()
307    }
308}
309
310impl HookName for dyn PostHook {
311    fn hook_name(&self) -> &'static str {
312        self.name()
313    }
314}
315
316/// Read the length of a poisoning-safe RwLock hook vec without cloning.
317/// Allocation: none. Complexity: O(1).
318fn rwlock_len<T: ?Sized>(target: &RwLock<Vec<Arc<T>>>) -> usize {
319    match target.read() {
320        Ok(guard) => guard.len(),
321        Err(poisoned) => poisoned.into_inner().len(),
322    }
323}
324
325/// Read a poisoning-safe RwLock clone of the hook vec.
326/// Allocation: clones Vec + its Arc entries. Complexity: O(n), n=hook count.
327fn read_rwlock_vec<T: ?Sized>(target: &RwLock<Vec<Arc<T>>>) -> Vec<Arc<T>> {
328    match target.read() {
329        Ok(guard) => guard.clone(),
330        Err(poisoned) => poisoned.into_inner().clone(),
331    }
332}
333
334fn merge_preferred_hooks<T>(preferred: &[Arc<T>], fallback: &[Arc<T>]) -> Vec<Arc<T>>
335where
336    T: ?Sized + HookName,
337{
338    let mut merged = Vec::with_capacity(preferred.len() + fallback.len());
339    let mut names: HashSet<&'static str> = HashSet::with_capacity(preferred.len() + fallback.len());
340    for hook in preferred {
341        if names.insert(hook.hook_name()) {
342            merged.push(Arc::clone(hook));
343        }
344    }
345    for hook in fallback {
346        if names.insert(hook.hook_name()) {
347            merged.push(Arc::clone(hook));
348        }
349    }
350    merged
351}
352
353fn merge_owned_with_overlay<T>(mut base: Vec<Arc<T>>, overlay: Option<&[Arc<T>]>) -> Vec<Arc<T>>
354where
355    T: ?Sized + HookName,
356{
357    let Some(overlay) = overlay else {
358        return base;
359    };
360    if overlay.is_empty() {
361        return base;
362    }
363    let mut names: HashSet<&'static str> = base.iter().map(|hook| hook.hook_name()).collect();
364    for hook in overlay {
365        if names.insert(hook.hook_name()) {
366            base.push(Arc::clone(hook));
367        }
368    }
369    base
370}
371
372/// Register incoming hooks deduplicating by name. Poison-safe.
373/// Allocation: one HashSet per call. Complexity: O(n + m), n=existing, m=incoming.
374fn register_dedup_hooks<T>(target: &RwLock<Vec<Arc<T>>>, incoming: Vec<Arc<T>>)
375where
376    T: ?Sized + HookName,
377{
378    let mut guard = match target.write() {
379        Ok(guard) => guard,
380        Err(poisoned) => poisoned.into_inner(),
381    };
382    let mut names: HashSet<&'static str> = guard.iter().map(|hook| hook.hook_name()).collect();
383    for hook in incoming {
384        if names.insert(hook.hook_name()) {
385            guard.push(hook);
386        }
387    }
388}