use async_trait::async_trait;
use std::fmt;
#[derive(Debug, Default)]
pub struct RequestContext {
pub model: Option<String>,
pub extensions: axum::http::Extensions,
}
impl RequestContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_model(mut self, model: impl Into<String>) -> Self {
self.model = Some(model.into());
self
}
pub fn with_extension<T: Clone + Send + Sync + 'static>(mut self, val: T) -> Self {
self.extensions.insert(val);
self
}
}
#[derive(Debug, Clone)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
pub strict: bool,
}
#[derive(Debug, Clone)]
pub enum ToolError {
NotFound(String),
ExecutionError(String),
InvalidArguments(String),
Timeout(String),
}
impl fmt::Display for ToolError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ToolError::NotFound(name) => write!(f, "Tool not found: {}", name),
ToolError::ExecutionError(msg) => write!(f, "Tool execution error: {}", msg),
ToolError::InvalidArguments(msg) => write!(f, "Invalid arguments: {}", msg),
ToolError::Timeout(msg) => write!(f, "Tool timeout: {}", msg),
}
}
}
impl std::error::Error for ToolError {}
#[async_trait]
pub trait ToolExecutor: Send + Sync {
async fn tools(&self, ctx: &RequestContext) -> Vec<ToolSchema>;
async fn execute(
&self,
tool_name: &str,
tool_call_id: &str,
arguments: &serde_json::Value,
ctx: &RequestContext,
) -> Result<serde_json::Value, ToolError>;
}
#[derive(Debug, Clone, Default)]
pub struct NoOpToolExecutor;
#[async_trait]
impl ToolExecutor for NoOpToolExecutor {
async fn tools(&self, _ctx: &RequestContext) -> Vec<ToolSchema> {
Vec::new()
}
async fn execute(
&self,
tool_name: &str,
_tool_call_id: &str,
_arguments: &serde_json::Value,
_ctx: &RequestContext,
) -> Result<serde_json::Value, ToolError> {
Err(ToolError::NotFound(tool_name.to_string()))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_noop_executor_returns_no_tools() {
let executor = NoOpToolExecutor;
let ctx = RequestContext::new();
assert!(executor.tools(&ctx).await.is_empty());
}
#[tokio::test]
async fn test_noop_executor_returns_not_found() {
let executor = NoOpToolExecutor;
let ctx = RequestContext::new();
let result = executor
.execute("test_tool", "call_123", &serde_json::json!({}), &ctx)
.await;
assert!(matches!(result, Err(ToolError::NotFound(_))));
}
#[test]
fn test_request_context_builder() {
let ctx = RequestContext::new()
.with_model("gpt-4o")
.with_extension(42u32);
assert_eq!(ctx.model.as_deref(), Some("gpt-4o"));
assert_eq!(ctx.extensions.get::<u32>(), Some(&42));
}
}