use std::collections::HashMap;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use llmkit_core::{
ChatRequest, ChatResponse, FinishReason, LlmError, LlmProvider, LlmResult, Message, Tool,
ToolSchema,
};
use serde::de::DeserializeOwned;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
type Handler = Box<dyn Fn(serde_json::Value) -> BoxFuture<LlmResult<String>> + Send + Sync>;
const DEFAULT_MAX_TURNS: usize = 8;
pub struct ChatBuilder {
provider: Arc<dyn LlmProvider>,
req: ChatRequest,
handlers: HashMap<String, Handler>,
max_turns: usize,
}
impl ChatBuilder {
pub(crate) fn new(provider: Arc<dyn LlmProvider>, req: ChatRequest) -> Self {
Self { provider, req, handlers: HashMap::new(), max_turns: DEFAULT_MAX_TURNS }
}
pub fn with_tool<T, F, Fut>(mut self, handler: F) -> Self
where
T: ToolSchema + DeserializeOwned + Send + 'static,
F: Fn(T) -> Fut + Send + Sync + 'static,
Fut: Future<Output = LlmResult<String>> + Send + 'static,
{
let tool = Tool::from_schema::<T>();
let name = tool.name.clone();
self.req.tools.get_or_insert_with(Vec::new).push(tool);
let handler = Arc::new(handler);
let boxed: Handler = Box::new(move |input: serde_json::Value| {
let handler = handler.clone();
Box::pin(async move {
let parsed: T = serde_json::from_value(input)
.map_err(|e| LlmError::serde(format!("tool input: {e}")))?;
handler(parsed).await
})
});
self.handlers.insert(name, boxed);
self
}
pub fn max_turns(mut self, max_turns: usize) -> Self {
self.max_turns = max_turns.max(1);
self
}
async fn run(mut self) -> LlmResult<ChatResponse> {
let mut last = self.provider.chat(self.req.clone()).await?;
for _ in 0..self.max_turns {
if last.tool_calls.is_empty() || self.handlers.is_empty() {
return Ok(last);
}
if !matches!(last.finish_reason, FinishReason::ToolUse) && last.tool_calls.is_empty() {
return Ok(last);
}
for call in &last.tool_calls {
self.req.messages.push(Message {
role: llmkit_core::Role::Assistant,
content: llmkit_core::MessageContent::ToolUse {
id: call.id.clone(),
name: call.name.clone(),
input: call.input.clone(),
},
});
let result = match self.handlers.get(&call.name) {
Some(h) => match h(call.input.clone()).await {
Ok(out) => out,
Err(e) => format!("error: {e}"),
},
None => format!("error: no handler registered for tool `{}`", call.name),
};
self.req
.messages
.push(Message::tool_result(call.id.clone(), result));
}
last = self.provider.chat(self.req.clone()).await?;
}
Err(LlmError::Other(format!(
"tool loop exceeded {} turns",
self.max_turns
)))
}
}
impl std::future::IntoFuture for ChatBuilder {
type Output = LlmResult<ChatResponse>;
type IntoFuture = BoxFuture<LlmResult<ChatResponse>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(self.run())
}
}