use async_trait::async_trait;
use futures::future::BoxFuture;
use std::collections::HashMap;
use std::sync::Arc;
use crate::config::RunnableConfig;
use crate::llm::ToolDefinition;
use crate::state::{Message, State};
use crate::store::Store;
#[derive(Debug, Clone, thiserror::Error)]
pub enum ToolError {
#[error("invalid input: {0}")]
InvalidInput(String),
#[error("execution failed: {0}")]
ExecutionFailed(String),
#[error("timeout")]
Timeout,
#[error("tool not found: {0}")]
ToolNotFound(String),
#[error("validation error: {0}")]
ValidationError(String),
}
#[async_trait]
pub trait Tool: Send + Sync + 'static {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn schema(&self) -> serde_json::Value;
fn definition(&self) -> ToolDefinition {
ToolDefinition {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.schema(),
}
}
async fn invoke(&self, input: serde_json::Value) -> Result<String, ToolError>;
#[must_use]
fn requires_store(&self) -> bool {
false
}
fn invoke_with_store<'a>(
&'a self,
input: serde_json::Value,
_store: &'a dyn crate::store::Store,
) -> BoxFuture<'a, Result<String, ToolError>>
where
Self: 'a,
{
Box::pin(async move {
let result = self.invoke(input).await?;
Ok(result)
})
}
}
#[allow(
missing_debug_implementations,
reason = "Contains dyn Store trait object which doesn't implement Debug"
)]
pub struct ToolRuntime<S: State> {
pub state: S,
pub tool_call_id: String,
pub config: RunnableConfig,
pub store: Option<Arc<dyn Store>>,
stream_tx: Option<tokio::sync::mpsc::UnboundedSender<serde_json::Value>>,
tools_event_tx: Option<tokio::sync::mpsc::UnboundedSender<crate::stream::ToolsEvent>>,
}
impl<S: State> ToolRuntime<S> {
#[must_use]
pub const fn new(
state: S,
tool_call_id: String,
config: RunnableConfig,
store: Option<Arc<dyn Store>>,
stream_tx: Option<tokio::sync::mpsc::UnboundedSender<serde_json::Value>>,
tools_event_tx: Option<tokio::sync::mpsc::UnboundedSender<crate::stream::ToolsEvent>>,
) -> Self {
Self {
state,
tool_call_id,
config,
store,
stream_tx,
tools_event_tx,
}
}
pub fn emit_output_delta(&self, delta: &str) {
if let Some(ref tx) = self.stream_tx {
let _ = tx.send(serde_json::json!({
"delta": delta,
"tool_call_id": self.tool_call_id
}));
}
}
pub fn emit_tool_started(&self, tool_name: &str, node: &str, input: serde_json::Value) {
if let Some(ref tx) = self.tools_event_tx {
let event = crate::stream::ToolsEvent::ToolStarted {
tool_name: tool_name.to_string(),
tool_call_id: self.tool_call_id.clone(),
node: node.to_string(),
input,
timestamp: chrono::Utc::now(),
};
let _ = tx.send(event);
}
}
pub fn emit_tool_finished(&self, output: serde_json::Value, duration_ms: u64, success: bool) {
if let Some(ref tx) = self.tools_event_tx {
let event = crate::stream::ToolsEvent::ToolFinished {
tool_call_id: self.tool_call_id.clone(),
output,
duration_ms,
success,
};
let _ = tx.send(event);
}
}
}
#[async_trait]
pub trait StatefulTool<S: State>: Tool {
fn invoke_with_state(
&self,
input: serde_json::Value,
runtime: &ToolRuntime<S>,
) -> BoxFuture<'_, Result<String, ToolError>>;
fn invoke_with_store<'a>(
&'a self,
input: serde_json::Value,
store: &'a dyn crate::store::Store,
) -> BoxFuture<'a, Result<String, ToolError>>
where
Self: 'a,
{
Tool::invoke_with_store(self, input, store)
}
}
#[async_trait]
pub trait ToolInterceptor: Send + Sync + 'static {
fn pre_execute(
&self,
tool_call: &crate::state::ToolCall,
state: &serde_json::Value,
) -> BoxFuture<'_, Result<(), ToolError>>;
fn post_execute(
&self,
tool_call: &crate::state::ToolCall,
result: &Result<String, ToolError>,
) -> BoxFuture<'_, Result<String, ToolError>>;
}
#[derive(Debug)]
pub struct NopToolInterceptor;
#[async_trait]
impl ToolInterceptor for NopToolInterceptor {
fn pre_execute(
&self,
_tool_call: &crate::state::ToolCall,
_state: &serde_json::Value,
) -> BoxFuture<'_, Result<(), ToolError>> {
Box::pin(async { Ok(()) })
}
fn post_execute(
&self,
_tool_call: &crate::state::ToolCall,
result: &Result<String, ToolError>,
) -> BoxFuture<'_, Result<String, ToolError>> {
let result_clone = result.clone();
Box::pin(async move { result_clone.map_err(|e| ToolError::ExecutionFailed(e.to_string())) })
}
}
pub trait ToolCallTransformer: Send + Sync + 'static {
fn transform(&self, tool_call: &mut crate::state::ToolCall) -> Result<(), ToolError>;
}
#[allow(
missing_debug_implementations,
clippy::type_complexity,
reason = "Contains trait objects and Arc<dyn Fn> which don't implement Debug. Complex trait object type is required for dynamic tool configuration."
)]
pub struct ToolNodeConfig {
pub tools: Vec<Box<dyn Tool>>,
pub handle_errors: bool,
pub validate_input: bool,
pub call_transformer: Option<Box<dyn ToolCallTransformer>>,
pub interceptor: Option<Arc<dyn ToolInterceptor>>,
pub tools_condition: Option<Arc<dyn Fn(&Message) -> bool + Send + Sync>>,
}
impl Default for ToolNodeConfig {
fn default() -> Self {
Self {
tools: vec![],
handle_errors: true,
validate_input: false,
call_transformer: None,
interceptor: None,
tools_condition: None,
}
}
}
#[allow(
missing_debug_implementations,
reason = "Contains trait objects which don't implement Debug"
)]
pub struct ToolNode {
#[expect(dead_code, reason = "Used in tool execution")]
tools: HashMap<String, Box<dyn Tool>>,
handle_errors: bool,
validate_input: bool,
call_transformer: Option<Box<dyn ToolCallTransformer>>,
interceptor: Option<Arc<dyn ToolInterceptor>>,
}
impl ToolNode {
#[must_use]
pub fn new(tools: Vec<Box<dyn Tool>>) -> Self {
let tool_map = tools
.into_iter()
.map(|t| (t.name().to_string(), t))
.collect();
Self {
tools: tool_map,
handle_errors: true,
validate_input: false,
call_transformer: None,
interceptor: None,
}
}
#[must_use]
pub fn from_config(config: ToolNodeConfig) -> Self {
let tool_map = config
.tools
.into_iter()
.map(|t| (t.name().to_string(), t))
.collect();
Self {
tools: tool_map,
handle_errors: config.handle_errors,
validate_input: config.validate_input,
call_transformer: config.call_transformer,
interceptor: config.interceptor,
}
}
#[must_use]
pub const fn with_error_handling(mut self, handle: bool) -> Self {
self.handle_errors = handle;
self
}
#[must_use]
pub const fn with_validation(mut self, validate: bool) -> Self {
self.validate_input = validate;
self
}
#[must_use]
pub fn with_transformer(mut self, transformer: Box<dyn ToolCallTransformer>) -> Self {
self.call_transformer = Some(transformer);
self
}
#[must_use]
pub fn with_interceptor(mut self, interceptor: Arc<dyn ToolInterceptor>) -> Self {
self.interceptor = Some(interceptor);
self
}
}
#[derive(Debug, Clone)]
pub struct ToolExecutionTrace {
pub tool_name: String,
pub tool_call_id: String,
pub attempt: usize,
pub first_attempt_time: chrono::DateTime<chrono::Utc>,
pub duration_ms: u64,
pub success: bool,
}
#[expect(dead_code, reason = "Used in tool execution validation")]
fn validate_tool_input(tool: &dyn Tool, input: &serde_json::Value) -> Result<(), ToolError> {
let schema = tool.schema();
if let Some(obj) = input.as_object()
&& let Some(schema_obj) = schema.as_object()
&& let Some(required) = schema_obj.get("required").and_then(|v| v.as_array())
{
for field in required {
if let Some(field_name) = field.as_str()
&& !obj.contains_key(field_name)
{
return Err(ToolError::ValidationError(format!(
"Missing required field: {field_name}",
)));
}
}
}
Ok(())
}
pub fn tools_condition<S: State + serde::Serialize>(
state: &S,
messages_field: &str,
) -> &'static str {
if has_pending_tool_calls(state, messages_field) {
"tools"
} else {
crate::END
}
}
fn has_pending_tool_calls<S: serde::Serialize>(state: &S, messages_field: &str) -> bool {
let Ok(value) = serde_json::to_value(state) else {
return false;
};
let Some(messages) = value.get(messages_field).and_then(|v| v.as_array()) else {
return false;
};
for msg in messages.iter().rev() {
let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("");
if role == "Ai" {
return msg
.get("tool_calls")
.and_then(|v| v.as_array())
.is_some_and(|arr| !arr.is_empty());
}
}
false
}