use futures::StreamExt;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, fmt, future::Future, pin::Pin, sync::Arc};
use tracing::{span, Instrument, Level, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use crate::{services::llm::message::Message, Agent, NotificationHandler};
use super::errors::ToolExecutionError;
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum ToolType {
Function,
}
pub type AsyncToolFn = Arc<
dyn Fn(Value) -> Pin<Box<dyn Future<Output = Result<String, ToolExecutionError>> + Send>>
+ Send
+ Sync,
>;
fn default_executor() -> AsyncToolFn {
Arc::new(|_| {
Box::pin(async {
panic!("Called a default, non-functional tool executor. The tool was not rehydrated after deserialization.");
})
})
}
#[derive(Serialize, Clone, Deserialize)]
pub struct Tool {
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: Function,
#[serde(skip, default = "default_executor")]
pub executor: AsyncToolFn,
}
impl fmt::Debug for Tool {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Tool")
.field("tool_type", &self.tool_type)
.field("function", &self.function)
.field("executor", &"<async_fn>") .finish()
}
}
impl Tool {
pub async fn execute(&self, args: Value) -> Result<String, ToolExecutionError> {
(self.executor)(args).await
}
pub fn name(&self) -> &str {
&self.function.name
}
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Function {
pub name: String,
pub description: String,
pub parameters: FunctionParameters,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct FunctionParameters {
#[serde(rename = "type")]
pub param_type: String,
pub properties: HashMap<String, Property>,
pub required: Vec<String>,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct Property {
#[serde(rename = "type")]
pub property_type: String,
pub description: String,
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolCall {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub id: Option<String>,
#[serde(
default = "default_tool_call_type",
skip_serializing_if = "is_default_tool_call_type"
)]
#[serde(rename = "type")]
pub tool_type: ToolType,
pub function: ToolCallFunction,
}
fn default_tool_call_type() -> ToolType {
ToolType::Function }
#[allow(clippy::trivially_copy_pass_by_ref)]
fn is_default_tool_call_type(tool_type: &ToolType) -> bool {
*tool_type == default_tool_call_type()
}
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct ToolCallFunction {
pub name: String,
pub arguments: serde_json::Value,
}
pub async fn call_tools(agent: &Agent, tool_calls: &[ToolCall]) -> Vec<Message> {
let mut results = Vec::new();
let Some(avail) = &agent.tools else {
tracing::error!("No avalible tools specified");
agent
.notify_tool_error("Agent called tools, but no tools avalible to the model".into())
.await;
results.push(Message::tool(
"If you want to use a tool specify the name of the available tool.",
"Tool".to_string(),
));
return results;
};
let results = futures::stream::iter(tool_calls.iter().cloned())
.map(|call| {
let tool_span = span!(
Level::INFO,
"Tool Call", "langfuse.observation.type" = "tool", "langfuse.observation.metadata.tool_name" = call.function.name.as_str(), "langfuse.observation.id" = call.id.as_deref().unwrap_or("unknown"),
"langfuse.observation.name" = &format!("Tool: {}", call.function.name.as_str()),
);
if let Ok(input_str) = serde_json::to_string_pretty(&call.function.arguments) {
tool_span.set_attribute("input.value", input_str);
}
async move {
let Some(tool) = avail.iter().find(|t| t.function.name == call.function.name)
else {
Span::current().set_attribute("otel.status_code", "ERROR");
Span::current()
.set_attribute("langfuse.observation.status_message", "Tool not found");
return Message::tool("Tool not found", "0".to_string());
};
agent.notify_tool_request(call.clone()).await;
match tool.execute(call.function.arguments.clone()).await {
Ok(output) => {
Span::current().set_attribute("output.value", output.clone());
Span::current().set_attribute("otel.status_code", "OK");
agent.notify_tool_success(output.clone()).await;
Message::tool(output, call.id.clone().unwrap_or(call.function.name))
}
Err(e) => {
let err_msg = e.to_string();
Span::current().set_attribute("otel.status_code", "ERROR");
Span::current()
.set_attribute("langfuse.observation.status_message", err_msg.clone());
agent.notify_tool_error(err_msg.clone()).await;
Message::tool(err_msg, "0".to_string())
}
}
}
.instrument(tool_span) })
.buffer_unordered(tool_calls.len())
.collect::<Vec<Message>>()
.await;
results
}