Skip to main content

llmkit/
tool_loop.rs

1//! Automatic tool-execution loop.
2//!
3//! [`ChatBuilder`] wraps a single chat request and a registry of typed tool
4//! handlers. Awaiting it runs the request; if the model asks for a registered
5//! tool, the handler is invoked, the result is fed back, and the request is
6//! re-issued — looping until the model returns a final answer (or `max_turns`
7//! is hit). Tools register their JSON Schema automatically via [`ToolSchema`].
8
9use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use llmkit_core::{
15    ChatRequest, ChatResponse, FinishReason, LlmError, LlmProvider, LlmResult, Message, Tool,
16    ToolSchema,
17};
18use serde::de::DeserializeOwned;
19
20type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
21type Handler = Box<dyn Fn(serde_json::Value) -> BoxFuture<LlmResult<String>> + Send + Sync>;
22
23const DEFAULT_MAX_TURNS: usize = 8;
24
25/// A chat request with optional registered tools, executed when awaited.
26pub struct ChatBuilder {
27    provider: Arc<dyn LlmProvider>,
28    req: ChatRequest,
29    handlers: HashMap<String, Handler>,
30    max_turns: usize,
31}
32
33impl ChatBuilder {
34    pub(crate) fn new(provider: Arc<dyn LlmProvider>, req: ChatRequest) -> Self {
35        Self { provider, req, handlers: HashMap::new(), max_turns: DEFAULT_MAX_TURNS }
36    }
37
38    /// Register a typed tool and its async handler.
39    ///
40    /// The tool's name, description, and JSON Schema come from `T`'s
41    /// [`ToolSchema`] impl; the handler receives the deserialized input.
42    pub fn with_tool<T, F, Fut>(mut self, handler: F) -> Self
43    where
44        T: ToolSchema + DeserializeOwned + Send + 'static,
45        F: Fn(T) -> Fut + Send + Sync + 'static,
46        Fut: Future<Output = LlmResult<String>> + Send + 'static,
47    {
48        let tool = Tool::from_schema::<T>();
49        let name = tool.name.clone();
50        self.req.tools.get_or_insert_with(Vec::new).push(tool);
51
52        let handler = Arc::new(handler);
53        let boxed: Handler = Box::new(move |input: serde_json::Value| {
54            let handler = handler.clone();
55            Box::pin(async move {
56                let parsed: T = serde_json::from_value(input)
57                    .map_err(|e| LlmError::serde(format!("tool input: {e}")))?;
58                handler(parsed).await
59            })
60        });
61        self.handlers.insert(name, boxed);
62        self
63    }
64
65    /// Cap the number of model/tool round-trips (default 8).
66    pub fn max_turns(mut self, max_turns: usize) -> Self {
67        self.max_turns = max_turns.max(1);
68        self
69    }
70
71    /// Run the request, executing tools automatically until a final answer.
72    async fn run(mut self) -> LlmResult<ChatResponse> {
73        let mut last = self.provider.chat(self.req.clone()).await?;
74
75        for _ in 0..self.max_turns {
76            if last.tool_calls.is_empty() || self.handlers.is_empty() {
77                return Ok(last);
78            }
79            if !matches!(last.finish_reason, FinishReason::ToolUse) && last.tool_calls.is_empty() {
80                return Ok(last);
81            }
82
83            // Echo the assistant's tool-use turn back, then append each result.
84            for call in &last.tool_calls {
85                self.req.messages.push(Message {
86                    role: llmkit_core::Role::Assistant,
87                    content: llmkit_core::MessageContent::ToolUse {
88                        id: call.id.clone(),
89                        name: call.name.clone(),
90                        input: call.input.clone(),
91                    },
92                });
93
94                let result = match self.handlers.get(&call.name) {
95                    Some(h) => match h(call.input.clone()).await {
96                        Ok(out) => out,
97                        Err(e) => format!("error: {e}"),
98                    },
99                    None => format!("error: no handler registered for tool `{}`", call.name),
100                };
101                self.req
102                    .messages
103                    .push(Message::tool_result(call.id.clone(), result));
104            }
105
106            last = self.provider.chat(self.req.clone()).await?;
107        }
108
109        Err(LlmError::Other(format!(
110            "tool loop exceeded {} turns",
111            self.max_turns
112        )))
113    }
114}
115
116impl std::future::IntoFuture for ChatBuilder {
117    type Output = LlmResult<ChatResponse>;
118    type IntoFuture = BoxFuture<LlmResult<ChatResponse>>;
119
120    fn into_future(self) -> Self::IntoFuture {
121        Box::pin(self.run())
122    }
123}