use anyhow::Result;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::agents::debug;
pub struct DebugTool<T>
where
T: Tool,
{
inner: T,
_phantom: PhantomData<T>,
}
impl<T> DebugTool<T>
where
T: Tool,
{
pub fn new(tool: T) -> Self {
Self {
inner: tool,
_phantom: PhantomData,
}
}
}
impl<T> Clone for DebugTool<T>
where
T: Tool + Clone,
{
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
_phantom: PhantomData,
}
}
}
impl<T> Debug for DebugTool<T>
where
T: Tool + Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("DebugTool")
.field("inner", &self.inner)
.finish()
}
}
impl<T> Serialize for DebugTool<T>
where
T: Tool + Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.inner.serialize(serializer)
}
}
impl<'de, T> Deserialize<'de> for DebugTool<T>
where
T: Tool + Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let inner = T::deserialize(deserializer)?;
Ok(Self {
inner,
_phantom: PhantomData,
})
}
}
impl<T> Tool for DebugTool<T>
where
T: Tool + Send + Sync,
T::Args: Debug + Send + Sync,
T::Output: Debug + Send + Sync,
T::Error: Send + Sync,
{
const NAME: &'static str = T::NAME;
type Error = T::Error;
type Args = T::Args;
type Output = T::Output;
async fn definition(&self, prompt: String) -> ToolDefinition {
self.inner.definition(prompt).await
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
let args_str = format!("{:?}", args);
debug::debug_tool_call(Self::NAME, &args_str);
let timer = debug::DebugTimer::start(&format!("Tool: {}", Self::NAME));
let result = self.inner.call(args).await;
timer.finish();
match &result {
Ok(output) => {
let output_str = format!("{:?}", output);
debug::debug_tool_response(
Self::NAME,
&output_str,
std::time::Duration::from_secs(0), );
}
Err(e) => {
debug::debug_error(&format!("Tool {} failed: {:?}", Self::NAME, e));
}
}
result
}
}