Skip to main content

agy_bridge/hooks/
types.rs

1//! Hook bridge for the Antigravity SDK.
2//!
3//! Defines Rust-side hook types that wrap callbacks for agent lifecycle
4//! hook points: pre-turn, post-turn, pre-tool-call-decide, post-tool-call,
5//! compaction, session start/end, tool errors, user interactions, and
6//! tool-input transformation.
7//!
8//! The actual Python wrapping (creating `PyO3` classes that the SDK dispatches to)
9//! requires the Python runtime and is gated behind integration tests.
10
11use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14
15/// Result of a hook decision (mirrors SDK `HookResult`).
16#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct HookResult {
18    /// Whether execution should proceed.
19    pub allow: bool,
20    /// Optional explanation or response message.
21    pub message: String,
22}
23
24impl HookResult {
25    /// Create an "allow" result with an empty message.
26    #[must_use]
27    pub const fn allow() -> Self {
28        Self {
29            allow: true,
30            message: String::new(),
31        }
32    }
33
34    /// Create an "allow" result with a message.
35    #[must_use]
36    pub fn allow_with_message(message: impl Into<String>) -> Self {
37        Self {
38            allow: true,
39            message: message.into(),
40        }
41    }
42
43    /// Create a "deny" result with a reason.
44    #[must_use]
45    pub fn deny(reason: impl Into<String>) -> Self {
46        Self {
47            allow: false,
48            message: reason.into(),
49        }
50    }
51}
52
53// ── Hook context structs ────────────────────────────────────────────────────
54
55/// Persistent session metadata passed to session-lifecycle hooks.
56///
57/// Created when a session starts and carried through to session-end hooks
58/// so hooks can correlate events, measure session duration, and identify
59/// the agent instance.
60#[non_exhaustive]
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
62pub struct SessionContext {
63    /// Unique identifier for this session.
64    pub session_id: String,
65    /// Numeric agent identifier within the bridge runtime.
66    pub agent_id: u64,
67    /// Wall-clock timestamp of when the session was started.
68    ///
69    /// Defaults to [`UNIX_EPOCH`](SystemTime::UNIX_EPOCH) if the backend
70    /// does not supply this field.
71    #[serde(default = "SessionContext::default_started_at")]
72    pub started_at: SystemTime,
73}
74
75impl SessionContext {
76    /// Fallback value when `started_at` is absent from the JSON payload.
77    fn default_started_at() -> SystemTime {
78        SystemTime::UNIX_EPOCH
79    }
80}
81
82/// Context passed to [`HookPoint::OnSessionStart`] hooks.
83#[non_exhaustive]
84#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
85pub struct OnSessionStartContext {
86    /// Session metadata for the newly started session.
87    pub session: SessionContext,
88}
89
90/// Context passed to [`HookPoint::OnSessionEnd`] hooks.
91#[non_exhaustive]
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct OnSessionEndContext {
94    /// Session metadata for the ending session.
95    pub session: SessionContext,
96}
97
98/// Context passed to [`HookPoint::OnCompaction`] hooks.
99#[non_exhaustive]
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct OnCompactionContext {}
102
103/// Context passed to [`HookPoint::OnInteraction`] hooks.
104#[non_exhaustive]
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct OnInteractionContext {
107    /// The interaction message content.
108    pub message: String,
109}
110
111/// Context passed to [`HookPoint::PreTurn`] hooks.
112#[non_exhaustive]
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct PreTurnContext {
115    /// The user prompt for this turn.
116    pub prompt: String,
117    /// The 1-based turn number.
118    pub turn_number: u32,
119}
120
121impl PreTurnContext {
122    /// Create a new pre-turn context.
123    #[must_use]
124    pub fn new(prompt: impl Into<String>, turn_number: u32) -> Self {
125        Self {
126            prompt: prompt.into(),
127            turn_number,
128        }
129    }
130}
131
132/// Context passed to [`HookPoint::PostTurn`] hooks.
133#[non_exhaustive]
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct PostTurnContext {
136    /// The model's response text for this turn.
137    pub response_text: String,
138    /// The 1-based turn number.
139    pub turn_number: u32,
140}
141
142/// Context passed to [`HookPoint::PreToolCallDecide`] hooks.
143#[non_exhaustive]
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct PreToolCallDecideContext {
146    /// Name of the tool about to be called.
147    #[serde(alias = "name")]
148    pub tool_name: String,
149    /// Arguments the tool will receive.
150    #[serde(alias = "args", default)]
151    pub tool_args: serde_json::Value,
152}
153
154impl PreToolCallDecideContext {
155    /// Create a new pre-tool-call-decide context.
156    #[must_use]
157    pub fn new(tool_name: impl Into<String>, tool_args: serde_json::Value) -> Self {
158        Self {
159            tool_name: tool_name.into(),
160            tool_args,
161        }
162    }
163}
164
165/// Context passed to [`HookPoint::PostToolCall`] hooks.
166#[non_exhaustive]
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct PostToolCallContext {
169    /// Name of the tool that was called.
170    #[serde(alias = "name")]
171    pub tool_name: String,
172    /// Arguments the tool received.
173    #[serde(alias = "args", default)]
174    pub tool_args: serde_json::Value,
175    /// The tool's return value (serialised).
176    pub result: String,
177    /// Structured metadata from the tool response (if any).
178    #[serde(default)]
179    pub metadata: serde_json::Value,
180}
181
182/// Context passed to [`HookPoint::OnToolError`] hooks.
183#[non_exhaustive]
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct OnToolErrorContext {
186    /// Name of the tool that errored.
187    #[serde(alias = "name")]
188    pub tool_name: String,
189    /// Arguments the tool received.
190    #[serde(alias = "args", default)]
191    pub tool_args: serde_json::Value,
192    /// The error message.
193    pub error: String,
194}
195
196/// Identifies the point in the agent lifecycle where a hook fires.
197#[non_exhaustive]
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
199pub enum HookPoint {
200    /// Before the model processes a turn (receives the user prompt).
201    PreTurn,
202    /// After the model completes a turn (receives the model response).
203    PostTurn,
204    /// Before a tool call is executed — can approve or deny.
205    PreToolCallDecide,
206    /// After a tool call completes (receives the tool result).
207    PostToolCall,
208    /// Fires when the context window is compacted (trimmed to fit limits).
209    OnCompaction,
210    /// Fires when a new agent session begins.
211    OnSessionStart,
212    /// Fires when an agent session ends.
213    OnSessionEnd,
214    /// Fires when a tool call returns an error.
215    OnToolError,
216    /// Fires on each user interaction (message received from user).
217    OnInteraction,
218}
219
220impl HookPoint {
221    /// Human-readable label for logging.
222    #[must_use]
223    pub const fn label(self) -> &'static str {
224        match self {
225            Self::PreTurn => "pre_turn",
226            Self::PostTurn => "post_turn",
227            Self::PreToolCallDecide => "pre_tool_call_decide",
228            Self::PostToolCall => "post_tool_call",
229            Self::OnCompaction => "on_compaction",
230            Self::OnSessionStart => "on_session_start",
231            Self::OnSessionEnd => "on_session_end",
232            Self::OnToolError => "on_tool_error",
233            Self::OnInteraction => "on_interaction",
234        }
235    }
236}
237
238/// A named hook registration that will be attached to an agent.
239///
240/// The `callback_id` is an opaque identifier used to look up the actual
241/// Rust callback in the hook runner. This decouples serialization from
242/// function pointers.
243///
244/// # Construction
245///
246/// Prefer [`HookEntry::new`] which validates eagerly. Direct struct
247/// construction is allowed for deserialization but skips validation —
248/// call [`HookEntry::validate`] before use if constructing manually.
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct HookEntry {
251    /// Descriptive name (e.g. `"safety_gate"`).
252    pub name: String,
253    /// Which lifecycle point this hook fires at.
254    pub point: HookPoint,
255    /// Opaque callback identifier for the hook runner to resolve.
256    pub callback_id: String,
257}
258
259impl HookEntry {
260    /// Create a new hook entry, validating that `name` and `callback_id`
261    /// are non-empty.
262    ///
263    /// # Errors
264    ///
265    /// Returns [`Error::InvalidConfig`](crate::error::Error::InvalidConfig)
266    /// if `name` or `callback_id` is empty or whitespace-only.
267    ///
268    /// # Examples
269    ///
270    /// ```
271    /// # use agy_bridge::hooks::{HookEntry, HookPoint};
272    /// let entry = HookEntry::new("safety_gate", HookPoint::PreToolCallDecide, "cb_safety")
273    ///     .expect("valid entry");
274    /// assert_eq!(entry.name, "safety_gate");
275    /// ```
276    pub fn new(
277        name: impl Into<String>,
278        point: HookPoint,
279        callback_id: impl Into<String>,
280    ) -> Result<Self, crate::error::Error> {
281        let entry = Self {
282            name: name.into(),
283            point,
284            callback_id: callback_id.into(),
285        };
286        entry.validate()?;
287        Ok(entry)
288    }
289
290    /// Validate that the entry has non-empty name and `callback_id`.
291    ///
292    /// # Errors
293    ///
294    /// Returns `Err` with a description if the name or `callback_id` is empty.
295    pub fn validate(&self) -> Result<(), crate::error::Error> {
296        if self.name.trim().is_empty() {
297            return Err(crate::error::Error::InvalidConfig {
298                message: "HookEntry name must not be empty".to_owned(),
299            });
300        }
301        if self.callback_id.trim().is_empty() {
302            return Err(crate::error::Error::InvalidConfig {
303                message: format!("HookEntry '{}' has an empty callback_id", self.name),
304            });
305        }
306        Ok(())
307    }
308}
309
310/// An ordered list of hooks to attach to an agent.
311///
312/// Hooks at the same [`HookPoint`] fire in registration order.
313#[derive(Debug, Clone, Default, Serialize, Deserialize)]
314pub struct HookSet {
315    entries: Vec<HookEntry>,
316}
317
318impl HookSet {
319    /// Create an empty hook set.
320    #[must_use]
321    pub const fn new() -> Self {
322        Self {
323            entries: Vec::new(),
324        }
325    }
326
327    /// Register a hook.
328    ///
329    /// If a hook with the same name AND hook point already exists, it is
330    /// replaced and a warning is logged.
331    ///
332    /// # Errors
333    ///
334    /// Returns `Err` if the entry fails validation (empty name or `callback_id`).
335    pub fn push(&mut self, entry: HookEntry) -> Result<(), crate::error::Error> {
336        entry.validate()?;
337        if let Some(pos) = self
338            .entries
339            .iter()
340            .position(|e| e.name == entry.name && e.point == entry.point)
341        {
342            tracing::warn!(
343                hook = %entry.name,
344                point = %entry.point.label(),
345                "duplicate hook name+point in HookSet — replacing previous entry"
346            );
347            self.entries[pos] = entry;
348        } else {
349            self.entries.push(entry);
350        }
351        Ok(())
352    }
353
354    /// Iterate over hooks at a specific point, in registration order.
355    pub fn at_point(&self, point: HookPoint) -> impl Iterator<Item = &HookEntry> {
356        self.entries.iter().filter(move |e| e.point == point)
357    }
358
359    /// Iterate over all hooks.
360    pub fn iter(&self) -> impl Iterator<Item = &HookEntry> {
361        self.entries.iter()
362    }
363
364    /// Number of registered hooks.
365    #[must_use]
366    pub const fn len(&self) -> usize {
367        self.entries.len()
368    }
369
370    /// Whether the set is empty.
371    #[must_use]
372    pub const fn is_empty(&self) -> bool {
373        self.entries.is_empty()
374    }
375}
376
377impl From<HookSet> for Vec<HookEntry> {
378    fn from(set: HookSet) -> Self {
379        set.entries
380    }
381}
382
383impl From<&HookSet> for Vec<HookEntry> {
384    fn from(set: &HookSet) -> Self {
385        set.entries.clone()
386    }
387}
388
389impl IntoIterator for HookSet {
390    type Item = HookEntry;
391    type IntoIter = std::vec::IntoIter<Self::Item>;
392
393    fn into_iter(self) -> Self::IntoIter {
394        self.entries.into_iter()
395    }
396}
397
398impl FromIterator<HookEntry> for HookSet {
399    fn from_iter<T: IntoIterator<Item = HookEntry>>(iter: T) -> Self {
400        let mut set = Self::new();
401        for entry in iter {
402            let name = entry.name.clone();
403            if let Err(e) = set.push(entry) {
404                tracing::error!(
405                    error = %e,
406                    hook = %name,
407                    "Failed to push hook entry during from_iter"
408                );
409            }
410        }
411        set
412    }
413}
414
415impl From<Vec<HookEntry>> for HookSet {
416    fn from(entries: Vec<HookEntry>) -> Self {
417        Self::from_iter(entries)
418    }
419}
420
421impl<const N: usize> From<[HookEntry; N]> for HookSet {
422    fn from(entries: [HookEntry; N]) -> Self {
423        Self::from_iter(entries)
424    }
425}
426// ── Callback types ──────────────────────────────────────────────────────────
427
428/// Type alias for the transform-tool-input closure signature.
429///
430/// Accepts a pre-tool-call context and optionally returns replacement
431/// arguments.  `None` means "no change".
432type TransformToolInputFn =
433    dyn Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync;
434
435/// A registered hook callback, keyed by hook point.
436///
437/// Each variant wraps a boxed closure that receives the strongly-typed context
438/// for that hook point.  [`PreToolCallDecide`](Self::PreToolCallDecide) returns
439/// a [`HookResult`] so it can approve or deny tool execution; all other
440/// variants are fire-and-forget observers.
441#[non_exhaustive]
442pub enum HookCallback {
443    /// Callback invoked before each agent turn.
444    PreTurn(Box<dyn Fn(&PreTurnContext) + Send + Sync>),
445    /// Callback invoked after each agent turn completes.
446    PostTurn(Box<dyn Fn(&PostTurnContext) + Send + Sync>),
447    /// Callback invoked before deciding whether to execute a tool call.
448    PreToolCallDecide(Box<dyn Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync>),
449    /// Callback invoked after a tool call completes.
450    PostToolCall(Box<dyn Fn(&PostToolCallContext) + Send + Sync>),
451    /// Callback invoked when a tool call produces an error.
452    OnToolError(Box<dyn Fn(&OnToolErrorContext) + Send + Sync>),
453    /// Callback invoked when a new agent session begins.
454    OnSessionStart(Box<dyn Fn(&OnSessionStartContext) + Send + Sync>),
455    /// Callback invoked when an agent session ends.
456    OnSessionEnd(Box<dyn Fn(&OnSessionEndContext) + Send + Sync>),
457    /// Callback invoked when conversation history is compacted.
458    OnCompaction(Box<dyn Fn(&OnCompactionContext) + Send + Sync>),
459    /// Callback invoked on each interaction event.
460    OnInteraction(Box<dyn Fn(&OnInteractionContext) -> HookResult + Send + Sync>),
461    /// Transform tool input arguments before execution.
462    ///
463    /// The closure receives the pre-tool-call context and may return
464    /// `Some(new_args)` to replace the tool arguments, or `None` to
465    /// leave them unchanged.  Multiple transform hooks are applied
466    /// sequentially — each receives the (possibly already-modified)
467    /// arguments from the previous transform.
468    TransformToolInput(Box<TransformToolInputFn>),
469}
470
471impl HookCallback {
472    /// Returns the [`HookPoint`] this callback is associated with.
473    #[must_use]
474    pub(crate) const fn hook_point(&self) -> HookPoint {
475        match self {
476            Self::PreTurn(_) => HookPoint::PreTurn,
477            Self::PostTurn(_) => HookPoint::PostTurn,
478            Self::PreToolCallDecide(_) | Self::TransformToolInput(_) => {
479                HookPoint::PreToolCallDecide
480            }
481            Self::PostToolCall(_) => HookPoint::PostToolCall,
482            Self::OnToolError(_) => HookPoint::OnToolError,
483            Self::OnSessionStart(_) => HookPoint::OnSessionStart,
484            Self::OnSessionEnd(_) => HookPoint::OnSessionEnd,
485            Self::OnCompaction(_) => HookPoint::OnCompaction,
486            Self::OnInteraction(_) => HookPoint::OnInteraction,
487        }
488    }
489}
490
491// Manual Debug impl because closures don't implement Debug.
492impl std::fmt::Debug for HookCallback {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        f.write_str("HookCallback::")?;
495        match self {
496            Self::TransformToolInput(_) => f.write_str("transform_tool_input"),
497            other => f.write_str(other.hook_point().label()),
498        }
499    }
500}
501
502#[cfg(test)]
503mod tests {
504    use super::*;
505
506    #[test]
507    fn hook_result_allow() {
508        let r = HookResult::allow();
509        assert!(r.allow);
510        assert!(r.message.is_empty());
511    }
512
513    #[test]
514    fn hook_result_deny() {
515        let r = HookResult::deny("blocked by policy");
516        assert!(!r.allow);
517        assert_eq!(r.message, "blocked by policy");
518    }
519
520    #[test]
521    fn hook_result_allow_with_message() {
522        let r = HookResult::allow_with_message("proceeding with caution");
523        assert!(r.allow);
524        assert_eq!(r.message, "proceeding with caution");
525    }
526
527    #[test]
528    fn hook_point_labels() {
529        assert_eq!(HookPoint::PreTurn.label(), "pre_turn");
530        assert_eq!(HookPoint::PostTurn.label(), "post_turn");
531        assert_eq!(HookPoint::PreToolCallDecide.label(), "pre_tool_call_decide");
532        assert_eq!(HookPoint::PostToolCall.label(), "post_tool_call");
533        assert_eq!(HookPoint::OnCompaction.label(), "on_compaction");
534        assert_eq!(HookPoint::OnSessionStart.label(), "on_session_start");
535        assert_eq!(HookPoint::OnSessionEnd.label(), "on_session_end");
536        assert_eq!(HookPoint::OnToolError.label(), "on_tool_error");
537        assert_eq!(HookPoint::OnInteraction.label(), "on_interaction");
538    }
539
540    #[test]
541    fn hooks_fire_in_correct_order() {
542        let mut set = HookSet::new();
543        assert!(set.is_empty());
544
545        set.push(HookEntry {
546            name: "pre_turn_1".to_owned(),
547            point: HookPoint::PreTurn,
548            callback_id: "cb_pre1".to_owned(),
549        })
550        .unwrap();
551        set.push(HookEntry {
552            name: "pre_tool_decide".to_owned(),
553            point: HookPoint::PreToolCallDecide,
554            callback_id: "cb_decide".to_owned(),
555        })
556        .unwrap();
557        set.push(HookEntry {
558            name: "pre_turn_2".to_owned(),
559            point: HookPoint::PreTurn,
560            callback_id: "cb_pre2".to_owned(),
561        })
562        .unwrap();
563        set.push(HookEntry {
564            name: "post_turn_1".to_owned(),
565            point: HookPoint::PostTurn,
566            callback_id: "cb_post1".to_owned(),
567        })
568        .unwrap();
569        set.push(HookEntry {
570            name: "post_tool_1".to_owned(),
571            point: HookPoint::PostToolCall,
572            callback_id: "cb_posttool1".to_owned(),
573        })
574        .unwrap();
575
576        assert_eq!(set.len(), 5);
577
578        let pre_turn: Vec<&str> = set
579            .at_point(HookPoint::PreTurn)
580            .map(|e| e.name.as_str())
581            .collect();
582        assert_eq!(pre_turn, vec!["pre_turn_1", "pre_turn_2"]);
583
584        let decide: Vec<&str> = set
585            .at_point(HookPoint::PreToolCallDecide)
586            .map(|e| e.name.as_str())
587            .collect();
588        assert_eq!(decide, vec!["pre_tool_decide"]);
589
590        let post_turn: Vec<&str> = set
591            .at_point(HookPoint::PostTurn)
592            .map(|e| e.name.as_str())
593            .collect();
594        assert_eq!(post_turn, vec!["post_turn_1"]);
595
596        let post_tool: Vec<&str> = set
597            .at_point(HookPoint::PostToolCall)
598            .map(|e| e.name.as_str())
599            .collect();
600        assert_eq!(post_tool, vec!["post_tool_1"]);
601    }
602
603    #[test]
604    fn hook_entry_serde_roundtrip() {
605        let entry = HookEntry {
606            name: "my_hook".to_owned(),
607            point: HookPoint::PreToolCallDecide,
608            callback_id: "cb_123".to_owned(),
609        };
610        let json = serde_json::to_string(&entry).expect("serialize");
611        let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
612        assert_eq!(parsed.name, entry.name);
613        assert_eq!(parsed.point, entry.point);
614        assert_eq!(parsed.callback_id, entry.callback_id);
615    }
616
617    #[test]
618    fn hook_result_serde_roundtrip() {
619        let results = vec![
620            HookResult::allow(),
621            HookResult::deny("reason"),
622            HookResult::allow_with_message("ok"),
623        ];
624        for result in &results {
625            let json = serde_json::to_string(result).expect("serialize");
626            let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
627            assert_eq!(&parsed, result);
628        }
629    }
630
631    #[test]
632    fn hook_set_serde_roundtrip() {
633        let mut set = HookSet::new();
634        set.push(HookEntry {
635            name: "gate".to_owned(),
636            point: HookPoint::PreTurn,
637            callback_id: "cb_1".to_owned(),
638        })
639        .unwrap();
640        set.push(HookEntry {
641            name: "logger".to_owned(),
642            point: HookPoint::PostToolCall,
643            callback_id: "cb_2".to_owned(),
644        })
645        .unwrap();
646        let json = serde_json::to_string(&set).expect("serialize");
647        let parsed: HookSet = serde_json::from_str(&json).expect("deserialize");
648        assert_eq!(parsed.len(), 2);
649        let names: Vec<&str> = parsed.iter().map(|e| e.name.as_str()).collect();
650        assert_eq!(names, vec!["gate", "logger"]);
651    }
652
653    #[test]
654    fn hook_set_from_conversions() {
655        let mut set = HookSet::new();
656        set.push(HookEntry {
657            name: "gate".to_owned(),
658            point: HookPoint::PreTurn,
659            callback_id: "cb_1".to_owned(),
660        })
661        .unwrap();
662
663        let vec_from_owned: Vec<HookEntry> = Vec::from(set.clone());
664        assert_eq!(vec_from_owned.len(), 1);
665        assert_eq!(vec_from_owned[0].name, "gate");
666
667        let vec_from_ref: Vec<HookEntry> = Vec::from(&set);
668        assert_eq!(vec_from_ref.len(), 1);
669        assert_eq!(vec_from_ref[0].name, "gate");
670
671        let entry = HookEntry {
672            name: "gate".to_owned(),
673            point: HookPoint::PreTurn,
674            callback_id: "cb_1".to_owned(),
675        };
676        let set_from_arr = HookSet::from([entry.clone()]);
677        assert_eq!(set_from_arr.len(), 1);
678
679        let set_from_vec = HookSet::from(vec![entry]);
680        assert_eq!(set_from_vec.len(), 1);
681    }
682
683    #[test]
684    fn empty_hook_set_iteration_at_each_point() {
685        let set = HookSet::new();
686        for point in [
687            HookPoint::PreTurn,
688            HookPoint::PostTurn,
689            HookPoint::PreToolCallDecide,
690            HookPoint::PostToolCall,
691            HookPoint::OnCompaction,
692            HookPoint::OnSessionStart,
693            HookPoint::OnSessionEnd,
694            HookPoint::OnToolError,
695            HookPoint::OnInteraction,
696        ] {
697            assert_eq!(
698                set.at_point(point).count(),
699                0,
700                "Empty HookSet should have 0 hooks at {point:?}"
701            );
702        }
703    }
704
705    #[test]
706    fn hook_point_serde_roundtrip() {
707        let points = [
708            HookPoint::PreTurn,
709            HookPoint::PostTurn,
710            HookPoint::PreToolCallDecide,
711            HookPoint::PostToolCall,
712            HookPoint::OnCompaction,
713            HookPoint::OnSessionStart,
714            HookPoint::OnSessionEnd,
715            HookPoint::OnToolError,
716            HookPoint::OnInteraction,
717        ];
718        for point in points {
719            let json = serde_json::to_string(&point).expect("serialize");
720            let parsed: HookPoint = serde_json::from_str(&json).expect("deserialize");
721            assert_eq!(parsed, point);
722        }
723    }
724
725    #[test]
726    fn hook_set_default_is_empty() {
727        let set = HookSet::default();
728        assert!(set.is_empty());
729        assert_eq!(set.len(), 0);
730    }
731
732    #[test]
733    fn hook_set_multiple_hooks_at_same_point() {
734        let mut set = HookSet::new();
735        for i in 0..5 {
736            set.push(HookEntry {
737                name: format!("hook_{i}"),
738                point: HookPoint::PreToolCallDecide,
739                callback_id: format!("cb_{i}"),
740            })
741            .unwrap();
742        }
743        assert_eq!(set.len(), 5);
744        assert_eq!(set.at_point(HookPoint::PreToolCallDecide).count(), 5);
745        assert_eq!(set.at_point(HookPoint::PreTurn).count(), 0);
746    }
747
748    #[test]
749    fn hook_result_deny_with_string_owned() {
750        let reason = String::from("policy violation detected");
751        let r = HookResult::deny(reason.clone());
752        assert!(!r.allow);
753        assert_eq!(r.message, reason);
754    }
755
756    #[test]
757    fn hook_entry_with_new_hook_points() {
758        let new_points = [
759            (HookPoint::OnCompaction, "compaction_hook"),
760            (HookPoint::OnSessionStart, "session_start_hook"),
761            (HookPoint::OnSessionEnd, "session_end_hook"),
762            (HookPoint::OnToolError, "tool_error_hook"),
763            (HookPoint::OnInteraction, "interaction_hook"),
764        ];
765        let mut set = HookSet::new();
766        for (point, name) in &new_points {
767            set.push(HookEntry {
768                name: (*name).to_owned(),
769                point: *point,
770                callback_id: format!("cb_{name}"),
771            })
772            .unwrap();
773        }
774        assert_eq!(set.len(), 5);
775        for (point, name) in &new_points {
776            let hooks: Vec<&str> = set.at_point(*point).map(|e| e.name.as_str()).collect();
777            assert_eq!(hooks, vec![*name], "expected hook at {point:?}");
778        }
779    }
780
781    #[test]
782    fn hook_entry_serde_roundtrip_new_points() {
783        let new_points = [
784            HookPoint::OnCompaction,
785            HookPoint::OnSessionStart,
786            HookPoint::OnSessionEnd,
787            HookPoint::OnToolError,
788            HookPoint::OnInteraction,
789        ];
790        for point in new_points {
791            let entry = HookEntry {
792                name: format!("test_{}", point.label()),
793                point,
794                callback_id: format!("cb_{}", point.label()),
795            };
796            let json = serde_json::to_string(&entry).expect("serialize");
797            let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
798            assert_eq!(parsed.name, entry.name);
799            assert_eq!(parsed.point, entry.point);
800            assert_eq!(parsed.callback_id, entry.callback_id);
801        }
802    }
803
804    // ── SessionContext tests ────────────────────────────────────────────
805
806    #[test]
807    fn session_context_clone() {
808        let ctx = SessionContext {
809            session_id: "sess-1".into(),
810            agent_id: 42,
811            started_at: SystemTime::now(),
812        };
813        let cloned = ctx;
814        assert_eq!(cloned.session_id, "sess-1");
815        assert_eq!(cloned.agent_id, 42);
816    }
817
818    #[test]
819    fn session_context_debug_format() {
820        let ctx = SessionContext {
821            session_id: "sess-debug".into(),
822            agent_id: 1,
823            started_at: SystemTime::now(),
824        };
825        let dbg = format!("{ctx:?}");
826        assert!(dbg.contains("sess-debug"));
827        assert!(dbg.contains("agent_id: 1"));
828    }
829
830    #[test]
831    fn session_context_serde_roundtrip_preserves_started_at() {
832        let original = SessionContext {
833            session_id: "sess-rt".into(),
834            agent_id: 99,
835            started_at: SystemTime::now(),
836        };
837        let json = serde_json::to_string(&original).expect("serialize");
838        let parsed: SessionContext = serde_json::from_str(&json).expect("deserialize");
839
840        assert_eq!(parsed.session_id, original.session_id);
841        assert_eq!(parsed.agent_id, original.agent_id);
842        // SystemTime roundtrips through serde; Instant did not.
843        assert_eq!(parsed.started_at, original.started_at);
844    }
845
846    // ── HookEntry::new validated constructor tests ──────────────────────
847
848    #[test]
849    fn hook_entry_new_valid() {
850        let entry = HookEntry::new("safety_gate", HookPoint::PreToolCallDecide, "cb_safety")
851            .expect("valid entry");
852        assert_eq!(entry.name, "safety_gate");
853        assert_eq!(entry.point, HookPoint::PreToolCallDecide);
854        assert_eq!(entry.callback_id, "cb_safety");
855    }
856
857    #[test]
858    fn hook_entry_new_rejects_empty_name() {
859        let result = HookEntry::new("", HookPoint::PreTurn, "cb_1");
860        assert!(result.is_err(), "should reject empty name");
861    }
862
863    #[test]
864    fn hook_entry_new_rejects_whitespace_name() {
865        let result = HookEntry::new("   ", HookPoint::PreTurn, "cb_1");
866        assert!(result.is_err(), "should reject whitespace-only name");
867    }
868
869    #[test]
870    fn hook_entry_new_rejects_empty_callback_id() {
871        let result = HookEntry::new("my_hook", HookPoint::PreTurn, "");
872        assert!(result.is_err(), "should reject empty callback_id");
873    }
874
875    #[test]
876    fn hook_entry_new_rejects_whitespace_callback_id() {
877        let result = HookEntry::new("my_hook", HookPoint::PostTurn, "  ");
878        assert!(result.is_err(), "should reject whitespace-only callback_id");
879    }
880
881    #[test]
882    fn pre_tool_call_decide_context_serde_aliases() {
883        let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"}}"#;
884        let parsed_std: PreToolCallDecideContext = serde_json::from_str(json_std).unwrap();
885        assert_eq!(parsed_std.tool_name, "my_tool");
886        assert_eq!(parsed_std.tool_args["foo"], "bar");
887
888        let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"}}"#;
889        let parsed_alias: PreToolCallDecideContext = serde_json::from_str(json_alias).unwrap();
890        assert_eq!(parsed_alias.tool_name, "my_tool");
891        assert_eq!(parsed_alias.tool_args["foo"], "bar");
892    }
893
894    #[test]
895    fn pre_tool_call_decide_context_serde_default() {
896        let json_no_args = r#"{"name":"my_tool"}"#;
897        let parsed_no_args: PreToolCallDecideContext = serde_json::from_str(json_no_args).unwrap();
898        assert_eq!(parsed_no_args.tool_name, "my_tool");
899        assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
900    }
901
902    #[test]
903    fn post_tool_call_context_serde_aliases_and_default() {
904        let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"result":"success"}"#;
905        let parsed_std: PostToolCallContext = serde_json::from_str(json_std).unwrap();
906        assert_eq!(parsed_std.tool_name, "my_tool");
907        assert_eq!(parsed_std.tool_args["foo"], "bar");
908        assert_eq!(parsed_std.result, "success");
909
910        let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"result":"success"}"#;
911        let parsed_alias: PostToolCallContext = serde_json::from_str(json_alias).unwrap();
912        assert_eq!(parsed_alias.tool_name, "my_tool");
913        assert_eq!(parsed_alias.tool_args["foo"], "bar");
914        assert_eq!(parsed_alias.result, "success");
915
916        let json_no_args = r#"{"name":"my_tool","result":"success"}"#;
917        let parsed_no_args: PostToolCallContext = serde_json::from_str(json_no_args).unwrap();
918        assert_eq!(parsed_no_args.tool_name, "my_tool");
919        assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
920        assert_eq!(parsed_no_args.result, "success");
921    }
922
923    #[test]
924    fn on_tool_error_context_serde_aliases_and_default() {
925        let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"error":"failed"}"#;
926        let parsed_std: OnToolErrorContext = serde_json::from_str(json_std).unwrap();
927        assert_eq!(parsed_std.tool_name, "my_tool");
928        assert_eq!(parsed_std.tool_args["foo"], "bar");
929        assert_eq!(parsed_std.error, "failed");
930
931        let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"error":"failed"}"#;
932        let parsed_alias: OnToolErrorContext = serde_json::from_str(json_alias).unwrap();
933        assert_eq!(parsed_alias.tool_name, "my_tool");
934        assert_eq!(parsed_alias.tool_args["foo"], "bar");
935        assert_eq!(parsed_alias.error, "failed");
936
937        let json_no_args = r#"{"name":"my_tool","error":"failed"}"#;
938        let parsed_no_args: OnToolErrorContext = serde_json::from_str(json_no_args).unwrap();
939        assert_eq!(parsed_no_args.tool_name, "my_tool");
940        assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
941        assert_eq!(parsed_no_args.error, "failed");
942
943        let json_no_name = r#"{"error":"failed"}"#;
944        let parsed_no_name: Result<OnToolErrorContext, _> = serde_json::from_str(json_no_name);
945        assert!(parsed_no_name.is_err());
946    }
947}