Skip to main content

rs_adk/
callback.rs

1//! Callback types for tool execution lifecycle.
2//!
3//! Callbacks provide a lightweight alternative to plugins for simple
4//! before/after tool interception. They are closures registered on the
5//! agent builder.
6
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10
11use rs_genai::prelude::FunctionCall;
12
13use crate::error::ToolError;
14
15/// The result of a before-tool callback.
16#[derive(Debug, Clone)]
17pub enum BeforeToolResult {
18    /// Continue with the tool call.
19    Continue,
20    /// Skip the tool call and use this value as the result.
21    Skip(serde_json::Value),
22    /// Deny the tool call with a reason.
23    Deny(String),
24}
25
26/// The result of a tool call, passed to after-tool callbacks.
27#[derive(Debug, Clone)]
28pub struct ToolCallResult {
29    /// The function call that was executed.
30    pub call: FunctionCall,
31    /// The result (Ok = tool output, Err = tool error).
32    pub result: Result<serde_json::Value, ToolError>,
33    /// How long the tool call took.
34    pub duration: std::time::Duration,
35}
36
37/// A before-tool callback function type.
38///
39/// Receives the function call about to be executed and returns a decision
40/// about whether to proceed.
41pub type BeforeToolCallback = Arc<
42    dyn Fn(&FunctionCall) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + '_>>
43        + Send
44        + Sync,
45>;
46
47/// An after-tool callback function type.
48///
49/// Receives the tool call result for observation/logging purposes.
50/// Cannot modify the result.
51pub type AfterToolCallback =
52    Arc<dyn Fn(&ToolCallResult) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> + Send + Sync>;
53
54#[cfg(test)]
55mod tests {
56    use super::*;
57
58    #[test]
59    fn before_tool_result_variants() {
60        let cont = BeforeToolResult::Continue;
61        assert!(matches!(cont, BeforeToolResult::Continue));
62
63        let skip = BeforeToolResult::Skip(serde_json::json!({"cached": true}));
64        assert!(matches!(skip, BeforeToolResult::Skip(_)));
65
66        let deny = BeforeToolResult::Deny("not allowed".into());
67        assert!(matches!(deny, BeforeToolResult::Deny(_)));
68    }
69
70    #[test]
71    fn tool_call_result_ok() {
72        let result = ToolCallResult {
73            call: FunctionCall {
74                name: "test".into(),
75                args: serde_json::json!({}),
76                id: None,
77            },
78            result: Ok(serde_json::json!({"success": true})),
79            duration: std::time::Duration::from_millis(42),
80        };
81        assert!(result.result.is_ok());
82    }
83
84    #[test]
85    fn tool_call_result_err() {
86        let result = ToolCallResult {
87            call: FunctionCall {
88                name: "test".into(),
89                args: serde_json::json!({}),
90                id: None,
91            },
92            result: Err(ToolError::ExecutionFailed("boom".into())),
93            duration: std::time::Duration::from_millis(1),
94        };
95        assert!(result.result.is_err());
96    }
97}