1use std::time::Instant;
12
13use serde::{Deserialize, Serialize};
14
15#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
17pub struct HookResult {
18 pub allow: bool,
20 pub message: String,
22}
23
24impl HookResult {
25 #[must_use]
27 pub const fn allow() -> Self {
28 Self {
29 allow: true,
30 message: String::new(),
31 }
32 }
33
34 #[must_use]
36 pub fn allow_with_message(message: impl Into<String>) -> Self {
37 Self {
38 allow: true,
39 message: message.into(),
40 }
41 }
42
43 #[must_use]
45 pub fn deny(reason: impl Into<String>) -> Self {
46 Self {
47 allow: false,
48 message: reason.into(),
49 }
50 }
51}
52
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
61pub struct SessionContext {
62 pub session_id: String,
64 pub agent_id: u64,
66 #[serde(skip, default = "std::time::Instant::now")]
68 pub started_at: Instant,
69}
70
71#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
73pub struct OnSessionStartContext {
74 pub session: SessionContext,
76}
77
78#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
80pub struct OnSessionEndContext {
81 pub session: SessionContext,
83}
84
85#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct OnCompactionContext {}
88
89#[derive(Debug, Clone, Serialize, Deserialize)]
91pub struct OnInteractionContext {
92 pub message: String,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct PreTurnContext {
99 pub prompt: String,
101 pub turn_number: u32,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct PostTurnContext {
108 pub response_text: String,
110 pub turn_number: u32,
112}
113
114#[derive(Debug, Clone, Serialize, Deserialize)]
116pub struct PreToolCallDecideContext {
117 #[serde(alias = "name")]
119 pub tool_name: String,
120 #[serde(alias = "args", default)]
122 pub tool_args: serde_json::Value,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
127pub struct PostToolCallContext {
128 #[serde(alias = "name")]
130 pub tool_name: String,
131 #[serde(alias = "args", default)]
133 pub tool_args: serde_json::Value,
134 pub result: String,
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize)]
140pub struct OnToolErrorContext {
141 #[serde(alias = "name")]
143 pub tool_name: String,
144 #[serde(alias = "args", default)]
146 pub tool_args: serde_json::Value,
147 pub error: String,
149}
150
151#[non_exhaustive]
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
154pub enum HookPoint {
155 PreTurn,
157 PostTurn,
159 PreToolCallDecide,
161 PostToolCall,
163 OnCompaction,
165 OnSessionStart,
167 OnSessionEnd,
169 OnToolError,
171 OnInteraction,
173}
174
175impl HookPoint {
176 #[must_use]
178 pub const fn label(self) -> &'static str {
179 match self {
180 Self::PreTurn => "pre_turn",
181 Self::PostTurn => "post_turn",
182 Self::PreToolCallDecide => "pre_tool_call_decide",
183 Self::PostToolCall => "post_tool_call",
184 Self::OnCompaction => "on_compaction",
185 Self::OnSessionStart => "on_session_start",
186 Self::OnSessionEnd => "on_session_end",
187 Self::OnToolError => "on_tool_error",
188 Self::OnInteraction => "on_interaction",
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
205pub struct HookEntry {
206 pub name: String,
208 pub point: HookPoint,
210 pub callback_id: String,
212}
213
214impl HookEntry {
215 pub fn new(
232 name: impl Into<String>,
233 point: HookPoint,
234 callback_id: impl Into<String>,
235 ) -> Result<Self, crate::error::Error> {
236 let entry = Self {
237 name: name.into(),
238 point,
239 callback_id: callback_id.into(),
240 };
241 entry.validate()?;
242 Ok(entry)
243 }
244
245 pub fn validate(&self) -> Result<(), crate::error::Error> {
251 if self.name.trim().is_empty() {
252 return Err(crate::error::Error::InvalidConfig {
253 message: "HookEntry name must not be empty".to_owned(),
254 });
255 }
256 if self.callback_id.trim().is_empty() {
257 return Err(crate::error::Error::InvalidConfig {
258 message: format!("HookEntry '{}' has an empty callback_id", self.name),
259 });
260 }
261 Ok(())
262 }
263}
264
265#[derive(Debug, Clone, Default, Serialize, Deserialize)]
269pub struct HookSet {
270 entries: Vec<HookEntry>,
271}
272
273impl HookSet {
274 #[must_use]
276 pub const fn new() -> Self {
277 Self {
278 entries: Vec::new(),
279 }
280 }
281
282 pub fn push(&mut self, entry: HookEntry) -> Result<(), crate::error::Error> {
291 entry.validate()?;
292 if let Some(pos) = self
293 .entries
294 .iter()
295 .position(|e| e.name == entry.name && e.point == entry.point)
296 {
297 tracing::warn!(
298 hook = %entry.name,
299 point = %entry.point.label(),
300 "duplicate hook name+point in HookSet — replacing previous entry"
301 );
302 self.entries[pos] = entry;
303 } else {
304 self.entries.push(entry);
305 }
306 Ok(())
307 }
308
309 pub fn at_point(&self, point: HookPoint) -> impl Iterator<Item = &HookEntry> {
311 self.entries.iter().filter(move |e| e.point == point)
312 }
313
314 pub fn iter(&self) -> impl Iterator<Item = &HookEntry> {
316 self.entries.iter()
317 }
318
319 #[must_use]
321 pub const fn len(&self) -> usize {
322 self.entries.len()
323 }
324
325 #[must_use]
327 pub const fn is_empty(&self) -> bool {
328 self.entries.is_empty()
329 }
330}
331
332impl From<HookSet> for Vec<HookEntry> {
333 fn from(set: HookSet) -> Self {
334 set.entries
335 }
336}
337
338impl From<&HookSet> for Vec<HookEntry> {
339 fn from(set: &HookSet) -> Self {
340 set.entries.clone()
341 }
342}
343
344impl IntoIterator for HookSet {
345 type Item = HookEntry;
346 type IntoIter = std::vec::IntoIter<Self::Item>;
347
348 fn into_iter(self) -> Self::IntoIter {
349 self.entries.into_iter()
350 }
351}
352
353impl FromIterator<HookEntry> for HookSet {
354 fn from_iter<T: IntoIterator<Item = HookEntry>>(iter: T) -> Self {
355 let mut set = Self::new();
356 for entry in iter {
357 let name = entry.name.clone();
358 if let Err(e) = set.push(entry) {
359 tracing::error!(
360 error = %e,
361 hook = %name,
362 "Failed to push hook entry during from_iter"
363 );
364 }
365 }
366 set
367 }
368}
369
370impl From<Vec<HookEntry>> for HookSet {
371 fn from(entries: Vec<HookEntry>) -> Self {
372 Self::from_iter(entries)
373 }
374}
375
376impl<const N: usize> From<[HookEntry; N]> for HookSet {
377 fn from(entries: [HookEntry; N]) -> Self {
378 Self::from_iter(entries)
379 }
380}
381type TransformToolInputFn =
388 dyn Fn(&PreToolCallDecideContext) -> Option<serde_json::Value> + Send + Sync;
389
390#[non_exhaustive]
397pub enum HookCallback {
398 PreTurn(Box<dyn Fn(&PreTurnContext) + Send + Sync>),
400 PostTurn(Box<dyn Fn(&PostTurnContext) + Send + Sync>),
402 PreToolCallDecide(Box<dyn Fn(&PreToolCallDecideContext) -> HookResult + Send + Sync>),
404 PostToolCall(Box<dyn Fn(&PostToolCallContext) + Send + Sync>),
406 OnToolError(Box<dyn Fn(&OnToolErrorContext) + Send + Sync>),
408 OnSessionStart(Box<dyn Fn(&OnSessionStartContext) + Send + Sync>),
410 OnSessionEnd(Box<dyn Fn(&OnSessionEndContext) + Send + Sync>),
412 OnCompaction(Box<dyn Fn(&OnCompactionContext) + Send + Sync>),
414 OnInteraction(Box<dyn Fn(&OnInteractionContext) -> HookResult + Send + Sync>),
416 TransformToolInput(Box<TransformToolInputFn>),
424}
425
426impl HookCallback {
427 #[must_use]
429 pub(crate) const fn hook_point(&self) -> HookPoint {
430 match self {
431 Self::PreTurn(_) => HookPoint::PreTurn,
432 Self::PostTurn(_) => HookPoint::PostTurn,
433 Self::PreToolCallDecide(_) | Self::TransformToolInput(_) => {
434 HookPoint::PreToolCallDecide
435 }
436 Self::PostToolCall(_) => HookPoint::PostToolCall,
437 Self::OnToolError(_) => HookPoint::OnToolError,
438 Self::OnSessionStart(_) => HookPoint::OnSessionStart,
439 Self::OnSessionEnd(_) => HookPoint::OnSessionEnd,
440 Self::OnCompaction(_) => HookPoint::OnCompaction,
441 Self::OnInteraction(_) => HookPoint::OnInteraction,
442 }
443 }
444}
445
446impl std::fmt::Debug for HookCallback {
448 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
449 f.write_str("HookCallback::")?;
450 match self {
451 Self::TransformToolInput(_) => f.write_str("transform_tool_input"),
452 other => f.write_str(other.hook_point().label()),
453 }
454 }
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 #[test]
462 fn hook_result_allow() {
463 let r = HookResult::allow();
464 assert!(r.allow);
465 assert!(r.message.is_empty());
466 }
467
468 #[test]
469 fn hook_result_deny() {
470 let r = HookResult::deny("blocked by policy");
471 assert!(!r.allow);
472 assert_eq!(r.message, "blocked by policy");
473 }
474
475 #[test]
476 fn hook_result_allow_with_message() {
477 let r = HookResult::allow_with_message("proceeding with caution");
478 assert!(r.allow);
479 assert_eq!(r.message, "proceeding with caution");
480 }
481
482 #[test]
483 fn hook_point_labels() {
484 assert_eq!(HookPoint::PreTurn.label(), "pre_turn");
485 assert_eq!(HookPoint::PostTurn.label(), "post_turn");
486 assert_eq!(HookPoint::PreToolCallDecide.label(), "pre_tool_call_decide");
487 assert_eq!(HookPoint::PostToolCall.label(), "post_tool_call");
488 assert_eq!(HookPoint::OnCompaction.label(), "on_compaction");
489 assert_eq!(HookPoint::OnSessionStart.label(), "on_session_start");
490 assert_eq!(HookPoint::OnSessionEnd.label(), "on_session_end");
491 assert_eq!(HookPoint::OnToolError.label(), "on_tool_error");
492 assert_eq!(HookPoint::OnInteraction.label(), "on_interaction");
493 }
494
495 #[test]
496 fn hooks_fire_in_correct_order() {
497 let mut set = HookSet::new();
498 assert!(set.is_empty());
499
500 set.push(HookEntry {
501 name: "pre_turn_1".to_owned(),
502 point: HookPoint::PreTurn,
503 callback_id: "cb_pre1".to_owned(),
504 })
505 .unwrap();
506 set.push(HookEntry {
507 name: "pre_tool_decide".to_owned(),
508 point: HookPoint::PreToolCallDecide,
509 callback_id: "cb_decide".to_owned(),
510 })
511 .unwrap();
512 set.push(HookEntry {
513 name: "pre_turn_2".to_owned(),
514 point: HookPoint::PreTurn,
515 callback_id: "cb_pre2".to_owned(),
516 })
517 .unwrap();
518 set.push(HookEntry {
519 name: "post_turn_1".to_owned(),
520 point: HookPoint::PostTurn,
521 callback_id: "cb_post1".to_owned(),
522 })
523 .unwrap();
524 set.push(HookEntry {
525 name: "post_tool_1".to_owned(),
526 point: HookPoint::PostToolCall,
527 callback_id: "cb_posttool1".to_owned(),
528 })
529 .unwrap();
530
531 assert_eq!(set.len(), 5);
532
533 let pre_turn: Vec<&str> = set
534 .at_point(HookPoint::PreTurn)
535 .map(|e| e.name.as_str())
536 .collect();
537 assert_eq!(pre_turn, vec!["pre_turn_1", "pre_turn_2"]);
538
539 let decide: Vec<&str> = set
540 .at_point(HookPoint::PreToolCallDecide)
541 .map(|e| e.name.as_str())
542 .collect();
543 assert_eq!(decide, vec!["pre_tool_decide"]);
544
545 let post_turn: Vec<&str> = set
546 .at_point(HookPoint::PostTurn)
547 .map(|e| e.name.as_str())
548 .collect();
549 assert_eq!(post_turn, vec!["post_turn_1"]);
550
551 let post_tool: Vec<&str> = set
552 .at_point(HookPoint::PostToolCall)
553 .map(|e| e.name.as_str())
554 .collect();
555 assert_eq!(post_tool, vec!["post_tool_1"]);
556 }
557
558 #[test]
559 fn hook_entry_serde_roundtrip() {
560 let entry = HookEntry {
561 name: "my_hook".to_owned(),
562 point: HookPoint::PreToolCallDecide,
563 callback_id: "cb_123".to_owned(),
564 };
565 let json = serde_json::to_string(&entry).expect("serialize");
566 let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
567 assert_eq!(parsed.name, entry.name);
568 assert_eq!(parsed.point, entry.point);
569 assert_eq!(parsed.callback_id, entry.callback_id);
570 }
571
572 #[test]
573 fn hook_result_serde_roundtrip() {
574 let results = vec![
575 HookResult::allow(),
576 HookResult::deny("reason"),
577 HookResult::allow_with_message("ok"),
578 ];
579 for result in &results {
580 let json = serde_json::to_string(result).expect("serialize");
581 let parsed: HookResult = serde_json::from_str(&json).expect("deserialize");
582 assert_eq!(&parsed, result);
583 }
584 }
585
586 #[test]
587 fn hook_set_serde_roundtrip() {
588 let mut set = HookSet::new();
589 set.push(HookEntry {
590 name: "gate".to_owned(),
591 point: HookPoint::PreTurn,
592 callback_id: "cb_1".to_owned(),
593 })
594 .unwrap();
595 set.push(HookEntry {
596 name: "logger".to_owned(),
597 point: HookPoint::PostToolCall,
598 callback_id: "cb_2".to_owned(),
599 })
600 .unwrap();
601 let json = serde_json::to_string(&set).expect("serialize");
602 let parsed: HookSet = serde_json::from_str(&json).expect("deserialize");
603 assert_eq!(parsed.len(), 2);
604 let names: Vec<&str> = parsed.iter().map(|e| e.name.as_str()).collect();
605 assert_eq!(names, vec!["gate", "logger"]);
606 }
607
608 #[test]
609 fn hook_set_from_conversions() {
610 let mut set = HookSet::new();
611 set.push(HookEntry {
612 name: "gate".to_owned(),
613 point: HookPoint::PreTurn,
614 callback_id: "cb_1".to_owned(),
615 })
616 .unwrap();
617
618 let vec_from_owned: Vec<HookEntry> = Vec::from(set.clone());
619 assert_eq!(vec_from_owned.len(), 1);
620 assert_eq!(vec_from_owned[0].name, "gate");
621
622 let vec_from_ref: Vec<HookEntry> = Vec::from(&set);
623 assert_eq!(vec_from_ref.len(), 1);
624 assert_eq!(vec_from_ref[0].name, "gate");
625
626 let entry = HookEntry {
627 name: "gate".to_owned(),
628 point: HookPoint::PreTurn,
629 callback_id: "cb_1".to_owned(),
630 };
631 let set_from_arr = HookSet::from([entry.clone()]);
632 assert_eq!(set_from_arr.len(), 1);
633
634 let set_from_vec = HookSet::from(vec![entry]);
635 assert_eq!(set_from_vec.len(), 1);
636 }
637
638 #[test]
639 fn empty_hook_set_iteration_at_each_point() {
640 let set = HookSet::new();
641 for point in [
642 HookPoint::PreTurn,
643 HookPoint::PostTurn,
644 HookPoint::PreToolCallDecide,
645 HookPoint::PostToolCall,
646 HookPoint::OnCompaction,
647 HookPoint::OnSessionStart,
648 HookPoint::OnSessionEnd,
649 HookPoint::OnToolError,
650 HookPoint::OnInteraction,
651 ] {
652 assert_eq!(
653 set.at_point(point).count(),
654 0,
655 "Empty HookSet should have 0 hooks at {point:?}"
656 );
657 }
658 }
659
660 #[test]
661 fn hook_point_serde_roundtrip() {
662 let points = [
663 HookPoint::PreTurn,
664 HookPoint::PostTurn,
665 HookPoint::PreToolCallDecide,
666 HookPoint::PostToolCall,
667 HookPoint::OnCompaction,
668 HookPoint::OnSessionStart,
669 HookPoint::OnSessionEnd,
670 HookPoint::OnToolError,
671 HookPoint::OnInteraction,
672 ];
673 for point in points {
674 let json = serde_json::to_string(&point).expect("serialize");
675 let parsed: HookPoint = serde_json::from_str(&json).expect("deserialize");
676 assert_eq!(parsed, point);
677 }
678 }
679
680 #[test]
681 fn hook_set_default_is_empty() {
682 let set = HookSet::default();
683 assert!(set.is_empty());
684 assert_eq!(set.len(), 0);
685 }
686
687 #[test]
688 fn hook_set_multiple_hooks_at_same_point() {
689 let mut set = HookSet::new();
690 for i in 0..5 {
691 set.push(HookEntry {
692 name: format!("hook_{i}"),
693 point: HookPoint::PreToolCallDecide,
694 callback_id: format!("cb_{i}"),
695 })
696 .unwrap();
697 }
698 assert_eq!(set.len(), 5);
699 assert_eq!(set.at_point(HookPoint::PreToolCallDecide).count(), 5);
700 assert_eq!(set.at_point(HookPoint::PreTurn).count(), 0);
701 }
702
703 #[test]
704 fn hook_result_deny_with_string_owned() {
705 let reason = String::from("policy violation detected");
706 let r = HookResult::deny(reason.clone());
707 assert!(!r.allow);
708 assert_eq!(r.message, reason);
709 }
710
711 #[test]
712 fn hook_entry_with_new_hook_points() {
713 let new_points = [
714 (HookPoint::OnCompaction, "compaction_hook"),
715 (HookPoint::OnSessionStart, "session_start_hook"),
716 (HookPoint::OnSessionEnd, "session_end_hook"),
717 (HookPoint::OnToolError, "tool_error_hook"),
718 (HookPoint::OnInteraction, "interaction_hook"),
719 ];
720 let mut set = HookSet::new();
721 for (point, name) in &new_points {
722 set.push(HookEntry {
723 name: (*name).to_owned(),
724 point: *point,
725 callback_id: format!("cb_{name}"),
726 })
727 .unwrap();
728 }
729 assert_eq!(set.len(), 5);
730 for (point, name) in &new_points {
731 let hooks: Vec<&str> = set.at_point(*point).map(|e| e.name.as_str()).collect();
732 assert_eq!(hooks, vec![*name], "expected hook at {point:?}");
733 }
734 }
735
736 #[test]
737 fn hook_entry_serde_roundtrip_new_points() {
738 let new_points = [
739 HookPoint::OnCompaction,
740 HookPoint::OnSessionStart,
741 HookPoint::OnSessionEnd,
742 HookPoint::OnToolError,
743 HookPoint::OnInteraction,
744 ];
745 for point in new_points {
746 let entry = HookEntry {
747 name: format!("test_{}", point.label()),
748 point,
749 callback_id: format!("cb_{}", point.label()),
750 };
751 let json = serde_json::to_string(&entry).expect("serialize");
752 let parsed: HookEntry = serde_json::from_str(&json).expect("deserialize");
753 assert_eq!(parsed.name, entry.name);
754 assert_eq!(parsed.point, entry.point);
755 assert_eq!(parsed.callback_id, entry.callback_id);
756 }
757 }
758
759 #[test]
762 fn session_context_clone() {
763 let ctx = SessionContext {
764 session_id: "sess-1".into(),
765 agent_id: 42,
766 started_at: Instant::now(),
767 };
768 let cloned = ctx;
769 assert_eq!(cloned.session_id, "sess-1");
770 assert_eq!(cloned.agent_id, 42);
771 }
772
773 #[test]
774 fn session_context_debug_format() {
775 let ctx = SessionContext {
776 session_id: "sess-debug".into(),
777 agent_id: 1,
778 started_at: Instant::now(),
779 };
780 let dbg = format!("{ctx:?}");
781 assert!(dbg.contains("sess-debug"));
782 assert!(dbg.contains("agent_id: 1"));
783 }
784
785 #[test]
788 fn hook_entry_new_valid() {
789 let entry = HookEntry::new("safety_gate", HookPoint::PreToolCallDecide, "cb_safety")
790 .expect("valid entry");
791 assert_eq!(entry.name, "safety_gate");
792 assert_eq!(entry.point, HookPoint::PreToolCallDecide);
793 assert_eq!(entry.callback_id, "cb_safety");
794 }
795
796 #[test]
797 fn hook_entry_new_rejects_empty_name() {
798 let result = HookEntry::new("", HookPoint::PreTurn, "cb_1");
799 assert!(result.is_err(), "should reject empty name");
800 }
801
802 #[test]
803 fn hook_entry_new_rejects_whitespace_name() {
804 let result = HookEntry::new(" ", HookPoint::PreTurn, "cb_1");
805 assert!(result.is_err(), "should reject whitespace-only name");
806 }
807
808 #[test]
809 fn hook_entry_new_rejects_empty_callback_id() {
810 let result = HookEntry::new("my_hook", HookPoint::PreTurn, "");
811 assert!(result.is_err(), "should reject empty callback_id");
812 }
813
814 #[test]
815 fn hook_entry_new_rejects_whitespace_callback_id() {
816 let result = HookEntry::new("my_hook", HookPoint::PostTurn, " ");
817 assert!(result.is_err(), "should reject whitespace-only callback_id");
818 }
819
820 #[test]
821 fn pre_tool_call_decide_context_serde_aliases() {
822 let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"}}"#;
823 let parsed_std: PreToolCallDecideContext = serde_json::from_str(json_std).unwrap();
824 assert_eq!(parsed_std.tool_name, "my_tool");
825 assert_eq!(parsed_std.tool_args["foo"], "bar");
826
827 let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"}}"#;
828 let parsed_alias: PreToolCallDecideContext = serde_json::from_str(json_alias).unwrap();
829 assert_eq!(parsed_alias.tool_name, "my_tool");
830 assert_eq!(parsed_alias.tool_args["foo"], "bar");
831 }
832
833 #[test]
834 fn pre_tool_call_decide_context_serde_default() {
835 let json_no_args = r#"{"name":"my_tool"}"#;
836 let parsed_no_args: PreToolCallDecideContext = serde_json::from_str(json_no_args).unwrap();
837 assert_eq!(parsed_no_args.tool_name, "my_tool");
838 assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
839 }
840
841 #[test]
842 fn post_tool_call_context_serde_aliases_and_default() {
843 let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"result":"success"}"#;
844 let parsed_std: PostToolCallContext = serde_json::from_str(json_std).unwrap();
845 assert_eq!(parsed_std.tool_name, "my_tool");
846 assert_eq!(parsed_std.tool_args["foo"], "bar");
847 assert_eq!(parsed_std.result, "success");
848
849 let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"result":"success"}"#;
850 let parsed_alias: PostToolCallContext = serde_json::from_str(json_alias).unwrap();
851 assert_eq!(parsed_alias.tool_name, "my_tool");
852 assert_eq!(parsed_alias.tool_args["foo"], "bar");
853 assert_eq!(parsed_alias.result, "success");
854
855 let json_no_args = r#"{"name":"my_tool","result":"success"}"#;
856 let parsed_no_args: PostToolCallContext = serde_json::from_str(json_no_args).unwrap();
857 assert_eq!(parsed_no_args.tool_name, "my_tool");
858 assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
859 assert_eq!(parsed_no_args.result, "success");
860 }
861
862 #[test]
863 fn on_tool_error_context_serde_aliases_and_default() {
864 let json_std = r#"{"tool_name":"my_tool","tool_args":{"foo":"bar"},"error":"failed"}"#;
865 let parsed_std: OnToolErrorContext = serde_json::from_str(json_std).unwrap();
866 assert_eq!(parsed_std.tool_name, "my_tool");
867 assert_eq!(parsed_std.tool_args["foo"], "bar");
868 assert_eq!(parsed_std.error, "failed");
869
870 let json_alias = r#"{"name":"my_tool","args":{"foo":"bar"},"error":"failed"}"#;
871 let parsed_alias: OnToolErrorContext = serde_json::from_str(json_alias).unwrap();
872 assert_eq!(parsed_alias.tool_name, "my_tool");
873 assert_eq!(parsed_alias.tool_args["foo"], "bar");
874 assert_eq!(parsed_alias.error, "failed");
875
876 let json_no_args = r#"{"name":"my_tool","error":"failed"}"#;
877 let parsed_no_args: OnToolErrorContext = serde_json::from_str(json_no_args).unwrap();
878 assert_eq!(parsed_no_args.tool_name, "my_tool");
879 assert_eq!(parsed_no_args.tool_args, serde_json::Value::Null);
880 assert_eq!(parsed_no_args.error, "failed");
881
882 let json_no_name = r#"{"error":"failed"}"#;
883 let parsed_no_name: Result<OnToolErrorContext, _> = serde_json::from_str(json_no_name);
884 assert!(parsed_no_name.is_err());
885 }
886}