use std::collections::BTreeMap;
use std::fmt::Debug;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use dyn_clone::DynClone;
use llmy_client::req::{
ChatCompletionRequestMessageRaw, ChatCompletionRequestToolMessageContent,
ChatCompletionRequestToolMessageRaw, ChatCompletionTool, ChatCompletionToolRaw,
ChatCompletionTools, ChatCompletionToolsRaw, FunctionObjectRaw,
};
use llmy_types::error::{GeneralToolCall, LLMYError};
use llmy_types::other::WithOtherFields;
use schemars::schema_for;
use serde::de::DeserializeOwned;
use tokio::task::JoinSet;
use tracing::debug;
pub trait ToolDyn: DynClone + Debug + Send + Sync + std::any::Any {
fn name(&self) -> String;
fn description(&self) -> Option<String>;
fn schema(&self) -> schemars::Schema;
fn strict(&self) -> bool {
false
}
fn to_openai_obejct(&self) -> ChatCompletionTool {
WithOtherFields::new(ChatCompletionToolRaw {
function: WithOtherFields::new(FunctionObjectRaw {
name: self.name(),
description: self.description(),
parameters: Some(
serde_json::to_value(self.schema()).expect("Fail to serialize schema"),
),
strict: Some(self.strict()),
}),
})
}
fn to_mcp_tool(&self) -> rmcp::model::Tool {
let input_schema = serde_json::to_value(self.schema()).expect("Fail to serialize schema");
let input_schema = input_schema.as_object().cloned().unwrap_or_default();
rmcp::model::Tool::new_with_raw(
self.name(),
self.description().map(Into::into),
Arc::new(input_schema),
)
}
fn call(
&self,
arguments: String,
) -> Pin<Box<dyn Future<Output = Result<String, LLMYError>> + Send + '_>> {
Box::pin(async move {
match serde_json::from_str::<serde_json::Value>(&arguments) {
Ok(value) => self.run(value).await,
Err(_) => Err(LLMYError::IncorrectToolCall(
self.name(),
arguments,
self.schema(),
)),
}
})
}
fn run(
&self,
arguments: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<String, LLMYError>> + Send + '_>>;
}
pub fn downcast_tool<T: 'static>(tool: &dyn ToolDyn) -> &T {
(tool as &dyn std::any::Any)
.downcast_ref::<T>()
.expect("can not downcast")
}
dyn_clone::clone_trait_object!(ToolDyn);
pub trait Tool: Send + Sync + DynClone + Debug {
type ARGUMENTS: DeserializeOwned + schemars::JsonSchema + Sized + Send;
const NAME: &str;
const DESCRIPTION: Option<&str>;
const STRICT: bool = false;
fn invoke(
&self,
arguments: Self::ARGUMENTS,
) -> impl Future<Output = Result<String, LLMYError>> + Send;
}
impl<T: Tool + DynClone + 'static> ToolDyn for T {
fn name(&self) -> String {
Self::NAME.to_string()
}
fn description(&self) -> Option<String> {
Self::DESCRIPTION.map(|v| v.to_string())
}
fn schema(&self) -> schemars::Schema {
schema_for!(T::ARGUMENTS)
}
fn strict(&self) -> bool {
T::STRICT
}
fn run(
&self,
arguments: serde_json::Value,
) -> Pin<Box<dyn Future<Output = Result<String, LLMYError>> + Send + '_>> {
Box::pin(async move {
match serde_json::from_value::<T::ARGUMENTS>(arguments.clone()) {
Ok(args) => self.invoke(args).await,
Err(_) => Err(LLMYError::IncorrectToolCall(
T::NAME.to_string(),
arguments.to_string(),
schema_for!(T::ARGUMENTS),
)),
}
})
}
}
#[derive(Default, Clone, Debug)]
pub struct ToolBox {
tools: BTreeMap<String, Arc<Box<dyn ToolDyn>>>,
}
impl ToolBox {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.tools.len()
}
pub fn render_tools(&self, details: bool) -> Vec<String> {
self.tools
.iter()
.map(|(name, tool)| {
if details {
format!(
"`{}`: {:?}", name,
tool.description()
.unwrap_or_else(|| "no description is provided".to_string())
)
} else {
name.clone()
}
})
.collect()
}
pub fn extend(&mut self, rhs: Self) {
self.tools.extend(rhs.tools.into_iter());
}
pub fn has_tool(&self, tool: &String) -> bool {
self.tools.contains_key(tool)
}
pub fn mcp_tools(&self) -> Vec<rmcp::model::Tool> {
self.tools.values().map(|t| t.to_mcp_tool()).collect()
}
pub fn openai_objects(&self) -> Vec<ChatCompletionTools> {
self.tools
.iter()
.map(|t| WithOtherFields::new(ChatCompletionToolsRaw::Function(t.1.to_openai_obejct())))
.collect()
}
pub fn add_tool<T: Tool + 'static>(&mut self, tool: T) {
self.add_dyn_tool(Box::new(tool) as _);
}
pub fn add_dyn_tool(&mut self, tool: Box<dyn ToolDyn>) {
self.tools.insert(tool.name(), Arc::new(tool));
}
pub async fn invoke(
&self,
tool_name: String,
arguments: String,
) -> Option<Result<String, LLMYError>> {
if let Some(tool) = self.tools.get(&tool_name) {
debug!("Invoking tool {} with arguments {}", &tool_name, &arguments);
Some(tool.call(arguments).await)
} else {
None
}
}
pub async fn invoke_value(
&self,
tool_name: String,
arguments: serde_json::Value,
) -> Option<Result<String, LLMYError>> {
if let Some(tool) = self.tools.get(&tool_name) {
debug!("Invoking tool {} with arguments {}", &tool_name, &arguments);
Some(tool.run(arguments).await)
} else {
None
}
}
pub async fn invoke_many(
&self,
calls: Vec<GeneralToolCall>,
) -> Vec<(GeneralToolCall, Option<Result<String, LLMYError>>)> {
let mut js = JoinSet::new();
for call in calls {
let tb = self.clone();
js.spawn(async move {
let tc: GeneralToolCall = call.clone();
tracing::info!("Calling {}", &tc);
(tc, tb.invoke(call.tool_name, call.tool_args).await)
});
}
js.join_all().await
}
pub async fn invoke_many_sequential(
&self,
calls: Vec<GeneralToolCall>,
) -> Vec<(GeneralToolCall, Option<Result<String, LLMYError>>)> {
let mut out = Vec::with_capacity(calls.len());
for call in calls {
let tc: GeneralToolCall = call.clone();
tracing::info!("Calling {}", &tc);
out.push((tc, self.invoke(call.tool_name, call.tool_args).await));
}
out
}
pub async fn agent_invoke_many(
&self,
calls: Vec<GeneralToolCall>,
) -> Vec<(
GeneralToolCall,
Option<Result<ChatCompletionRequestMessageRaw, LLMYError>>,
)> {
let invokes = self.invoke_many(calls).await;
Self::agent_messages_from_invokes(invokes)
}
pub async fn agent_invoke_many_sequential(
&self,
calls: Vec<GeneralToolCall>,
) -> Vec<(
GeneralToolCall,
Option<Result<ChatCompletionRequestMessageRaw, LLMYError>>,
)> {
let invokes = self.invoke_many_sequential(calls).await;
Self::agent_messages_from_invokes(invokes)
}
fn agent_messages_from_invokes(
invokes: Vec<(GeneralToolCall, Option<Result<String, LLMYError>>)>,
) -> Vec<(
GeneralToolCall,
Option<Result<ChatCompletionRequestMessageRaw, LLMYError>>,
)> {
let mut out = vec![];
for (call, result) in invokes {
let id = call.tool_id.clone();
let result = result.map(|v| {
v.map(|s| {
let tool_msg = ChatCompletionRequestToolMessageRaw {
content: ChatCompletionRequestToolMessageContent::Text(s),
tool_call_id: id,
};
ChatCompletionRequestMessageRaw::Tool(WithOtherFields::new(tool_msg))
})
});
out.push((call, result));
}
out
}
}