use super::Callable;
use crate::providers::{ChatMessage, ChatRequest, ModelProvider};
use crate::tool::{DynTool, Tool};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: Value,
}
#[derive(Debug, Clone, Serialize)]
pub struct ToolSchema {
#[serde(rename = "type")]
pub tool_type: String,
pub function: FunctionSchema,
}
#[derive(Debug, Clone, Serialize)]
pub struct FunctionSchema {
pub name: String,
pub description: String,
pub parameters: Value,
}
impl ToolSchema {
pub fn from_tool(tool: &dyn Tool) -> Self {
Self {
tool_type: "function".to_string(),
function: FunctionSchema {
name: tool.name().to_string(),
description: tool.description().to_string(),
parameters: tool.parameters_schema(),
},
}
}
}
pub struct LlmCallable {
name: String,
description: Option<String>,
system_prompt: String,
provider: Arc<dyn ModelProvider>,
tools: Vec<DynTool>,
max_iterations: usize,
}
impl LlmCallable {
pub fn with_provider(
name: impl Into<String>,
system_prompt: impl Into<String>,
provider: Arc<dyn ModelProvider>,
) -> Self {
Self {
name: name.into(),
description: None,
system_prompt: system_prompt.into(),
provider,
tools: Vec::new(),
max_iterations: 10,
}
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn add_tool(mut self, tool: impl Tool + 'static) -> Self {
self.tools.push(Arc::new(tool));
self
}
pub fn add_tools(mut self, tools: Vec<DynTool>) -> Self {
self.tools.extend(tools);
self
}
pub fn max_iterations(mut self, max: usize) -> Self {
self.max_iterations = max;
self
}
async fn execute_tool(&self, name: &str, args: Value) -> anyhow::Result<Value> {
let tool = self
.tools
.iter()
.find(|t| t.name() == name)
.ok_or_else(|| anyhow::anyhow!("Tool '{}' not found", name))?;
tool.execute(args).await
}
#[allow(dead_code)]
fn get_tool_schemas(&self) -> Vec<ToolSchema> {
self.tools
.iter()
.map(|t| ToolSchema::from_tool(t.as_ref()))
.collect()
}
fn extract_tool_calls(&self, content: &str) -> Option<Vec<ToolCall>> {
if let Ok(parsed) = serde_json::from_str::<Value>(content) {
if let Some(tool_call) = parsed.get("tool_call") {
let name = tool_call.get("name")?.as_str()?.to_string();
let arguments = tool_call.get("arguments").cloned().unwrap_or(Value::Null);
return Some(vec![ToolCall {
id: format!("call_{}", uuid::Uuid::new_v4()),
name,
arguments,
}]);
}
}
None
}
}
#[async_trait]
impl Callable for LlmCallable {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> Option<&str> {
self.description.as_deref()
}
async fn run(&self, input: &str) -> anyhow::Result<String> {
let mut messages = vec![
ChatMessage::system(&self.system_prompt),
ChatMessage::user(input),
];
if !self.tools.is_empty() {
let tool_names: Vec<&str> = self.tools.iter().map(|t| t.name()).collect();
let tool_instruction = format!(
"\n\nYou have access to these tools: {:?}. To use a tool, respond with ONLY a JSON object: {{\"tool_call\": {{\"name\": \"tool_name\", \"arguments\": {{...}}}}}}",
tool_names
);
messages[0] =
ChatMessage::system(format!("{}{}", self.system_prompt, tool_instruction));
}
for iteration in 0..self.max_iterations {
tracing::debug!(iteration, "Callable iteration");
let request = ChatRequest {
messages: messages.clone(),
max_tokens: Some(4096),
temperature: Some(0.7),
};
let response = self.provider.chat(request).await?;
let content = response
.choices
.first()
.map(|c| c.message.content.clone())
.unwrap_or_default();
if let Some(tool_calls) = self.extract_tool_calls(&content) {
for call in tool_calls {
tracing::debug!(tool = %call.name, "Executing tool");
let result = self.execute_tool(&call.name, call.arguments.clone()).await?;
messages.push(ChatMessage::assistant(&content));
messages.push(ChatMessage {
role: "user".to_string(),
content: format!("Tool result: {}", serde_json::to_string(&result)?),
});
}
} else {
return Ok(content);
}
}
anyhow::bail!("Max iterations ({}) reached", self.max_iterations)
}
}