llmoxide-tools 0.1.0

Tool-calling runner for llmoxide (schemas, dispatch, streaming callbacks)
Documentation
use crate::runner::ToolError;
use async_trait::async_trait;
use llmoxide::types::{ToolCall, ToolSpec};
use schemars::{JsonSchema, schema_for};
use serde::{Serialize, de::DeserializeOwned};
use std::collections::BTreeMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + 'a>>;

#[derive(Debug, Clone)]
pub struct ToolMeta {
    pub name: String,
    pub description: Option<String>,
}

impl ToolMeta {
    pub fn new(name: impl Into<String>) -> Self {
        Self {
            name: name.into(),
            description: None,
        }
    }

    pub fn description(mut self, description: impl Into<String>) -> Self {
        self.description = Some(description.into());
        self
    }
}

#[async_trait(?Send)]
trait DynTool: Send + Sync {
    fn spec(&self) -> ToolSpec;
    fn name(&self) -> &str;
    async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError>;
}

struct ToolImpl<TArgs, TResult> {
    meta: ToolMeta,
    handler: Arc<dyn Fn(TArgs) -> BoxFuture<'static, Result<TResult, ToolError>> + Send + Sync>,
    _phantom: std::marker::PhantomData<(TArgs, TResult)>,
}

impl<TArgs, TResult> ToolImpl<TArgs, TResult>
where
    TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
    TResult: Serialize + Send + Sync + 'static,
{
    fn schema_json() -> serde_json::Value {
        let schema = schema_for!(TArgs);
        serde_json::to_value(&schema.schema).unwrap_or(serde_json::Value::Null)
    }
}

#[async_trait(?Send)]
impl<TArgs, TResult> DynTool for ToolImpl<TArgs, TResult>
where
    TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
    TResult: Serialize + Send + Sync + 'static,
{
    fn spec(&self) -> ToolSpec {
        ToolSpec {
            name: self.meta.name.clone(),
            description: self.meta.description.clone(),
            parameters: Self::schema_json(),
        }
    }

    fn name(&self) -> &str {
        &self.meta.name
    }

    async fn call(&self, call: &ToolCall) -> Result<serde_json::Value, ToolError> {
        let args: TArgs = serde_json::from_value(call.arguments.clone()).map_err(|e| {
            ToolError::InvalidArguments {
                tool: self.meta.name.clone(),
                details: e.to_string(),
            }
        })?;

        let res = (self.handler)(args).await?;
        serde_json::to_value(res).map_err(|e| ToolError::Handler {
            tool: self.meta.name.clone(),
            details: e.to_string(),
        })
    }
}

/// Registry of typed tools, convertible to provider tool schemas.
#[derive(Clone, Default)]
pub struct ToolRegistry {
    tools: Arc<BTreeMap<String, Arc<dyn DynTool>>>,
}

impl ToolRegistry {
    pub fn new() -> Self {
        Self::default()
    }

    /// Register a tool with typed args and typed result.
    ///
    /// - `TArgs` must implement `JsonSchema` so we can produce a provider tool schema.
    /// - The handler returns a typed result which is serialized to JSON for tool responses.
    pub fn register<TArgs, TResult, Fut, F>(&mut self, meta: ToolMeta, handler: F) -> &mut Self
    where
        TArgs: DeserializeOwned + JsonSchema + Send + Sync + 'static,
        TResult: Serialize + Send + Sync + 'static,
        Fut: Future<Output = Result<TResult, ToolError>> + 'static,
        F: Fn(TArgs) -> Fut + Send + Sync + 'static,
    {
        let mut map: BTreeMap<String, Arc<dyn DynTool>> = (*self.tools).clone();
        let name = meta.name.clone();
        let handler = Arc::new(
            move |args: TArgs| -> BoxFuture<'static, Result<TResult, ToolError>> {
                Box::pin(handler(args))
            },
        );
        let tool = ToolImpl::<TArgs, TResult> {
            meta,
            handler,
            _phantom: std::marker::PhantomData,
        };
        map.insert(name, Arc::new(tool));
        self.tools = Arc::new(map);
        self
    }

    pub fn specs(&self) -> Vec<ToolSpec> {
        self.tools.values().map(|t| t.spec()).collect()
    }

    pub(crate) async fn dispatch(
        &self,
        call: &ToolCall,
    ) -> Result<(String, serde_json::Value), ToolError> {
        let Some(tool) = self.tools.get(&call.name) else {
            return Err(ToolError::UnknownTool {
                tool: call.name.clone(),
            });
        };
        let out = tool.call(call).await?;
        Ok((tool.name().to_string(), out))
    }
}