1use std::time::SystemTime;
12
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct HookResult {
18 pub allow: bool,
20 pub message: String,
22}
23
24impl HookResult {
25 #[must_use]
27 pub const fn allow() -> Self {
28 Self {
29 allow: true,
30 message: String::new(),
31 }
32 }
33
34 #[must_use]
36 pub fn allow_with_message(message: impl Into<String>) -> Self {
37 Self {
38 allow: true,
39 message: message.into(),
40 }
41 }
42
43 #[must_use]
45 pub fn deny(reason: impl Into<String>) -> Self {
46 Self {
47 allow: false,
48 message: reason.into(),
49 }
50 }
51}
52
53#[non_exhaustive]
61#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
62pub struct SessionContext {
63 pub session_id: String,
65 pub agent_id: u64,
67 #[serde(default = "SessionContext::default_started_at")]
72 pub started_at: SystemTime,
73}
74
75impl SessionContext {
76 fn default_started_at() -> SystemTime {
78 SystemTime::UNIX_EPOCH
79 }
80}
81
82#[non_exhaustive]
84#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
85pub struct OnSessionStartContext {
86 pub session: SessionContext,
88}
89
90#[non_exhaustive]
92#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
93pub struct OnSessionEndContext {
94 pub session: SessionContext,
96}
97
98#[non_exhaustive]
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct OnCompactionContext {}
102
103#[non_exhaustive]
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct OnInteractionContext {
107 pub message: String,
109}
110
111#[non_exhaustive]
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct PreTurnContext {
115 pub prompt: String,
117 pub turn_number: u32,
119}
120
121impl PreTurnContext {
122 #[must_use]
124 pub fn new(prompt: impl Into<String>, turn_number: u32) -> Self {
125 Self {
126 prompt: prompt.into(),
127 turn_number,
128 }
129 }
130}
131
132#[non_exhaustive]
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct PostTurnContext {
136 pub response_text: String,
138 pub turn_number: u32,
140}
141
142#[non_exhaustive]
144#[derive(Debug, Clone, Serialize, Deserialize)]
145pub struct PreToolCallDecideContext {
146 #[serde(alias = "name")]
148 pub tool_name: String,
149 #[serde(alias = "args", default)]
151 pub tool_args: serde_json::Value,
152}
153
154impl PreToolCallDecideContext {
155 #[must_use]
157 pub fn new(tool_name: impl Into<String>, tool_args: serde_json::Value) -> Self {
158 Self {
159 tool_name: tool_name.into(),
160 tool_args,
161 }
162 }
163}
164
165#[non_exhaustive]
167#[derive(Debug, Clone, Serialize, Deserialize)]
168pub struct PostToolCallContext {
169 #[serde(alias = "name")]
171 pub tool_name: String,
172 #[serde(alias = "args", default)]
174 pub tool_args: serde_json::Value,
175 pub result: String,
177 #[serde(default)]
179 pub metadata: serde_json::Value,
180}
181
182#[non_exhaustive]
184#[derive(Debug, Clone, Serialize, Deserialize)]
185pub struct OnToolErrorContext {
186 #[serde(alias = "name")]
188 pub tool_name: String,
189 #[serde(alias = "args", default)]
191 pub tool_args: serde_json::Value,
192 pub error: String,
194}
195
196#[non_exhaustive]
198#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
199pub enum HookPoint {
200 PreTurn,
202 PostTurn,
204 PreToolCallDecide,
206 PostToolCall,
208 OnCompaction,
210 OnSessionStart,
212 OnSessionEnd,
214 OnToolError,
216 OnInteraction,
218}
219
220impl HookPoint {
221 #[must_use]
223 pub const fn label(self) -> &'static str {
224 match self {
225 Self::PreTurn => "pre_turn",
226 Self::PostTurn => "post_turn",
227 Self::PreToolCallDecide => "pre_tool_call_decide",
228 Self::PostToolCall => "post_tool_call",
229 Self::OnCompaction => "on_compaction",
230 Self::OnSessionStart => "on_session_start",
231 Self::OnSessionEnd => "on_session_end",
232 Self::OnToolError => "on_tool_error",
233 Self::OnInteraction => "on_interaction",
234 }
235 }
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
250pub struct HookEntry {
251 pub name: String,
253 pub point: HookPoint,
255 pub callback_id: String,
257}
258
259impl HookEntry {
260 pub fn new(
277 name: impl Into<String>,
278 point: HookPoint,
279 callback_id: impl Into<String>,
280 ) -> Result<Self, crate::error::Error> {
281 let entry = Self {
282 name: name.into(),
283 point,
284 callback_id: callback_id.into(),
285 };
286 entry.validate()?;
287 Ok(entry)
288 }
289
290 pub fn validate(&self) -> Result<(), crate::error::Error> {
296 if self.name.trim().is_empty() {
297 return Err(crate::error::Error::InvalidConfig {
298 message: "HookEntry name must not be empty".to_owned(),
299 });
300 }
301 if self.callback_id.trim().is_empty() {
302 return Err(crate::error::Error::InvalidConfig {
303 message: format!("HookEntry '{}' has an empty callback_id", self.name),
304 });
305 }
306 Ok(())
307 }
308}
309
310#[derive(Debug, Clone, Default, Serialize, Deserialize)]
314pub struct HookSet {
315 entries: Vec<HookEntry>,
316}
317
318impl HookSet {
319 #[must_use]
321 pub const fn new() -> Self {
322 Self {
323 entries: Vec::new(),
324 }
325 }
326
327 pub fn push(&mut self, entry: HookEntry) -> Result<(), crate::error::Error> {
336 entry.validate()?;
337 if let Some(pos) = self
338 .entries
339 .iter()
340 .position(|e| e.name == entry.name && e.point == entry.point)
341 {
342 tracing::warn!(
343 hook = %entry.name,
344 point = %entry.point.label(),
345 "duplicate hook name+point in HookSet — replacing previous entry"
346 );
347 self.entries[pos] = entry;
348 } else {
349 self.entries.push(entry);
350 }
351 Ok(())
352 }
353
354 pub fn at_point(&self, point: HookPoint) -> impl Iterator<Item = &HookEntry> {
356 self.entries.iter().filter(move |e| e.point == point)
357 }
358
359 pub fn iter(&self) -> impl Iterator<Item = &HookEntry> {
361 self.entries.iter()
362 }
363
364 #[must_use]
366 pub const fn len(&self) -> usize {
367 self.entries.len()
368 }
369
370 #[must_use]
372 pub const fn is_empty(&self) -> bool {
373 self.entries.is_empty()
374 }
375}
376
377impl From<HookSet> for Vec<HookEntry> {
378 fn from(set: HookSet) -> Self {
379 set.entries
380 }
381}
382
383impl From<&HookSet> for Vec<HookEntry> {
384 fn from(set: &HookSet) -> Self {
385 set.entries.clone()
386 }
387}
388
389impl IntoIterator for HookSet {
390 type Item = HookEntry;
391 type IntoIter = std::vec::IntoIter<Self::Item>;
392
393 fn into_iter(self) -> Self::IntoIter {
394 self.entries.into_iter()
395 }
396}
397
398impl FromIterator<HookEntry> for HookSet {
399 fn from_iter<T: IntoIterator<Item = HookEntry>>(iter: T) -> Self {
400 let mut set = Self::new();
401 for entry in iter {
402 let name = entry.name.clone();
403 if let Err(e) = set.push(entry) {
404 tracing::error!(
405 error = %e,
406 hook = %name,
407 "Failed to push hook entry during from_iter"
408 );
409 }
410 }
411 set
412 }
413}
414
415impl From<Vec<HookEntry>> for HookSet {
416 fn from(entries: Vec<HookEntry>) -> Self {
417 Self::from_iter(entries)
418 }
419}
420
421impl<const N: usize> From<[HookEntry; N]> for HookSet {
422 fn from(entries: [HookEntry; N]) -> Self {
423 Self::from_iter(entries)
424 }
425}
426type TransformToolInputFn =
433 dyn Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync;
434
435#[non_exhaustive]
442pub enum HookCallback {
443 PreTurn(Box<dyn Fn(&PreTurnContext) + Send + Sync>),
445 PostTurn(Box<dyn Fn(&PostTurnContext) + Send + Sync>),
447 PreToolCallDecide(Box<dyn Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync>),
449 PostToolCall(Box<dyn Fn(&PostToolCallContext) + Send + Sync>),
451 OnToolError(Box<dyn Fn(&OnToolErrorContext) + Send + Sync>),
453 OnSessionStart(Box<dyn Fn(&OnSessionStartContext) + Send + Sync>),
455 OnSessionEnd(Box<dyn Fn(&OnSessionEndContext) + Send + Sync>),
457 OnCompaction(Box<dyn Fn(&OnCompactionContext) + Send + Sync>),
459 OnInteraction(Box<dyn Fn(&OnInteractionContext) -> HookResult + Send + Sync>),
461 TransformToolInput(Box<TransformToolInputFn>),
469}
470
471impl HookCallback {
472 #[must_use]
474 pub(crate) const fn hook_point(&self) -> HookPoint {
475 match self {
476 Self::PreTurn(_) => HookPoint::PreTurn,
477 Self::PostTurn(_) => HookPoint::PostTurn,
478 Self::PreToolCallDecide(_) | Self::TransformToolInput(_) => {
479 HookPoint::PreToolCallDecide
480 }
481 Self::PostToolCall(_) => HookPoint::PostToolCall,
482 Self::OnToolError(_) => HookPoint::OnToolError,
483 Self::OnSessionStart(_) => HookPoint::OnSessionStart,
484 Self::OnSessionEnd(_) => HookPoint::OnSessionEnd,
485 Self::OnCompaction(_) => HookPoint::OnCompaction,
486 Self::OnInteraction(_) => HookPoint::OnInteraction,
487 }
488 }
489}
490
491impl std::fmt::Debug for HookCallback {
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 f.write_str("HookCallback::")?;
495 match self {
496 Self::TransformToolInput(_) => f.write_str("transform_tool_input"),
497 other => f.write_str(other.hook_point().label()),
498 }
499 }
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn hook_result_allow() {
508 let r = HookResult::allow();
509 assert!(r.allow);
510 assert!(r.message.is_empty());
511 }
512
513 #[test]
514 fn hook_result_deny() {
515 let r = HookResult::deny("blocked by policy");
516 assert!(!r.allow);
517 assert_eq!(r.message, "blocked by policy");
518 }
519
520 #[test]
521 fn hook_result_allow_with_message() {
522 let r = HookResult::allow_with_message("proceeding with caution");
523 assert!(r.allow);
524 assert_eq!(r.message, "proceeding with caution");
525 }
526
527 #[test]
528 fn hook_point_labels() {
529 assert_eq!(HookPoint::PreTurn.label(), "pre_turn");
530 assert_eq!(HookPoint::PostTurn.label(), "post_turn");
531 assert_eq!(HookPoint::PreToolCallDecide.label(), "pre_tool_call_decide");
532 assert_eq!(HookPoint::PostToolCall.label(), "post_tool_call");
533 assert_eq!(HookPoint::OnCompaction.label(), "on_compaction");
534 assert_eq!(HookPoint::OnSessionStart.label(), "on_session_start");
535 assert_eq!(HookPoint::OnSessionEnd.label(), "on_session_end");
536 assert_eq!(HookPoint::OnToolError.label(), "on_tool_error");
537 assert_eq!(HookPoint::OnInteraction.label(), "on_interaction");
538 }
539
540 #[test]
541 fn hooks_fire_in_correct_order() {
542 let mut set = HookSet::new();
543 assert!(set.is_empty());
544
545 set.push(HookEntry {
546 name: "pre_turn_1".to_owned(),
547 point: HookPoint::PreTurn,
548 callback_id: "cb_pre1".to_owned(),
549 })
550 .unwrap();
551 set.push(HookEntry {
552 name: "pre_tool_decide".to_owned(),
553 point: HookPoint::PreToolCallDecide,
554 callback_id: "cb_decide".to_owned(),
555 })
556 .unwrap();
557 set.push(HookEntry {
558 name: "pre_turn_2".to_owned(),
559 point: HookPoint::PreTurn,
560 callback_id: "cb_pre2".to_owned(),
561 })
562 .unwrap();
563 set.push(HookEntry {
564 name: "post_turn_1".to_owned(),
565 point: HookPoint::PostTurn,
566 callback_id: "cb_post1".to_owned(),
567 })
568 .unwrap();
569 set.push(HookEntry {
570 name: "post_tool_1".to_owned(),
571 point: HookPoint::PostToolCall,
572 callback_id: "cb_posttool1".to_owned(),
573 })
574 .unwrap();
575
576 assert_eq!(set.len(), 5);
577
578 let pre_turn: Vec<&str> = set
579 .at_point(HookPoint::PreTurn)
580 .map(|e| e.name.as_str())
581 .collect();
582 assert_eq!(pre_turn, vec!["pre_turn_1", "pre_turn_2"]);
583
584 let decide: Vec<&str> = set
585 .at_point(HookPoint::PreToolCallDecide)
586 .map(|e| e.name.as_str())
587 .collect();
588 assert_eq!(decide, vec!["pre_tool_decide"]);
589
590 let post_turn: Vec<&str> = set
591 .at_point(HookPoint::PostTurn)
592 .map(|e| e.name.as_str())
593 .collect();
594 assert_eq!(post_turn, vec!["post_turn_1"]);
595
596 let post_tool: Vec<&str> = set
597 .at_point(HookPoint::PostToolCall)
598 .map(|e| e.name.as_str())
599 .collect();
600 assert_eq!(post_tool, vec!["post_tool_1"]);
601 }
602
603 #[test]
604 fn hook_entry_serde_roundtrip() {
605 let entry = HookEntry {
606 name: "my_hook".to_owned(),
607 point: HookPoint::PreToolCallDecide,
608 callback_id: "cb_123".to_owned(),
609 };
610 let json = serde_json::to_string(&entry).expect("serialize");
611 let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
612 assert_eq!(parsed.name, entry.name);
613 assert_eq!(parsed.point, entry.point);
614 assert_eq!(parsed.callback_id, entry.callback_id);
615 }
616
617 #[test]
618 fn hook_result_serde_roundtrip() {
619 let results = vec![
620 HookResult::allow(),
621 HookResult::deny("reason"),
622 HookResult::allow_with_message("ok"),
623 ];
624 for result in &results {
625 let json = serde_json::to_string(result).expect("serialize");
626 let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
627 assert_eq!(&parsed, result);
628 }
629 }
630
631 #[test]
632 fn hook_set_serde_roundtrip() {
633 let mut set = HookSet::new();
634 set.push(HookEntry {
635 name: "gate".to_owned(),
636 point: HookPoint::PreTurn,
637 callback_id: "cb_1".to_owned(),
638 })
639 .unwrap();
640 set.push(HookEntry {
641 name: "logger".to_owned(),
642 point: HookPoint::PostToolCall,
643 callback_id: "cb_2".to_owned(),
644 })
645 .unwrap();
646 let json = serde_json::to_string(&set).expect("serialize");
647 let parsed: HookSet = serde_json::from_str(&json).expect("deserialize");
648 assert_eq!(parsed.len(), 2);
649 let names: Vec<&str> = parsed.iter().map(|e| e.name.as_str()).collect();
650 assert_eq!(names, vec!["gate", "logger"]);
651 }
652
653 #[test]
654 fn hook_set_from_conversions() {
655 let mut set = HookSet::new();
656 set.push(HookEntry {
657 name: "gate".to_owned(),
658 point: HookPoint::PreTurn,
659 callback_id: "cb_1".to_owned(),
660 })
661 .unwrap();
662
663 let vec_from_owned: Vec<HookEntry> = Vec::from(set.clone());
664 assert_eq!(vec_from_owned.len(), 1);
665 assert_eq!(vec_from_owned[0].name, "gate");
666
667 let vec_from_ref: Vec<HookEntry> = Vec::from(&set);
668 assert_eq!(vec_from_ref.len(), 1);
669 assert_eq!(vec_from_ref[0].name, "gate");
670
671 let entry = HookEntry {
672 name: "gate".to_owned(),
673 point: HookPoint::PreTurn,
674 callback_id: "cb_1".to_owned(),
675 };
676 let set_from_arr = HookSet::from([entry.clone()]);
677 assert_eq!(set_from_arr.len(), 1);
678
679 let set_from_vec = HookSet::from(vec![entry]);
680 assert_eq!(set_from_vec.len(), 1);
681 }
682
683 #[test]
684 fn empty_hook_set_iteration_at_each_point() {
685 let set = HookSet::new();
686 for point in [
687 HookPoint::PreTurn,
688 HookPoint::PostTurn,
689 HookPoint::PreToolCallDecide,
690 HookPoint::PostToolCall,
691 HookPoint::OnCompaction,
692 HookPoint::OnSessionStart,
693 HookPoint::OnSessionEnd,
694 HookPoint::OnToolError,
695 HookPoint::OnInteraction,
696 ] {
697 assert_eq!(
698 set.at_point(point).count(),
699 0,
700 "Empty HookSet should have 0 hooks at {point:?}"
701 );
702 }
703 }
704
705 #[test]
706 fn hook_point_serde_roundtrip() {
707 let points = [
708 HookPoint::PreTurn,
709 HookPoint::PostTurn,
710 HookPoint::PreToolCallDecide,
711 HookPoint::PostToolCall,
712 HookPoint::OnCompaction,
713 HookPoint::OnSessionStart,
714 HookPoint::OnSessionEnd,
715 HookPoint::OnToolError,
716 HookPoint::OnInteraction,
717 ];
718 for point in points {
719 let json = serde_json::to_string(&point).expect("serialize");
720 let parsed: HookPoint = serde_json::from_str(&json).expect("deserialize");
721 assert_eq!(parsed, point);
722 }
723 }
724
725 #[test]
726 fn hook_set_default_is_empty() {
727 let set = HookSet::default();
728 assert!(set.is_empty());
729 assert_eq!(set.len(), 0);
730 }
731
732 #[test]
733 fn hook_set_multiple_hooks_at_same_point() {
734 let mut set = HookSet::new();
735 for i in 0..5 {
736 set.push(HookEntry {
737 name: format!("hook_{i}"),
738 point: HookPoint::PreToolCallDecide,
739 callback_id: format!("cb_{i}"),
740 })
741 .unwrap();
742 }
743 assert_eq!(set.len(), 5);
744 assert_eq!(set.at_point(HookPoint::PreToolCallDecide).count(), 5);
745 assert_eq!(set.at_point(HookPoint::PreTurn).count(), 0);
746 }
747
748 #[test]
749 fn hook_result_deny_with_string_owned() {
750 let reason = String::from("policy violation detected");
751 let r = HookResult::deny(reason.clone());
752 assert!(!r.allow);
753 assert_eq!(r.message, reason);
754 }
755
756 #[test]
757 fn hook_entry_with_new_hook_points() {
758 let new_points = [
759 (HookPoint::OnCompaction, "compaction_hook"),
760 (HookPoint::OnSessionStart, "session_start_hook"),
761 (HookPoint::OnSessionEnd, "session_end_hook"),
762 (HookPoint::OnToolError, "tool_error_hook"),
763 (HookPoint::OnInteraction, "interaction_hook"),
764 ];
765 let mut set = HookSet::new();
766 for (point, name) in &new_points {
767 set.push(HookEntry {
768 name: (*name).to_owned(),
769 point: *point,
770 callback_id: format!("cb_{name}"),
771 })
772 .unwrap();
773 }
774 assert_eq!(set.len(), 5);
775 for (point, name) in &new_points {
776 let hooks: Vec<&str> = set.at_point(*point).map(|e| e.name.as_str()).collect();
777 assert_eq!(hooks, vec![*name], "expected hook at {point:?}");
778 }
779 }
780
781 #[test]
782 fn hook_entry_serde_roundtrip_new_points() {
783 let new_points = [
784 HookPoint::OnCompaction,
785 HookPoint::OnSessionStart,
786 HookPoint::OnSessionEnd,
787 HookPoint::OnToolError,
788 HookPoint::OnInteraction,
789 ];
790 for point in new_points {
791 let entry = HookEntry {
792 name: format!("test_{}", point.label()),
793 point,
794 callback_id: format!("cb_{}", point.label()),
795 };
796 let json = serde_json::to_string(&entry).expect("serialize");
797 let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
798 assert_eq!(parsed.name, entry.name);
799 assert_eq!(parsed.point, entry.point);
800 assert_eq!(parsed.callback_id, entry.callback_id);
801 }
802 }
803
804 #[test]
807 fn session_context_clone() {
808 let ctx = SessionContext {
809 session_id: "sess-1".into(),
810 agent_id: 42,
811 started_at: SystemTime::now(),
812 };
813 let cloned = ctx;
814 assert_eq!(cloned.session_id, "sess-1");
815 assert_eq!(cloned.agent_id, 42);
816 }
817
818 #[test]
819 fn session_context_debug_format() {
820 let ctx = SessionContext {
821 session_id: "sess-debug".into(),
822 agent_id: 1,
823 started_at: SystemTime::now(),
824 };
825 let dbg = format!("{ctx:?}");
826 assert!(dbg.contains("sess-debug"));
827 assert!(dbg.contains("agent_id: 1"));
828 }
829
830 #[test]
831 fn session_context_serde_roundtrip_preserves_started_at() {
832 let original = SessionContext {
833 session_id: "sess-rt".into(),
834 agent_id: 99,
835 started_at: SystemTime::now(),
836 };
837 let json = serde_json::to_string(&original).expect("serialize");
838 let parsed: SessionContext = serde_json::from_str(&json).expect("deserialize");
839
840 assert_eq!(parsed.session_id, original.session_id);
841 assert_eq!(parsed.agent_id, original.agent_id);
842 assert_eq!(parsed.started_at, original.started_at);
844 }
845
846 #[test]
849 fn hook_entry_new_valid() {
850 let entry = HookEntry::new("safety_gate", HookPoint::PreToolCallDecide, "cb_safety")
851 .expect("valid entry");
852 assert_eq!(entry.name, "safety_gate");
853 assert_eq!(entry.point, HookPoint::PreToolCallDecide);
854 assert_eq!(entry.callback_id, "cb_safety");
855 }
856
857 #[test]
858 fn hook_entry_new_rejects_empty_name() {
859 let result = HookEntry::new("", HookPoint::PreTurn, "cb_1");
860 assert!(result.is_err(), "should reject empty name");
861 }
862
863 #[test]
864 fn hook_entry_new_rejects_whitespace_name() {
865 let result = HookEntry::new(" ", HookPoint::PreTurn, "cb_1");
866 assert!(result.is_err(), "should reject whitespace-only name");
867 }
868
869 #[test]
870 fn hook_entry_new_rejects_empty_callback_id() {
871 let result = HookEntry::new("my_hook", HookPoint::PreTurn, "");
872 assert!(result.is_err(), "should reject empty callback_id");
873 }
874
875 #[test]
876 fn hook_entry_new_rejects_whitespace_callback_id() {
877 let result = HookEntry::new("my_hook", HookPoint::PostTurn, " ");
878 assert!(result.is_err(), "should reject whitespace-only callback_id");
879 }
880
881 #[test]
882 fn pre_tool_call_decide_context_serde_aliases() {
883 let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"}}"#;
884 let parsed_std: PreToolCallDecideContext = serde_json::from_str(json_std).unwrap();
885 assert_eq!(parsed_std.tool_name, "my_tool");
886 assert_eq!(parsed_std.tool_args["foo"], "bar");
887
888 let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"}}"#;
889 let parsed_alias: PreToolCallDecideContext = serde_json::from_str(json_alias).unwrap();
890 assert_eq!(parsed_alias.tool_name, "my_tool");
891 assert_eq!(parsed_alias.tool_args["foo"], "bar");
892 }
893
894 #[test]
895 fn pre_tool_call_decide_context_serde_default() {
896 let json_no_args = r#"{"name":"my_tool"}"#;
897 let parsed_no_args: PreToolCallDecideContext = serde_json::from_str(json_no_args).unwrap();
898 assert_eq!(parsed_no_args.tool_name, "my_tool");
899 assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
900 }
901
902 #[test]
903 fn post_tool_call_context_serde_aliases_and_default() {
904 let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"result":"success"}"#;
905 let parsed_std: PostToolCallContext = serde_json::from_str(json_std).unwrap();
906 assert_eq!(parsed_std.tool_name, "my_tool");
907 assert_eq!(parsed_std.tool_args["foo"], "bar");
908 assert_eq!(parsed_std.result, "success");
909
910 let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"result":"success"}"#;
911 let parsed_alias: PostToolCallContext = serde_json::from_str(json_alias).unwrap();
912 assert_eq!(parsed_alias.tool_name, "my_tool");
913 assert_eq!(parsed_alias.tool_args["foo"], "bar");
914 assert_eq!(parsed_alias.result, "success");
915
916 let json_no_args = r#"{"name":"my_tool","result":"success"}"#;
917 let parsed_no_args: PostToolCallContext = serde_json::from_str(json_no_args).unwrap();
918 assert_eq!(parsed_no_args.tool_name, "my_tool");
919 assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
920 assert_eq!(parsed_no_args.result, "success");
921 }
922
923 #[test]
924 fn on_tool_error_context_serde_aliases_and_default() {
925 let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"error":"failed"}"#;
926 let parsed_std: OnToolErrorContext = serde_json::from_str(json_std).unwrap();
927 assert_eq!(parsed_std.tool_name, "my_tool");
928 assert_eq!(parsed_std.tool_args["foo"], "bar");
929 assert_eq!(parsed_std.error, "failed");
930
931 let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"error":"failed"}"#;
932 let parsed_alias: OnToolErrorContext = serde_json::from_str(json_alias).unwrap();
933 assert_eq!(parsed_alias.tool_name, "my_tool");
934 assert_eq!(parsed_alias.tool_args["foo"], "bar");
935 assert_eq!(parsed_alias.error, "failed");
936
937 let json_no_args = r#"{"name":"my_tool","error":"failed"}"#;
938 let parsed_no_args: OnToolErrorContext = serde_json::from_str(json_no_args).unwrap();
939 assert_eq!(parsed_no_args.tool_name, "my_tool");
940 assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
941 assert_eq!(parsed_no_args.error, "failed");
942
943 let json_no_name = r#"{"error":"failed"}"#;
944 let parsed_no_name: Result<OnToolErrorContext, _> = serde_json::from_str(json_no_name);
945 assert!(parsed_no_name.is_err());
946 }
947}