1use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use rs_genai::prelude::FunctionCall;
12
13use crate::error::ToolError;
14
15#[derive(Debug, Clone)]
17pub enum BeforeToolResult {
18 Continue,
20 Skip(serde_json::Value),
22 Deny(String),
24}
25
26#[derive(Debug, Clone)]
28pub struct ToolCallResult {
29 pub call: FunctionCall,
31 pub result: Result<serde_json::Value, ToolError>,
33 pub duration: std::time::Duration,
35}
36
37pub type BeforeToolCallback = Arc<
42 dyn Fn(&FunctionCall) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + '_>>
43 + Send
44 + Sync,
45>;
46
47pub type AfterToolCallback =
52 Arc<dyn Fn(&ToolCallResult) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> + Send + Sync>;
53
54#[cfg(test)]
55mod tests {
56 use super::*;
57
58 #[test]
59 fn before_tool_result_variants() {
60 let cont = BeforeToolResult::Continue;
61 assert!(matches!(cont, BeforeToolResult::Continue));
62
63 let skip = BeforeToolResult::Skip(serde_json::json!({"cached": true}));
64 assert!(matches!(skip, BeforeToolResult::Skip(_)));
65
66 let deny = BeforeToolResult::Deny("not allowed".into());
67 assert!(matches!(deny, BeforeToolResult::Deny(_)));
68 }
69
70 #[test]
71 fn tool_call_result_ok() {
72 let result = ToolCallResult {
73 call: FunctionCall {
74 name: "test".into(),
75 args: serde_json::json!({}),
76 id: None,
77 },
78 result: Ok(serde_json::json!({"success": true})),
79 duration: std::time::Duration::from_millis(42),
80 };
81 assert!(result.result.is_ok());
82 }
83
84 #[test]
85 fn tool_call_result_err() {
86 let result = ToolCallResult {
87 call: FunctionCall {
88 name: "test".into(),
89 args: serde_json::json!({}),
90 id: None,
91 },
92 result: Err(ToolError::ExecutionFailed("boom".into())),
93 duration: std::time::Duration::from_millis(1),
94 };
95 assert!(result.result.is_err());
96 }
97}