use std::sync::Arc;
use futures::stream::{self, BoxStream, StreamExt};
use tokio::task::JoinSet;
use wesichain_core::Tool;
use wesichain_core::{Runnable, StreamEvent, WesichainError};
use wesichain_llm::{Message, Role, ToolCall};
use crate::{GraphState, StateSchema, StateUpdate};
pub trait HasToolCalls {
fn tool_calls(&self) -> &Vec<ToolCall>;
fn push_tool_result(&mut self, message: Message);
}
pub struct ToolNode {
tools: Vec<Arc<dyn Tool>>,
}
impl ToolNode {
pub fn new(tools: Vec<Arc<dyn Tool>>) -> Self {
Self { tools }
}
pub async fn invoke<S>(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError>
where
S: StateSchema<Update = S> + HasToolCalls,
{
<Self as Runnable<GraphState<S>, StateUpdate<S>>>::invoke(self, input).await
}
}
#[async_trait::async_trait]
impl<S> Runnable<GraphState<S>, StateUpdate<S>> for ToolNode
where
S: StateSchema<Update = S> + HasToolCalls,
{
async fn invoke(&self, input: GraphState<S>) -> Result<StateUpdate<S>, WesichainError> {
let calls: Vec<ToolCall> = input.data.tool_calls().clone();
let mut join_set: JoinSet<(usize, String, Result<String, WesichainError>)> =
JoinSet::new();
for (index, call) in calls.iter().enumerate() {
let tool = self
.tools
.iter()
.find(|t| t.name() == call.name)
.ok_or_else(|| WesichainError::ToolCallFailed {
tool_name: call.name.clone(),
reason: "not found".to_string(),
})?;
let tool = tool.clone();
let args = call.args.clone();
let call_id = call.id.clone();
let tool_name = call.name.clone();
join_set.spawn(async move {
let result = tool.invoke(args).await.map(|v| v.to_string()).map_err(|e| {
WesichainError::ToolCallFailed {
tool_name,
reason: e.to_string(),
}
});
(index, call_id, result)
});
}
let mut results: Vec<(usize, String, Result<String, WesichainError>)> =
Vec::with_capacity(calls.len());
while let Some(res) = join_set.join_next().await {
results.push(res.map_err(|e| WesichainError::Custom(format!("task panicked: {e}")))?);
}
results.sort_by_key(|(idx, _, _)| *idx);
let mut next = input.data.clone();
for (_, call_id, output) in results {
next.push_tool_result(Message {
role: Role::Tool,
content: output?.into(),
tool_call_id: Some(call_id),
tool_calls: Vec::new(),
});
}
Ok(StateUpdate::new(next))
}
fn stream(&self, _input: GraphState<S>) -> BoxStream<'_, Result<StreamEvent, WesichainError>> {
stream::empty().boxed()
}
}