1use std::collections::HashMap;
10use std::future::Future;
11use std::pin::Pin;
12use std::sync::Arc;
13
14use llmkit_core::{
15 ChatRequest, ChatResponse, FinishReason, LlmError, LlmProvider, LlmResult, Message, Tool,
16 ToolSchema,
17};
18use serde::de::DeserializeOwned;
19
20type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
21type Handler = Box<dyn Fn(serde_json::Value) -> BoxFuture<LlmResult<String>> + Send + Sync>;
22
23const DEFAULT_MAX_TURNS: usize = 8;
24
25pub struct ChatBuilder {
27 provider: Arc<dyn LlmProvider>,
28 req: ChatRequest,
29 handlers: HashMap<String, Handler>,
30 max_turns: usize,
31}
32
33impl ChatBuilder {
34 pub(crate) fn new(provider: Arc<dyn LlmProvider>, req: ChatRequest) -> Self {
35 Self { provider, req, handlers: HashMap::new(), max_turns: DEFAULT_MAX_TURNS }
36 }
37
38 pub fn with_tool<T, F, Fut>(mut self, handler: F) -> Self
43 where
44 T: ToolSchema + DeserializeOwned + Send + 'static,
45 F: Fn(T) -> Fut + Send + Sync + 'static,
46 Fut: Future<Output = LlmResult<String>> + Send + 'static,
47 {
48 let tool = Tool::from_schema::<T>();
49 let name = tool.name.clone();
50 self.req.tools.get_or_insert_with(Vec::new).push(tool);
51
52 let handler = Arc::new(handler);
53 let boxed: Handler = Box::new(move |input: serde_json::Value| {
54 let handler = handler.clone();
55 Box::pin(async move {
56 let parsed: T = serde_json::from_value(input)
57 .map_err(|e| LlmError::serde(format!("tool input: {e}")))?;
58 handler(parsed).await
59 })
60 });
61 self.handlers.insert(name, boxed);
62 self
63 }
64
65 pub fn max_turns(mut self, max_turns: usize) -> Self {
67 self.max_turns = max_turns.max(1);
68 self
69 }
70
71 async fn run(mut self) -> LlmResult<ChatResponse> {
73 let mut last = self.provider.chat(self.req.clone()).await?;
74
75 for _ in 0..self.max_turns {
76 if last.tool_calls.is_empty() || self.handlers.is_empty() {
77 return Ok(last);
78 }
79 if !matches!(last.finish_reason, FinishReason::ToolUse) && last.tool_calls.is_empty() {
80 return Ok(last);
81 }
82
83 for call in &last.tool_calls {
85 self.req.messages.push(Message {
86 role: llmkit_core::Role::Assistant,
87 content: llmkit_core::MessageContent::ToolUse {
88 id: call.id.clone(),
89 name: call.name.clone(),
90 input: call.input.clone(),
91 },
92 });
93
94 let result = match self.handlers.get(&call.name) {
95 Some(h) => match h(call.input.clone()).await {
96 Ok(out) => out,
97 Err(e) => format!("error: {e}"),
98 },
99 None => format!("error: no handler registered for tool `{}`", call.name),
100 };
101 self.req
102 .messages
103 .push(Message::tool_result(call.id.clone(), result));
104 }
105
106 last = self.provider.chat(self.req.clone()).await?;
107 }
108
109 Err(LlmError::Other(format!(
110 "tool loop exceeded {} turns",
111 self.max_turns
112 )))
113 }
114}
115
116impl std::future::IntoFuture for ChatBuilder {
117 type Output = LlmResult<ChatResponse>;
118 type IntoFuture = BoxFuture<LlmResult<ChatResponse>>;
119
120 fn into_future(self) -> Self::IntoFuture {
121 Box::pin(self.run())
122 }
123}