pub mod dispatcher;
pub mod simple;
pub mod typed;
pub use dispatcher::*;
pub use simple::*;
pub use typed::*;
use std::sync::Arc;
use std::time::Duration;
use async_trait::async_trait;
use tokio::sync::{broadcast, mpsc};
use tokio::task::JoinHandle;
use tokio_util::sync::CancellationToken;
use crate::agent_session::InputEvent;
use crate::error::ToolError;
#[async_trait]
pub trait ToolFunction: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Option<serde_json::Value>;
async fn call(&self, args: serde_json::Value) -> Result<serde_json::Value, ToolError>;
}
#[async_trait]
pub trait StreamingTool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Option<serde_json::Value>;
async fn run(
&self,
args: serde_json::Value,
yield_tx: mpsc::Sender<serde_json::Value>,
) -> Result<(), ToolError>;
}
#[async_trait]
pub trait InputStreamingTool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters(&self) -> Option<serde_json::Value>;
async fn run(
&self,
args: serde_json::Value,
input_rx: broadcast::Receiver<InputEvent>,
yield_tx: mpsc::Sender<serde_json::Value>,
) -> Result<(), ToolError>;
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ToolClass {
Regular,
Streaming,
InputStream,
}
pub enum ToolKind {
Function(Arc<dyn ToolFunction>),
Streaming(Arc<dyn StreamingTool>),
InputStream(Arc<dyn InputStreamingTool>),
}
pub struct ActiveStreamingTool {
pub task: JoinHandle<()>,
pub cancel: CancellationToken,
}
pub(crate) const DEFAULT_TOOL_TIMEOUT: Duration = Duration::from_secs(30);
#[cfg(test)]
mod tests {
use super::*;
use rs_genai::prelude::FunctionCall;
use serde_json::json;
struct MockTool;
#[async_trait]
impl ToolFunction for MockTool {
fn name(&self) -> &str {
"mock_tool"
}
fn description(&self) -> &str {
"A mock tool"
}
fn parameters(&self) -> Option<serde_json::Value> {
None
}
async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
Ok(json!({"result": "ok"}))
}
}
#[tokio::test]
async fn register_and_call_function_tool() {
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(MockTool));
let result = dispatcher
.call_function("mock_tool", json!({}))
.await
.unwrap();
assert_eq!(result["result"], "ok");
}
#[tokio::test]
async fn call_unknown_tool_returns_error() {
let dispatcher = ToolDispatcher::new();
let result = dispatcher.call_function("nonexistent", json!({})).await;
assert!(result.is_err());
}
#[test]
fn to_tool_declarations() {
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(MockTool));
let decls = dispatcher.to_tool_declarations();
assert_eq!(decls.len(), 1);
}
#[test]
fn classify_tool() {
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(MockTool));
assert_eq!(dispatcher.classify("mock_tool"), Some(ToolClass::Regular));
assert_eq!(dispatcher.classify("nonexistent"), None);
}
#[test]
fn empty_dispatcher() {
let dispatcher = ToolDispatcher::new();
assert!(dispatcher.is_empty());
assert_eq!(dispatcher.len(), 0);
assert!(dispatcher.to_tool_declarations().is_empty());
}
#[test]
fn build_response_success() {
let call = FunctionCall {
name: "test".to_string(),
args: json!({}),
id: Some("call-1".to_string()),
};
let resp = ToolDispatcher::build_response(&call, Ok(json!({"ok": true})));
assert_eq!(resp.name, "test");
assert_eq!(resp.response["ok"], true);
}
#[test]
fn build_response_error() {
let call = FunctionCall {
name: "test".to_string(),
args: json!({}),
id: Some("call-1".to_string()),
};
let resp = ToolDispatcher::build_response(
&call,
Err(ToolError::ExecutionFailed("boom".to_string())),
);
assert!(resp.response["error"].as_str().unwrap().contains("boom"));
}
#[test]
fn tool_dispatcher_implements_tool_provider() {
use rs_genai::prelude::ToolProvider;
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(MockTool));
let decls = dispatcher.declarations();
assert_eq!(decls.len(), 1);
}
#[tokio::test]
async fn simple_tool_closure() {
let tool = SimpleTool::new(
"add",
"Add two numbers",
Some(
json!({"type": "object", "properties": {"a": {"type": "number"}, "b": {"type": "number"}}}),
),
|args| async move {
let a = args["a"].as_f64().unwrap_or(0.0);
let b = args["b"].as_f64().unwrap_or(0.0);
Ok(json!({"sum": a + b}))
},
);
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(tool));
let result = dispatcher
.call_function("add", json!({"a": 3, "b": 4}))
.await
.unwrap();
assert_eq!(result["sum"], 7.0);
}
#[derive(serde::Deserialize, schemars::JsonSchema)]
struct WeatherArgs {
city: String,
#[serde(default = "default_units")]
units: String,
}
fn default_units() -> String {
"celsius".to_string()
}
#[test]
fn typed_tool_auto_generates_schema() {
let tool = TypedTool::new(
"get_weather",
"Get current weather for a city",
|_args: WeatherArgs| async move { Ok(json!({})) },
);
let params = tool.parameters().expect("should have parameters");
let props = ¶ms["properties"];
assert!(
props.get("city").is_some(),
"schema should contain 'city' property"
);
assert!(
props.get("units").is_some(),
"schema should contain 'units' property"
);
let required = params["required"]
.as_array()
.expect("should have required array");
let required_names: Vec<&str> = required.iter().filter_map(|v| v.as_str()).collect();
assert!(required_names.contains(&"city"), "city should be required");
}
#[tokio::test]
async fn typed_tool_deserializes_args() {
let tool = TypedTool::new(
"get_weather",
"Get current weather for a city",
|args: WeatherArgs| async move {
Ok(json!({
"temp": 22,
"city": args.city,
"units": args.units,
}))
},
);
let result = tool
.call(json!({"city": "London", "units": "fahrenheit"}))
.await
.unwrap();
assert_eq!(result["city"], "London");
assert_eq!(result["units"], "fahrenheit");
assert_eq!(result["temp"], 22);
}
#[tokio::test]
async fn typed_tool_invalid_args_returns_error() {
let tool = TypedTool::new(
"get_weather",
"Get current weather for a city",
|_args: WeatherArgs| async move { Ok(json!({})) },
);
let result = tool.call(json!({"units": "celsius"})).await;
assert!(result.is_err(), "should fail with missing required field");
let err = result.unwrap_err();
match &err {
ToolError::InvalidArgs(msg) => {
assert!(
msg.contains("city"),
"error message should mention the missing field: {msg}"
);
}
other => panic!("expected ToolError::InvalidArgs, got: {other:?}"),
}
let result = tool.call(json!({"city": 12345})).await;
assert!(result.is_err(), "should fail with wrong type");
}
#[tokio::test]
async fn typed_tool_registers_in_dispatcher() {
let tool = TypedTool::new(
"get_weather",
"Get current weather for a city",
|args: WeatherArgs| async move { Ok(json!({"city": args.city})) },
);
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(tool));
assert_eq!(dispatcher.classify("get_weather"), Some(ToolClass::Regular));
assert_eq!(dispatcher.len(), 1);
let result = dispatcher
.call_function("get_weather", json!({"city": "Paris"}))
.await
.unwrap();
assert_eq!(result["city"], "Paris");
let decls = dispatcher.to_tool_declarations();
assert_eq!(decls.len(), 1);
}
struct SlowTool;
#[async_trait]
impl ToolFunction for SlowTool {
fn name(&self) -> &str {
"slow_tool"
}
fn description(&self) -> &str {
"A tool that never completes"
}
fn parameters(&self) -> Option<serde_json::Value> {
None
}
async fn call(&self, _args: serde_json::Value) -> Result<serde_json::Value, ToolError> {
tokio::time::sleep(Duration::from_secs(3600)).await;
Ok(json!({"result": "should never reach here"}))
}
}
#[tokio::test]
async fn tool_timeout_returns_error() {
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(SlowTool));
let timeout = Duration::from_millis(50);
let result = dispatcher
.call_function_with_timeout("slow_tool", json!({}), timeout)
.await;
match result {
Err(ToolError::Timeout(d)) => assert_eq!(d, timeout),
other => panic!("expected ToolError::Timeout, got: {other:?}"),
}
}
#[tokio::test]
async fn tool_completes_before_timeout() {
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(MockTool));
let result = dispatcher
.call_function_with_timeout("mock_tool", json!({}), Duration::from_secs(5))
.await
.unwrap();
assert_eq!(result["result"], "ok");
}
#[tokio::test]
async fn tool_cancelled_returns_error() {
let mut dispatcher = ToolDispatcher::new();
dispatcher.register_function(Arc::new(SlowTool));
let cancel = CancellationToken::new();
let cancel_clone = cancel.clone();
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(50)).await;
cancel_clone.cancel();
});
let result = dispatcher
.call_function_with_cancel("slow_tool", json!({}), cancel)
.await;
match result {
Err(ToolError::Cancelled) => {} other => panic!("expected ToolError::Cancelled, got: {other:?}"),
}
}
#[test]
fn default_timeout_is_30s() {
let dispatcher = ToolDispatcher::new();
assert_eq!(dispatcher.default_timeout(), Duration::from_secs(30));
}
#[test]
fn with_timeout_overrides_default() {
let dispatcher = ToolDispatcher::new().with_timeout(Duration::from_secs(10));
assert_eq!(dispatcher.default_timeout(), Duration::from_secs(10));
}
#[tokio::test]
async fn call_function_uses_default_timeout() {
let mut dispatcher = ToolDispatcher::new().with_timeout(Duration::from_millis(50));
dispatcher.register_function(Arc::new(SlowTool));
let result = dispatcher.call_function("slow_tool", json!({})).await;
match result {
Err(ToolError::Timeout(d)) => assert_eq!(d, Duration::from_millis(50)),
other => panic!("expected ToolError::Timeout, got: {other:?}"),
}
}
}