llmkit-rs 0.1.0

Unified multi-provider async LLM client for Rust — OpenAI, Anthropic, Ollama, with Tower middleware
Documentation
//! Automatic tool-execution loop.
//!
//! [`ChatBuilder`] wraps a single chat request and a registry of typed tool
//! handlers. Awaiting it runs the request; if the model asks for a registered
//! tool, the handler is invoked, the result is fed back, and the request is
//! re-issued — looping until the model returns a final answer (or `max_turns`
//! is hit). Tools register their JSON Schema automatically via [`ToolSchema`].

use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;

use llmkit_core::{
    ChatRequest, ChatResponse, FinishReason, LlmError, LlmProvider, LlmResult, Message, Tool,
    ToolSchema,
};
use serde::de::DeserializeOwned;

type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
type Handler = Box<dyn Fn(serde_json::Value) -> BoxFuture<LlmResult<String>> + Send + Sync>;

const DEFAULT_MAX_TURNS: usize = 8;

/// A chat request with optional registered tools, executed when awaited.
pub struct ChatBuilder {
    provider: Arc<dyn LlmProvider>,
    req: ChatRequest,
    handlers: HashMap<String, Handler>,
    max_turns: usize,
}

impl ChatBuilder {
    pub(crate) fn new(provider: Arc<dyn LlmProvider>, req: ChatRequest) -> Self {
        Self { provider, req, handlers: HashMap::new(), max_turns: DEFAULT_MAX_TURNS }
    }

    /// Register a typed tool and its async handler.
    ///
    /// The tool's name, description, and JSON Schema come from `T`'s
    /// [`ToolSchema`] impl; the handler receives the deserialized input.
    pub fn with_tool<T, F, Fut>(mut self, handler: F) -> Self
    where
        T: ToolSchema + DeserializeOwned + Send + 'static,
        F: Fn(T) -> Fut + Send + Sync + 'static,
        Fut: Future<Output = LlmResult<String>> + Send + 'static,
    {
        let tool = Tool::from_schema::<T>();
        let name = tool.name.clone();
        self.req.tools.get_or_insert_with(Vec::new).push(tool);

        let handler = Arc::new(handler);
        let boxed: Handler = Box::new(move |input: serde_json::Value| {
            let handler = handler.clone();
            Box::pin(async move {
                let parsed: T = serde_json::from_value(input)
                    .map_err(|e| LlmError::serde(format!("tool input: {e}")))?;
                handler(parsed).await
            })
        });
        self.handlers.insert(name, boxed);
        self
    }

    /// Cap the number of model/tool round-trips (default 8).
    pub fn max_turns(mut self, max_turns: usize) -> Self {
        self.max_turns = max_turns.max(1);
        self
    }

    /// Run the request, executing tools automatically until a final answer.
    async fn run(mut self) -> LlmResult<ChatResponse> {
        let mut last = self.provider.chat(self.req.clone()).await?;

        for _ in 0..self.max_turns {
            if last.tool_calls.is_empty() || self.handlers.is_empty() {
                return Ok(last);
            }
            if !matches!(last.finish_reason, FinishReason::ToolUse) && last.tool_calls.is_empty() {
                return Ok(last);
            }

            // Echo the assistant's tool-use turn back, then append each result.
            for call in &last.tool_calls {
                self.req.messages.push(Message {
                    role: llmkit_core::Role::Assistant,
                    content: llmkit_core::MessageContent::ToolUse {
                        id: call.id.clone(),
                        name: call.name.clone(),
                        input: call.input.clone(),
                    },
                });

                let result = match self.handlers.get(&call.name) {
                    Some(h) => match h(call.input.clone()).await {
                        Ok(out) => out,
                        Err(e) => format!("error: {e}"),
                    },
                    None => format!("error: no handler registered for tool `{}`", call.name),
                };
                self.req
                    .messages
                    .push(Message::tool_result(call.id.clone(), result));
            }

            last = self.provider.chat(self.req.clone()).await?;
        }

        Err(LlmError::Other(format!(
            "tool loop exceeded {} turns",
            self.max_turns
        )))
    }
}

impl std::future::IntoFuture for ChatBuilder {
    type Output = LlmResult<ChatResponse>;
    type IntoFuture = BoxFuture<LlmResult<ChatResponse>>;

    fn into_future(self) -> Self::IntoFuture {
        Box::pin(self.run())
    }
}