oxide-agent 0.1.0

Type-safe, high-performance Rust crate for building agentic systems on Ollama
Documentation
use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use crate::error::OxideError;
use crate::types::ToolDefinition;

// ── Handler type ──────────────────────────────────────────────────────────────

type AsyncResult = Pin<Box<dyn Future<Output = Result<serde_json::Value, OxideError>> + Send>>;
type HandlerFn =
    Arc<dyn Fn(serde_json::Value) -> AsyncResult + Send + Sync>;

// ── Registered tool ───────────────────────────────────────────────────────────

struct RegisteredTool {
    definition: ToolDefinition,
    handler: HandlerFn,
}

// ── ToolRegistry ──────────────────────────────────────────────────────────────

/// Registry that maps tool names to their JSON Schema definitions and async
/// handler functions.
///
/// Pass the definitions to [`ChatRequest::tools`] when calling Ollama and call
/// [`ToolRegistry::dispatch`] when the model returns a tool-call in its reply.
#[derive(Default)]
pub struct ToolRegistry {
    tools: HashMap<String, RegisteredTool>,
}

impl ToolRegistry {
    pub fn new() -> Self {
        Self::default()
    }

    /// Register a tool with its JSON Schema definition and an async handler.
    ///
    /// The handler receives the `arguments` object from Ollama's tool-call
    /// response and must return a JSON value that will be fed back as the
    /// tool result.
    pub fn register<F, Fut>(
        &mut self,
        definition: ToolDefinition,
        handler: F,
    ) where
        F: Fn(serde_json::Value) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = Result<serde_json::Value, OxideError>> + Send + 'static,
    {
        let name = definition.function.name.clone();
        self.tools.insert(
            name,
            RegisteredTool {
                definition,
                handler: Arc::new(move |args| Box::pin(handler(args))),
            },
        );
    }

    /// All definitions, ready to pass directly to [`ChatRequest::tools`].
    pub fn definitions(&self) -> Vec<ToolDefinition> {
        self.tools.values().map(|t| t.definition.clone()).collect()
    }

    /// Execute the handler for `tool_name` with the given arguments.
    ///
    /// Returns `Err(OxideError::Other)` if the tool name is not registered.
    pub async fn dispatch(
        &self,
        tool_name: &str,
        args: serde_json::Value,
    ) -> Result<serde_json::Value, OxideError> {
        let tool = self.tools.get(tool_name).ok_or_else(|| {
            OxideError::Other(format!("unknown tool: {tool_name}"))
        })?;

        (tool.handler)(args).await
    }

    pub fn contains(&self, name: &str) -> bool {
        self.tools.contains_key(name)
    }

    pub fn len(&self) -> usize {
        self.tools.len()
    }

    pub fn is_empty(&self) -> bool {
        self.tools.is_empty()
    }
}

// ── Helper: build a ToolDefinition inline ────────────────────────────────────

/// Convenience builder for [`ToolDefinition`] without the `#[ollama_tool]`
/// macro.
pub struct ToolBuilder {
    name: String,
    description: String,
    properties: serde_json::Map<String, serde_json::Value>,
    required: Vec<String>,
}

impl ToolBuilder {
    pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
        Self {
            name: name.into(),
            description: description.into(),
            properties: serde_json::Map::new(),
            required: Vec::new(),
        }
    }

    /// Add a string parameter.
    pub fn string_param(
        mut self,
        name: impl Into<String>,
        description: impl Into<String>,
        required: bool,
    ) -> Self {
        let n = name.into();
        self.properties.insert(
            n.clone(),
            serde_json::json!({"type": "string", "description": description.into()}),
        );
        if required {
            self.required.push(n);
        }
        self
    }

    /// Add a numeric parameter.
    pub fn number_param(
        mut self,
        name: impl Into<String>,
        description: impl Into<String>,
        required: bool,
    ) -> Self {
        let n = name.into();
        self.properties.insert(
            n.clone(),
            serde_json::json!({"type": "number", "description": description.into()}),
        );
        if required {
            self.required.push(n);
        }
        self
    }

    /// Add a boolean parameter.
    pub fn bool_param(
        mut self,
        name: impl Into<String>,
        description: impl Into<String>,
        required: bool,
    ) -> Self {
        let n = name.into();
        self.properties.insert(
            n.clone(),
            serde_json::json!({"type": "boolean", "description": description.into()}),
        );
        if required {
            self.required.push(n);
        }
        self
    }

    pub fn build(self) -> ToolDefinition {
        use crate::types::FunctionDefinition;
        ToolDefinition {
            kind: "function".into(),
            function: FunctionDefinition {
                name: self.name,
                description: self.description,
                parameters: serde_json::json!({
                    "type": "object",
                    "properties": serde_json::Value::Object(self.properties),
                    "required": self.required,
                }),
            },
        }
    }
}

// ── Tests ─────────────────────────────────────────────────────────────────────

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn registry_dispatch_calls_handler() {
        let mut registry = ToolRegistry::new();

        let def = ToolBuilder::new("add", "Add two numbers")
            .number_param("a", "First operand", true)
            .number_param("b", "Second operand", true)
            .build();

        registry.register(def, |args| async move {
            let a = args["a"].as_f64().unwrap_or(0.0);
            let b = args["b"].as_f64().unwrap_or(0.0);
            Ok(serde_json::json!(a + b))
        });

        let result = registry
            .dispatch("add", serde_json::json!({"a": 3.0, "b": 4.0}))
            .await
            .unwrap();

        assert_eq!(result, serde_json::json!(7.0));
    }

    #[tokio::test]
    async fn unknown_tool_returns_error() {
        let registry = ToolRegistry::new();
        let err = registry
            .dispatch("nonexistent", serde_json::json!({}))
            .await
            .unwrap_err();
        assert!(matches!(err, OxideError::Other(_)));
    }

    #[test]
    fn definitions_are_returned() {
        let mut registry = ToolRegistry::new();
        let def = ToolBuilder::new("greet", "Say hello").build();
        registry.register(def, |_| async move { Ok(serde_json::json!("hello")) });
        assert_eq!(registry.definitions().len(), 1);
        assert_eq!(registry.definitions()[0].function.name, "greet");
    }
}