use crate::message::Context;
use crate::runtime::{is_tool_request, tool_response, MessageHandler, Result};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
pub struct ToolInterceptor {
original_handler: Arc<RwLock<Box<dyn MessageHandler>>>,
has_tool: bool,
}
impl ToolInterceptor {
pub fn new(handler: Box<dyn MessageHandler>, has_tool: bool) -> Self {
Self {
original_handler: Arc::new(RwLock::new(handler)),
has_tool,
}
}
pub fn has_tool(&self) -> bool {
self.has_tool
}
async fn handle_tool_request(&self, ctx: &mut Context) -> Result<()> {
let result = {
let mut handler = self.original_handler.write().await;
handler.on_message(ctx).await
};
if !has_tool_response_been_sent(ctx) {
if let Err(ref e) = result {
tool_response(ctx, "error", None, &e.to_string()).await?;
} else {
let output_data = self.collect_output_variables(ctx);
tool_response(ctx, "success", Some(output_data), "").await?;
}
}
result
}
fn collect_output_variables(
&self,
_ctx: &Context,
) -> std::collections::HashMap<String, serde_json::Value> {
let mut output_data = std::collections::HashMap::new();
output_data.insert("status".to_string(), serde_json::json!("completed"));
output_data
}
}
#[async_trait]
impl MessageHandler for ToolInterceptor {
async fn on_create(&mut self) -> Result<()> {
let mut handler = self.original_handler.write().await;
handler.on_create().await
}
async fn on_message(&mut self, ctx: &mut Context) -> Result<()> {
if self.has_tool && is_tool_request(ctx) {
return self.handle_tool_request(ctx).await;
}
let mut handler = self.original_handler.write().await;
handler.on_message(ctx).await
}
async fn on_close(&mut self) -> Result<()> {
let mut handler = self.original_handler.write().await;
handler.on_close().await
}
}
fn has_tool_response_been_sent(ctx: &Context) -> bool {
ctx.is_empty() || ctx.get_raw().is_err()
}
pub fn wrap_with_tool_interceptor(
handler: Box<dyn MessageHandler>,
has_tool: bool,
) -> Box<dyn MessageHandler> {
if has_tool {
Box::new(ToolInterceptor::new(handler, true))
} else {
handler
}
}