use crate::action::{ActionAgent, AgentAction, AgentFinish, AgentStep};
use async_trait::async_trait;
use futures::StreamExt;
use serde_json::json;
use wesichain_core::{LlmRequest, LlmResponse, Runnable, Tool, ToolCallingLlm, WesichainError};
use wesichain_prompt::ChatPromptTemplate;
pub struct ToolCallingAgentRunnable {
_prompt: ChatPromptTemplate,
llm: Box<dyn Runnable<LlmRequest, LlmResponse> + Send + Sync>,
}
#[async_trait]
impl Runnable<LlmRequest, AgentStep> for ToolCallingAgentRunnable {
async fn invoke(&self, input: LlmRequest) -> Result<AgentStep, WesichainError> {
let response = self.llm.invoke(input).await?;
if let Some(tool_calls) = response.tool_calls.first() {
Ok(AgentStep::Action(AgentAction {
tool: tool_calls.name.clone(),
tool_input: tool_calls.args.clone(),
log: format!("Invoking tool {}", tool_calls.name),
}))
} else {
Ok(AgentStep::Finish(AgentFinish {
return_values: json!({ "output": response.content }),
log: response.content,
}))
}
}
fn stream(
&self,
_input: LlmRequest,
) -> futures::stream::BoxStream<'_, Result<wesichain_core::StreamEvent, WesichainError>> {
futures::stream::empty().boxed()
}
}
#[async_trait]
impl ActionAgent for ToolCallingAgentRunnable {}
pub fn create_tool_calling_agent(
llm: Box<dyn ToolCallingLlm>,
tools: Vec<Box<dyn Tool>>,
_prompt: ChatPromptTemplate, ) -> impl ActionAgent {
let tool_specs: Vec<_> = tools
.iter()
.map(|t| wesichain_core::ToolSpec {
name: t.name().to_string(),
description: t.description().to_string(),
parameters: t.schema(),
})
.collect();
struct ToolBindingLlm {
inner: Box<dyn ToolCallingLlm>,
tools: Vec<wesichain_core::ToolSpec>,
}
#[async_trait]
impl Runnable<LlmRequest, LlmResponse> for ToolBindingLlm {
async fn invoke(&self, mut request: LlmRequest) -> Result<LlmResponse, WesichainError> {
request.tools.extend(self.tools.clone());
self.inner.invoke(request).await
}
fn stream(
&self,
mut input: LlmRequest,
) -> futures::stream::BoxStream<'_, Result<wesichain_core::StreamEvent, WesichainError>>
{
input.tools.extend(self.tools.clone());
self.inner.stream(input)
}
}
let bound_llm = ToolBindingLlm {
inner: llm,
tools: tool_specs,
};
ToolCallingAgentRunnable {
_prompt,
llm: Box::new(bound_llm),
}
}