use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use rs_genai::prelude::FunctionCall;
use crate::error::ToolError;
#[derive(Debug, Clone)]
pub enum BeforeToolResult {
Continue,
Skip(serde_json::Value),
Deny(String),
}
#[derive(Debug, Clone)]
pub struct ToolCallResult {
pub call: FunctionCall,
pub result: Result<serde_json::Value, ToolError>,
pub duration: std::time::Duration,
}
pub type BeforeToolCallback = Arc<
dyn Fn(&FunctionCall) -> Pin<Box<dyn Future<Output = BeforeToolResult> + Send + '_>>
+ Send
+ Sync,
>;
pub type AfterToolCallback =
Arc<dyn Fn(&ToolCallResult) -> Pin<Box<dyn Future<Output = ()> + Send + '_>> + Send + Sync>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn before_tool_result_variants() {
let cont = BeforeToolResult::Continue;
assert!(matches!(cont, BeforeToolResult::Continue));
let skip = BeforeToolResult::Skip(serde_json::json!({"cached": true}));
assert!(matches!(skip, BeforeToolResult::Skip(_)));
let deny = BeforeToolResult::Deny("not allowed".into());
assert!(matches!(deny, BeforeToolResult::Deny(_)));
}
#[test]
fn tool_call_result_ok() {
let result = ToolCallResult {
call: FunctionCall {
name: "test".into(),
args: serde_json::json!({}),
id: None,
},
result: Ok(serde_json::json!({"success": true})),
duration: std::time::Duration::from_millis(42),
};
assert!(result.result.is_ok());
}
#[test]
fn tool_call_result_err() {
let result = ToolCallResult {
call: FunctionCall {
name: "test".into(),
args: serde_json::json!({}),
id: None,
},
result: Err(ToolError::ExecutionFailed("boom".into())),
duration: std::time::Duration::from_millis(1),
};
assert!(result.result.is_err());
}
}