use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use crate::error::OxideError;
use crate::types::ToolDefinition;
type AsyncResult = Pin<Box<dyn Future<Output = Result<serde_json::Value, OxideError>> + Send>>;
type HandlerFn =
Arc<dyn Fn(serde_json::Value) -> AsyncResult + Send + Sync>;
struct RegisteredTool {
definition: ToolDefinition,
handler: HandlerFn,
}
#[derive(Default)]
pub struct ToolRegistry {
tools: HashMap<String, RegisteredTool>,
}
impl ToolRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register<F, Fut>(
&mut self,
definition: ToolDefinition,
handler: F,
) where
F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
Fut: Future<Output = Result<serde_json::Value, OxideError>> + Send + 'static,
{
let name = definition.function.name.clone();
self.tools.insert(
name,
RegisteredTool {
definition,
handler: Arc::new(move |args| Box::pin(handler(args))),
},
);
}
pub fn definitions(&self) -> Vec<ToolDefinition> {
self.tools.values().map(|t| t.definition.clone()).collect()
}
pub async fn dispatch(
&self,
tool_name: &str,
args: serde_json::Value,
) -> Result<serde_json::Value, OxideError> {
let tool = self.tools.get(tool_name).ok_or_else(|| {
OxideError::Other(format!("unknown tool: {tool_name}"))
})?;
(tool.handler)(args).await
}
pub fn contains(&self, name: &str) -> bool {
self.tools.contains_key(name)
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn is_empty(&self) -> bool {
self.tools.is_empty()
}
}
pub struct ToolBuilder {
name: String,
description: String,
properties: serde_json::Map<String, serde_json::Value>,
required: Vec<String>,
}
impl ToolBuilder {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
properties: serde_json::Map::new(),
required: Vec::new(),
}
}
pub fn string_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
let n = name.into();
self.properties.insert(
n.clone(),
serde_json::json!({"type": "string", "description": description.into()}),
);
if required {
self.required.push(n);
}
self
}
pub fn number_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
let n = name.into();
self.properties.insert(
n.clone(),
serde_json::json!({"type": "number", "description": description.into()}),
);
if required {
self.required.push(n);
}
self
}
pub fn bool_param(
mut self,
name: impl Into<String>,
description: impl Into<String>,
required: bool,
) -> Self {
let n = name.into();
self.properties.insert(
n.clone(),
serde_json::json!({"type": "boolean", "description": description.into()}),
);
if required {
self.required.push(n);
}
self
}
pub fn build(self) -> ToolDefinition {
use crate::types::FunctionDefinition;
ToolDefinition {
kind: "function".into(),
function: FunctionDefinition {
name: self.name,
description: self.description,
parameters: serde_json::json!({
"type": "object",
"properties": serde_json::Value::Object(self.properties),
"required": self.required,
}),
},
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn registry_dispatch_calls_handler() {
let mut registry = ToolRegistry::new();
let def = ToolBuilder::new("add", "Add two numbers")
.number_param("a", "First operand", true)
.number_param("b", "Second operand", true)
.build();
registry.register(def, |args| async move {
let a = args["a"].as_f64().unwrap_or(0.0);
let b = args["b"].as_f64().unwrap_or(0.0);
Ok(serde_json::json!(a + b))
});
let result = registry
.dispatch("add", serde_json::json!({"a": 3.0, "b": 4.0}))
.await
.unwrap();
assert_eq!(result, serde_json::json!(7.0));
}
#[tokio::test]
async fn unknown_tool_returns_error() {
let registry = ToolRegistry::new();
let err = registry
.dispatch("nonexistent", serde_json::json!({}))
.await
.unwrap_err();
assert!(matches!(err, OxideError::Other(_)));
}
#[test]
fn definitions_are_returned() {
let mut registry = ToolRegistry::new();
let def = ToolBuilder::new("greet", "Say hello").build();
registry.register(def, |_| async move { Ok(serde_json::json!("hello")) });
assert_eq!(registry.definitions().len(), 1);
assert_eq!(registry.definitions()[0].function.name, "greet");
}
}