Skip to main content

agy_bridge/hooks/
types.rs

1//! Hook bridge for the Antigravity SDK.
2//!
3//! Defines Rust-side hook types that wrap callbacks for agent lifecycle
4//! hook points: pre-turn, post-turn, pre-tool-call-decide, post-tool-call,
5//! compaction, session start/end, tool errors, user interactions, and
6//! tool-input transformation.
7//!
8//! The actual Python wrapping (creating `PyO3` classes that the SDK dispatches to)
9//! requires the Python runtime and is gated behind integration tests.
10
11use std::time::Instant;
12
13use serde::{Deserialize, Serialize};
14
15/// Result of a hook decision (mirrors SDK `HookResult`).
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct HookResult {
18    /// Whether execution should proceed.
19    pub allow: bool,
20    /// Optional explanation or response message.
21    pub message: String,
22}
23
24impl HookResult {
25    /// Create an "allow" result with an empty message.
26    #[must_use]
27    pub const fn allow() -> Self {
28        Self {
29            allow: true,
30            message: String::new(),
31        }
32    }
33
34    /// Create an "allow" result with a message.
35    #[must_use]
36    pub fn allow_with_message(message: impl Into<String>) -> Self {
37        Self {
38            allow: true,
39            message: message.into(),
40        }
41    }
42
43    /// Create a "deny" result with a reason.
44    #[must_use]
45    pub fn deny(reason: impl Into<String>) -> Self {
46        Self {
47            allow: false,
48            message: reason.into(),
49        }
50    }
51}
52
53// ── Hook context structs ────────────────────────────────────────────────────
54
55/// Persistent session metadata passed to session-lifecycle hooks.
56///
57/// Created when a session starts and carried through to session-end hooks
58/// so hooks can correlate events, measure session duration, and identify
59/// the agent instance.
60#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct SessionContext {
62    /// Unique identifier for this session.
63    pub session_id: String,
64    /// Numeric agent identifier within the bridge runtime.
65    pub agent_id: u64,
66    /// Monotonic timestamp of when the session was started.
67    #[serde(skip, default = "std::time::Instant::now")]
68    pub started_at: Instant,
69}
70
71/// Context passed to [`HookPoint::OnSessionStart`] hooks.
72#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub struct OnSessionStartContext {
74    /// Session metadata for the newly started session.
75    pub session: SessionContext,
76}
77
78/// Context passed to [`HookPoint::OnSessionEnd`] hooks.
79#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
80pub struct OnSessionEndContext {
81    /// Session metadata for the ending session.
82    pub session: SessionContext,
83}
84
85/// Context passed to [`HookPoint::OnCompaction`] hooks.
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct OnCompactionContext {}
88
89/// Context passed to [`HookPoint::OnInteraction`] hooks.
90#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct OnInteractionContext {
92    /// The interaction message content.
93    pub message: String,
94}
95
96/// Context passed to [`HookPoint::PreTurn`] hooks.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct PreTurnContext {
99    /// The user prompt for this turn.
100    pub prompt: String,
101    /// The 1-based turn number.
102    pub turn_number: u32,
103}
104
105/// Context passed to [`HookPoint::PostTurn`] hooks.
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct PostTurnContext {
108    /// The model's response text for this turn.
109    pub response_text: String,
110    /// The 1-based turn number.
111    pub turn_number: u32,
112}
113
114/// Context passed to [`HookPoint::PreToolCallDecide`] hooks.
115#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct PreToolCallDecideContext {
117    /// Name of the tool about to be called.
118    #[serde(alias = "name")]
119    pub tool_name: String,
120    /// Arguments the tool will receive.
121    #[serde(alias = "args", default)]
122    pub tool_args: serde_json::Value,
123}
124
125/// Context passed to [`HookPoint::PostToolCall`] hooks.
126#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct PostToolCallContext {
128    /// Name of the tool that was called.
129    #[serde(alias = "name")]
130    pub tool_name: String,
131    /// Arguments the tool received.
132    #[serde(alias = "args", default)]
133    pub tool_args: serde_json::Value,
134    /// The tool's return value (serialised).
135    pub result: String,
136}
137
138/// Context passed to [`HookPoint::OnToolError`] hooks.
139#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct OnToolErrorContext {
141    /// Name of the tool that errored.
142    #[serde(alias = "name")]
143    pub tool_name: String,
144    /// Arguments the tool received.
145    #[serde(alias = "args", default)]
146    pub tool_args: serde_json::Value,
147    /// The error message.
148    pub error: String,
149}
150
151/// Identifies the point in the agent lifecycle where a hook fires.
152#[non_exhaustive]
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
154pub enum HookPoint {
155    /// Before the model processes a turn (receives the user prompt).
156    PreTurn,
157    /// After the model completes a turn (receives the model response).
158    PostTurn,
159    /// Before a tool call is executed — can approve or deny.
160    PreToolCallDecide,
161    /// After a tool call completes (receives the tool result).
162    PostToolCall,
163    /// Fires when the context window is compacted (trimmed to fit limits).
164    OnCompaction,
165    /// Fires when a new agent session begins.
166    OnSessionStart,
167    /// Fires when an agent session ends.
168    OnSessionEnd,
169    /// Fires when a tool call returns an error.
170    OnToolError,
171    /// Fires on each user interaction (message received from user).
172    OnInteraction,
173}
174
175impl HookPoint {
176    /// Human-readable label for logging.
177    #[must_use]
178    pub const fn label(self) -> &'static str {
179        match self {
180            Self::PreTurn => "pre_turn",
181            Self::PostTurn => "post_turn",
182            Self::PreToolCallDecide => "pre_tool_call_decide",
183            Self::PostToolCall => "post_tool_call",
184            Self::OnCompaction => "on_compaction",
185            Self::OnSessionStart => "on_session_start",
186            Self::OnSessionEnd => "on_session_end",
187            Self::OnToolError => "on_tool_error",
188            Self::OnInteraction => "on_interaction",
189        }
190    }
191}
192
193/// A named hook registration that will be attached to an agent.
194///
195/// The `callback_id` is an opaque identifier used to look up the actual
196/// Rust callback in the hook runner. This decouples serialization from
197/// function pointers.
198///
199/// # Construction
200///
201/// Prefer [`HookEntry::new`] which validates eagerly. Direct struct
202/// construction is allowed for deserialization but skips validation —
203/// call [`HookEntry::validate`] before use if constructing manually.
204#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct HookEntry {
206    /// Descriptive name (e.g. `"safety_gate"`).
207    pub name: String,
208    /// Which lifecycle point this hook fires at.
209    pub point: HookPoint,
210    /// Opaque callback identifier for the hook runner to resolve.
211    pub callback_id: String,
212}
213
214impl HookEntry {
215    /// Create a new hook entry, validating that `name` and `callback_id`
216    /// are non-empty.
217    ///
218    /// # Errors
219    ///
220    /// Returns [`Error::InvalidConfig`](crate::error::Error::InvalidConfig)
221    /// if `name` or `callback_id` is empty or whitespace-only.
222    ///
223    /// # Examples
224    ///
225    /// ```
226    /// # use agy_bridge::hooks::{HookEntry, HookPoint};
227    /// let entry = HookEntry::new("safety_gate", HookPoint::PreToolCallDecide, "cb_safety")
228    ///     .expect("valid entry");
229    /// assert_eq!(entry.name, "safety_gate");
230    /// ```
231    pub fn new(
232        name: impl Into<String>,
233        point: HookPoint,
234        callback_id: impl Into<String>,
235    ) -> Result<Self, crate::error::Error> {
236        let entry = Self {
237            name: name.into(),
238            point,
239            callback_id: callback_id.into(),
240        };
241        entry.validate()?;
242        Ok(entry)
243    }
244
245    /// Validate that the entry has non-empty name and `callback_id`.
246    ///
247    /// # Errors
248    ///
249    /// Returns `Err` with a description if the name or `callback_id` is empty.
250    pub fn validate(&self) -> Result<(), crate::error::Error> {
251        if self.name.trim().is_empty() {
252            return Err(crate::error::Error::InvalidConfig {
253                message: "HookEntry name must not be empty".to_owned(),
254            });
255        }
256        if self.callback_id.trim().is_empty() {
257            return Err(crate::error::Error::InvalidConfig {
258                message: format!("HookEntry '{}' has an empty callback_id", self.name),
259            });
260        }
261        Ok(())
262    }
263}
264
265/// An ordered list of hooks to attach to an agent.
266///
267/// Hooks at the same [`HookPoint`] fire in registration order.
268#[derive(Debug, Clone, Default, Serialize, Deserialize)]
269pub struct HookSet {
270    entries: Vec<HookEntry>,
271}
272
273impl HookSet {
274    /// Create an empty hook set.
275    #[must_use]
276    pub const fn new() -> Self {
277        Self {
278            entries: Vec::new(),
279        }
280    }
281
282    /// Register a hook.
283    ///
284    /// If a hook with the same name AND hook point already exists, it is
285    /// replaced and a warning is logged.
286    ///
287    /// # Errors
288    ///
289    /// Returns `Err` if the entry fails validation (empty name or `callback_id`).
290    pub fn push(&mut self, entry: HookEntry) -> Result<(), crate::error::Error> {
291        entry.validate()?;
292        if let Some(pos) = self
293            .entries
294            .iter()
295            .position(|e| e.name == entry.name && e.point == entry.point)
296        {
297            tracing::warn!(
298                hook = %entry.name,
299                point = %entry.point.label(),
300                "duplicate hook name+point in HookSet — replacing previous entry"
301            );
302            self.entries[pos] = entry;
303        } else {
304            self.entries.push(entry);
305        }
306        Ok(())
307    }
308
309    /// Iterate over hooks at a specific point, in registration order.
310    pub fn at_point(&self, point: HookPoint) -> impl Iterator<Item = &HookEntry> {
311        self.entries.iter().filter(move |e| e.point == point)
312    }
313
314    /// Iterate over all hooks.
315    pub fn iter(&self) -> impl Iterator<Item = &HookEntry> {
316        self.entries.iter()
317    }
318
319    /// Number of registered hooks.
320    #[must_use]
321    pub const fn len(&self) -> usize {
322        self.entries.len()
323    }
324
325    /// Whether the set is empty.
326    #[must_use]
327    pub const fn is_empty(&self) -> bool {
328        self.entries.is_empty()
329    }
330}
331
332impl From<HookSet> for Vec<HookEntry> {
333    fn from(set: HookSet) -> Self {
334        set.entries
335    }
336}
337
338impl From<&HookSet> for Vec<HookEntry> {
339    fn from(set: &HookSet) -> Self {
340        set.entries.clone()
341    }
342}
343
344impl IntoIterator for HookSet {
345    type Item = HookEntry;
346    type IntoIter = std::vec::IntoIter<Self::Item>;
347
348    fn into_iter(self) -> Self::IntoIter {
349        self.entries.into_iter()
350    }
351}
352
353impl FromIterator<HookEntry> for HookSet {
354    fn from_iter<T: IntoIterator<Item = HookEntry>>(iter: T) -> Self {
355        let mut set = Self::new();
356        for entry in iter {
357            let name = entry.name.clone();
358            if let Err(e) = set.push(entry) {
359                tracing::error!(
360                    error = %e,
361                    hook = %name,
362                    "Failed to push hook entry during from_iter"
363                );
364            }
365        }
366        set
367    }
368}
369
370impl From<Vec<HookEntry>> for HookSet {
371    fn from(entries: Vec<HookEntry>) -> Self {
372        Self::from_iter(entries)
373    }
374}
375
376impl<const N: usize> From<[HookEntry; N]> for HookSet {
377    fn from(entries: [HookEntry; N]) -> Self {
378        Self::from_iter(entries)
379    }
380}
381// ── Callback types ──────────────────────────────────────────────────────────
382
383/// Type alias for the transform-tool-input closure signature.
384///
385/// Accepts a pre-tool-call context and optionally returns replacement
386/// arguments.  `None` means "no change".
387type TransformToolInputFn =
388    dyn Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync;
389
390/// A registered hook callback, keyed by hook point.
391///
392/// Each variant wraps a boxed closure that receives the strongly-typed context
393/// for that hook point.  [`PreToolCallDecide`](Self::PreToolCallDecide) returns
394/// a [`HookResult`] so it can approve or deny tool execution; all other
395/// variants are fire-and-forget observers.
396#[non_exhaustive]
397pub enum HookCallback {
398    /// Callback invoked before each agent turn.
399    PreTurn(Box<dyn Fn(&PreTurnContext) + Send + Sync>),
400    /// Callback invoked after each agent turn completes.
401    PostTurn(Box<dyn Fn(&PostTurnContext) + Send + Sync>),
402    /// Callback invoked before deciding whether to execute a tool call.
403    PreToolCallDecide(Box<dyn Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync>),
404    /// Callback invoked after a tool call completes.
405    PostToolCall(Box<dyn Fn(&PostToolCallContext) + Send + Sync>),
406    /// Callback invoked when a tool call produces an error.
407    OnToolError(Box<dyn Fn(&OnToolErrorContext) + Send + Sync>),
408    /// Callback invoked when a new agent session begins.
409    OnSessionStart(Box<dyn Fn(&OnSessionStartContext) + Send + Sync>),
410    /// Callback invoked when an agent session ends.
411    OnSessionEnd(Box<dyn Fn(&OnSessionEndContext) + Send + Sync>),
412    /// Callback invoked when conversation history is compacted.
413    OnCompaction(Box<dyn Fn(&OnCompactionContext) + Send + Sync>),
414    /// Callback invoked on each interaction event.
415    OnInteraction(Box<dyn Fn(&OnInteractionContext) -> HookResult + Send + Sync>),
416    /// Transform tool input arguments before execution.
417    ///
418    /// The closure receives the pre-tool-call context and may return
419    /// `Some(new_args)` to replace the tool arguments, or `None` to
420    /// leave them unchanged.  Multiple transform hooks are applied
421    /// sequentially — each receives the (possibly already-modified)
422    /// arguments from the previous transform.
423    TransformToolInput(Box<TransformToolInputFn>),
424}
425
426impl HookCallback {
427    /// Returns the [`HookPoint`] this callback is associated with.
428    #[must_use]
429    pub(crate) const fn hook_point(&self) -> HookPoint {
430        match self {
431            Self::PreTurn(_) => HookPoint::PreTurn,
432            Self::PostTurn(_) => HookPoint::PostTurn,
433            Self::PreToolCallDecide(_) | Self::TransformToolInput(_) => {
434                HookPoint::PreToolCallDecide
435            }
436            Self::PostToolCall(_) => HookPoint::PostToolCall,
437            Self::OnToolError(_) => HookPoint::OnToolError,
438            Self::OnSessionStart(_) => HookPoint::OnSessionStart,
439            Self::OnSessionEnd(_) => HookPoint::OnSessionEnd,
440            Self::OnCompaction(_) => HookPoint::OnCompaction,
441            Self::OnInteraction(_) => HookPoint::OnInteraction,
442        }
443    }
444}
445
446// Manual Debug impl because closures don't implement Debug.
447impl std::fmt::Debug for HookCallback {
448    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449        f.write_str("HookCallback::")?;
450        match self {
451            Self::TransformToolInput(_) => f.write_str("transform_tool_input"),
452            other => f.write_str(other.hook_point().label()),
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn hook_result_allow() {
463        let r = HookResult::allow();
464        assert!(r.allow);
465        assert!(r.message.is_empty());
466    }
467
468    #[test]
469    fn hook_result_deny() {
470        let r = HookResult::deny("blocked by policy");
471        assert!(!r.allow);
472        assert_eq!(r.message, "blocked by policy");
473    }
474
475    #[test]
476    fn hook_result_allow_with_message() {
477        let r = HookResult::allow_with_message("proceeding with caution");
478        assert!(r.allow);
479        assert_eq!(r.message, "proceeding with caution");
480    }
481
482    #[test]
483    fn hook_point_labels() {
484        assert_eq!(HookPoint::PreTurn.label(), "pre_turn");
485        assert_eq!(HookPoint::PostTurn.label(), "post_turn");
486        assert_eq!(HookPoint::PreToolCallDecide.label(), "pre_tool_call_decide");
487        assert_eq!(HookPoint::PostToolCall.label(), "post_tool_call");
488        assert_eq!(HookPoint::OnCompaction.label(), "on_compaction");
489        assert_eq!(HookPoint::OnSessionStart.label(), "on_session_start");
490        assert_eq!(HookPoint::OnSessionEnd.label(), "on_session_end");
491        assert_eq!(HookPoint::OnToolError.label(), "on_tool_error");
492        assert_eq!(HookPoint::OnInteraction.label(), "on_interaction");
493    }
494
495    #[test]
496    fn hooks_fire_in_correct_order() {
497        let mut set = HookSet::new();
498        assert!(set.is_empty());
499
500        set.push(HookEntry {
501            name: "pre_turn_1".to_owned(),
502            point: HookPoint::PreTurn,
503            callback_id: "cb_pre1".to_owned(),
504        })
505        .unwrap();
506        set.push(HookEntry {
507            name: "pre_tool_decide".to_owned(),
508            point: HookPoint::PreToolCallDecide,
509            callback_id: "cb_decide".to_owned(),
510        })
511        .unwrap();
512        set.push(HookEntry {
513            name: "pre_turn_2".to_owned(),
514            point: HookPoint::PreTurn,
515            callback_id: "cb_pre2".to_owned(),
516        })
517        .unwrap();
518        set.push(HookEntry {
519            name: "post_turn_1".to_owned(),
520            point: HookPoint::PostTurn,
521            callback_id: "cb_post1".to_owned(),
522        })
523        .unwrap();
524        set.push(HookEntry {
525            name: "post_tool_1".to_owned(),
526            point: HookPoint::PostToolCall,
527            callback_id: "cb_posttool1".to_owned(),
528        })
529        .unwrap();
530
531        assert_eq!(set.len(), 5);
532
533        let pre_turn: Vec<&str> = set
534            .at_point(HookPoint::PreTurn)
535            .map(|e| e.name.as_str())
536            .collect();
537        assert_eq!(pre_turn, vec!["pre_turn_1", "pre_turn_2"]);
538
539        let decide: Vec<&str> = set
540            .at_point(HookPoint::PreToolCallDecide)
541            .map(|e| e.name.as_str())
542            .collect();
543        assert_eq!(decide, vec!["pre_tool_decide"]);
544
545        let post_turn: Vec<&str> = set
546            .at_point(HookPoint::PostTurn)
547            .map(|e| e.name.as_str())
548            .collect();
549        assert_eq!(post_turn, vec!["post_turn_1"]);
550
551        let post_tool: Vec<&str> = set
552            .at_point(HookPoint::PostToolCall)
553            .map(|e| e.name.as_str())
554            .collect();
555        assert_eq!(post_tool, vec!["post_tool_1"]);
556    }
557
558    #[test]
559    fn hook_entry_serde_roundtrip() {
560        let entry = HookEntry {
561            name: "my_hook".to_owned(),
562            point: HookPoint::PreToolCallDecide,
563            callback_id: "cb_123".to_owned(),
564        };
565        let json = serde_json::to_string(&entry).expect("serialize");
566        let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
567        assert_eq!(parsed.name, entry.name);
568        assert_eq!(parsed.point, entry.point);
569        assert_eq!(parsed.callback_id, entry.callback_id);
570    }
571
572    #[test]
573    fn hook_result_serde_roundtrip() {
574        let results = vec![
575            HookResult::allow(),
576            HookResult::deny("reason"),
577            HookResult::allow_with_message("ok"),
578        ];
579        for result in &results {
580            let json = serde_json::to_string(result).expect("serialize");
581            let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
582            assert_eq!(&parsed, result);
583        }
584    }
585
586    #[test]
587    fn hook_set_serde_roundtrip() {
588        let mut set = HookSet::new();
589        set.push(HookEntry {
590            name: "gate".to_owned(),
591            point: HookPoint::PreTurn,
592            callback_id: "cb_1".to_owned(),
593        })
594        .unwrap();
595        set.push(HookEntry {
596            name: "logger".to_owned(),
597            point: HookPoint::PostToolCall,
598            callback_id: "cb_2".to_owned(),
599        })
600        .unwrap();
601        let json = serde_json::to_string(&set).expect("serialize");
602        let parsed: HookSet = serde_json::from_str(&json).expect("deserialize");
603        assert_eq!(parsed.len(), 2);
604        let names: Vec<&str> = parsed.iter().map(|e| e.name.as_str()).collect();
605        assert_eq!(names, vec!["gate", "logger"]);
606    }
607
608    #[test]
609    fn hook_set_from_conversions() {
610        let mut set = HookSet::new();
611        set.push(HookEntry {
612            name: "gate".to_owned(),
613            point: HookPoint::PreTurn,
614            callback_id: "cb_1".to_owned(),
615        })
616        .unwrap();
617
618        let vec_from_owned: Vec<HookEntry> = Vec::from(set.clone());
619        assert_eq!(vec_from_owned.len(), 1);
620        assert_eq!(vec_from_owned[0].name, "gate");
621
622        let vec_from_ref: Vec<HookEntry> = Vec::from(&set);
623        assert_eq!(vec_from_ref.len(), 1);
624        assert_eq!(vec_from_ref[0].name, "gate");
625
626        let entry = HookEntry {
627            name: "gate".to_owned(),
628            point: HookPoint::PreTurn,
629            callback_id: "cb_1".to_owned(),
630        };
631        let set_from_arr = HookSet::from([entry.clone()]);
632        assert_eq!(set_from_arr.len(), 1);
633
634        let set_from_vec = HookSet::from(vec![entry]);
635        assert_eq!(set_from_vec.len(), 1);
636    }
637
638    #[test]
639    fn empty_hook_set_iteration_at_each_point() {
640        let set = HookSet::new();
641        for point in [
642            HookPoint::PreTurn,
643            HookPoint::PostTurn,
644            HookPoint::PreToolCallDecide,
645            HookPoint::PostToolCall,
646            HookPoint::OnCompaction,
647            HookPoint::OnSessionStart,
648            HookPoint::OnSessionEnd,
649            HookPoint::OnToolError,
650            HookPoint::OnInteraction,
651        ] {
652            assert_eq!(
653                set.at_point(point).count(),
654                0,
655                "Empty HookSet should have 0 hooks at {point:?}"
656            );
657        }
658    }
659
660    #[test]
661    fn hook_point_serde_roundtrip() {
662        let points = [
663            HookPoint::PreTurn,
664            HookPoint::PostTurn,
665            HookPoint::PreToolCallDecide,
666            HookPoint::PostToolCall,
667            HookPoint::OnCompaction,
668            HookPoint::OnSessionStart,
669            HookPoint::OnSessionEnd,
670            HookPoint::OnToolError,
671            HookPoint::OnInteraction,
672        ];
673        for point in points {
674            let json = serde_json::to_string(&point).expect("serialize");
675            let parsed: HookPoint = serde_json::from_str(&json).expect("deserialize");
676            assert_eq!(parsed, point);
677        }
678    }
679
680    #[test]
681    fn hook_set_default_is_empty() {
682        let set = HookSet::default();
683        assert!(set.is_empty());
684        assert_eq!(set.len(), 0);
685    }
686
687    #[test]
688    fn hook_set_multiple_hooks_at_same_point() {
689        let mut set = HookSet::new();
690        for i in 0..5 {
691            set.push(HookEntry {
692                name: format!("hook_{i}"),
693                point: HookPoint::PreToolCallDecide,
694                callback_id: format!("cb_{i}"),
695            })
696            .unwrap();
697        }
698        assert_eq!(set.len(), 5);
699        assert_eq!(set.at_point(HookPoint::PreToolCallDecide).count(), 5);
700        assert_eq!(set.at_point(HookPoint::PreTurn).count(), 0);
701    }
702
703    #[test]
704    fn hook_result_deny_with_string_owned() {
705        let reason = String::from("policy violation detected");
706        let r = HookResult::deny(reason.clone());
707        assert!(!r.allow);
708        assert_eq!(r.message, reason);
709    }
710
711    #[test]
712    fn hook_entry_with_new_hook_points() {
713        let new_points = [
714            (HookPoint::OnCompaction, "compaction_hook"),
715            (HookPoint::OnSessionStart, "session_start_hook"),
716            (HookPoint::OnSessionEnd, "session_end_hook"),
717            (HookPoint::OnToolError, "tool_error_hook"),
718            (HookPoint::OnInteraction, "interaction_hook"),
719        ];
720        let mut set = HookSet::new();
721        for (point, name) in &new_points {
722            set.push(HookEntry {
723                name: (*name).to_owned(),
724                point: *point,
725                callback_id: format!("cb_{name}"),
726            })
727            .unwrap();
728        }
729        assert_eq!(set.len(), 5);
730        for (point, name) in &new_points {
731            let hooks: Vec<&str> = set.at_point(*point).map(|e| e.name.as_str()).collect();
732            assert_eq!(hooks, vec![*name], "expected hook at {point:?}");
733        }
734    }
735
736    #[test]
737    fn hook_entry_serde_roundtrip_new_points() {
738        let new_points = [
739            HookPoint::OnCompaction,
740            HookPoint::OnSessionStart,
741            HookPoint::OnSessionEnd,
742            HookPoint::OnToolError,
743            HookPoint::OnInteraction,
744        ];
745        for point in new_points {
746            let entry = HookEntry {
747                name: format!("test_{}", point.label()),
748                point,
749                callback_id: format!("cb_{}", point.label()),
750            };
751            let json = serde_json::to_string(&entry).expect("serialize");
752            let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
753            assert_eq!(parsed.name, entry.name);
754            assert_eq!(parsed.point, entry.point);
755            assert_eq!(parsed.callback_id, entry.callback_id);
756        }
757    }
758
759    // ── SessionContext tests ────────────────────────────────────────────
760
761    #[test]
762    fn session_context_clone() {
763        let ctx = SessionContext {
764            session_id: "sess-1".into(),
765            agent_id: 42,
766            started_at: Instant::now(),
767        };
768        let cloned = ctx;
769        assert_eq!(cloned.session_id, "sess-1");
770        assert_eq!(cloned.agent_id, 42);
771    }
772
773    #[test]
774    fn session_context_debug_format() {
775        let ctx = SessionContext {
776            session_id: "sess-debug".into(),
777            agent_id: 1,
778            started_at: Instant::now(),
779        };
780        let dbg = format!("{ctx:?}");
781        assert!(dbg.contains("sess-debug"));
782        assert!(dbg.contains("agent_id: 1"));
783    }
784
785    // ── HookEntry::new validated constructor tests ──────────────────────
786
787    #[test]
788    fn hook_entry_new_valid() {
789        let entry = HookEntry::new("safety_gate", HookPoint::PreToolCallDecide, "cb_safety")
790            .expect("valid entry");
791        assert_eq!(entry.name, "safety_gate");
792        assert_eq!(entry.point, HookPoint::PreToolCallDecide);
793        assert_eq!(entry.callback_id, "cb_safety");
794    }
795
796    #[test]
797    fn hook_entry_new_rejects_empty_name() {
798        let result = HookEntry::new("", HookPoint::PreTurn, "cb_1");
799        assert!(result.is_err(), "should reject empty name");
800    }
801
802    #[test]
803    fn hook_entry_new_rejects_whitespace_name() {
804        let result = HookEntry::new("   ", HookPoint::PreTurn, "cb_1");
805        assert!(result.is_err(), "should reject whitespace-only name");
806    }
807
808    #[test]
809    fn hook_entry_new_rejects_empty_callback_id() {
810        let result = HookEntry::new("my_hook", HookPoint::PreTurn, "");
811        assert!(result.is_err(), "should reject empty callback_id");
812    }
813
814    #[test]
815    fn hook_entry_new_rejects_whitespace_callback_id() {
816        let result = HookEntry::new("my_hook", HookPoint::PostTurn, "  ");
817        assert!(result.is_err(), "should reject whitespace-only callback_id");
818    }
819
820    #[test]
821    fn pre_tool_call_decide_context_serde_aliases() {
822        let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"}}"#;
823        let parsed_std: PreToolCallDecideContext = serde_json::from_str(json_std).unwrap();
824        assert_eq!(parsed_std.tool_name, "my_tool");
825        assert_eq!(parsed_std.tool_args["foo"], "bar");
826
827        let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"}}"#;
828        let parsed_alias: PreToolCallDecideContext = serde_json::from_str(json_alias).unwrap();
829        assert_eq!(parsed_alias.tool_name, "my_tool");
830        assert_eq!(parsed_alias.tool_args["foo"], "bar");
831    }
832
833    #[test]
834    fn pre_tool_call_decide_context_serde_default() {
835        let json_no_args = r#"{"name":"my_tool"}"#;
836        let parsed_no_args: PreToolCallDecideContext = serde_json::from_str(json_no_args).unwrap();
837        assert_eq!(parsed_no_args.tool_name, "my_tool");
838        assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
839    }
840
841    #[test]
842    fn post_tool_call_context_serde_aliases_and_default() {
843        let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"result":"success"}"#;
844        let parsed_std: PostToolCallContext = serde_json::from_str(json_std).unwrap();
845        assert_eq!(parsed_std.tool_name, "my_tool");
846        assert_eq!(parsed_std.tool_args["foo"], "bar");
847        assert_eq!(parsed_std.result, "success");
848
849        let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"result":"success"}"#;
850        let parsed_alias: PostToolCallContext = serde_json::from_str(json_alias).unwrap();
851        assert_eq!(parsed_alias.tool_name, "my_tool");
852        assert_eq!(parsed_alias.tool_args["foo"], "bar");
853        assert_eq!(parsed_alias.result, "success");
854
855        let json_no_args = r#"{"name":"my_tool","result":"success"}"#;
856        let parsed_no_args: PostToolCallContext = serde_json::from_str(json_no_args).unwrap();
857        assert_eq!(parsed_no_args.tool_name, "my_tool");
858        assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
859        assert_eq!(parsed_no_args.result, "success");
860    }
861
862    #[test]
863    fn on_tool_error_context_serde_aliases_and_default() {
864        let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"error":"failed"}"#;
865        let parsed_std: OnToolErrorContext = serde_json::from_str(json_std).unwrap();
866        assert_eq!(parsed_std.tool_name, "my_tool");
867        assert_eq!(parsed_std.tool_args["foo"], "bar");
868        assert_eq!(parsed_std.error, "failed");
869
870        let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"error":"failed"}"#;
871        let parsed_alias: OnToolErrorContext = serde_json::from_str(json_alias).unwrap();
872        assert_eq!(parsed_alias.tool_name, "my_tool");
873        assert_eq!(parsed_alias.tool_args["foo"], "bar");
874        assert_eq!(parsed_alias.error, "failed");
875
876        let json_no_args = r#"{"name":"my_tool","error":"failed"}"#;
877        let parsed_no_args: OnToolErrorContext = serde_json::from_str(json_no_args).unwrap();
878        assert_eq!(parsed_no_args.tool_name, "my_tool");
879        assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
880        assert_eq!(parsed_no_args.error, "failed");
881
882        let json_no_name = r#"{"error":"failed"}"#;
883        let parsed_no_name: Result<OnToolErrorContext, _> = serde_json::from_str(json_no_name);
884        assert!(parsed_no_name.is_err());
885    }
886}