use anyhow::Result;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::{collections::HashMap, sync::Arc};
use crate::{Part, ToolCall, ToolResponse};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ToolUiMessage {
pub message_type: ToolUiMessageType,
pub parts: Vec<Part>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum ToolUiMessageType {
ToolStart,
ToolEnd,
ToolError,
ToolProgress,
}
#[derive(Debug, Clone)]
pub struct ToolUiContext {
pub tool_call: ToolCall,
pub tool_response: Option<ToolResponse>,
pub error: Option<String>,
pub progress_info: Option<Value>,
}
pub trait UiToolRender: Send + Sync + std::fmt::Debug {
fn get_tool_name(&self) -> String;
fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
fn render_tool_progress(&self, _context: &ToolUiContext) -> Result<Option<ToolUiMessage>> {
Ok(None) }
fn supports_progress(&self) -> bool {
false }
}
#[derive(Debug, Default)]
pub struct ToolUiRenderRegistry {
renderers: HashMap<String, Arc<dyn UiToolRender>>,
}
impl ToolUiRenderRegistry {
pub fn new() -> Self {
Self {
renderers: HashMap::new(),
}
}
pub fn register(&mut self, tool_name: String, renderer: Arc<dyn UiToolRender>) {
tracing::debug!("Registering UI renderer for tool: {}", tool_name);
self.renderers.insert(tool_name, renderer);
}
pub fn get_renderer(&self, tool_name: &str) -> Option<&Arc<dyn UiToolRender>> {
self.renderers.get(tool_name)
}
pub fn render_tool_start(&self, tool_call: &ToolCall) -> Result<ToolUiMessage> {
let context = ToolUiContext {
tool_call: tool_call.clone(),
tool_response: None,
error: None,
progress_info: None,
};
if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
renderer.render_tool_start(&context)
} else {
DefaultToolRenderer.render_tool_start(&context)
}
}
pub fn render_tool_end(
&self,
tool_call: &ToolCall,
tool_response: &ToolResponse,
) -> Result<ToolUiMessage> {
let context = ToolUiContext {
tool_call: tool_call.clone(),
tool_response: Some(tool_response.clone()),
error: None,
progress_info: None,
};
if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
renderer.render_tool_end(&context)
} else {
DefaultToolRenderer.render_tool_end(&context)
}
}
pub fn render_tool_error(
&self,
tool_call: &ToolCall,
error: &anyhow::Error,
) -> Result<ToolUiMessage> {
let context = ToolUiContext {
tool_call: tool_call.clone(),
tool_response: None,
error: Some(error.to_string()),
progress_info: None,
};
if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
renderer.render_tool_error(&context)
} else {
DefaultToolRenderer.render_tool_error(&context)
}
}
pub fn render_tool_progress(
&self,
tool_call: &ToolCall,
progress_info: Value,
) -> Result<Option<ToolUiMessage>> {
let context = ToolUiContext {
tool_call: tool_call.clone(),
tool_response: None,
error: None,
progress_info: Some(progress_info),
};
if let Some(renderer) = self.get_renderer(&tool_call.tool_name)
&& renderer.supports_progress()
{
return renderer.render_tool_progress(&context);
}
Ok(None) }
pub fn list_registered_tools(&self) -> Vec<String> {
self.renderers.keys().cloned().collect()
}
}
#[derive(Debug)]
pub struct DefaultToolRenderer;
impl UiToolRender for DefaultToolRenderer {
fn get_tool_name(&self) -> String {
"default".to_string()
}
fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
let formatted_input =
if context.tool_call.input.is_object() || context.tool_call.input.is_array() {
serde_json::to_string_pretty(&context.tool_call.input)?
} else {
context.tool_call.input.to_string()
};
let message = format!(
"🔧 **{}**\n\n```json\n{}\n```",
context.tool_call.tool_name, formatted_input
);
Ok(ToolUiMessage {
message_type: ToolUiMessageType::ToolStart,
parts: vec![Part::Text(message)],
})
}
fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
let tool_response = context
.tool_response
.as_ref()
.ok_or_else(|| anyhow::anyhow!("Tool response required for tool_end message"))?;
let message = format!("✅ **{}** completed", context.tool_call.tool_name);
let mut parts = vec![Part::Text(message)];
parts.extend(tool_response.parts.clone());
Ok(ToolUiMessage {
message_type: ToolUiMessageType::ToolEnd,
parts,
})
}
fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
let error_msg = context
.error
.as_ref()
.cloned()
.unwrap_or_else(|| "Unknown error".to_string());
let message = format!(
"❌ **{}** failed\n\n```\n{}\n```",
context.tool_call.tool_name, error_msg
);
Ok(ToolUiMessage {
message_type: ToolUiMessageType::ToolError,
parts: vec![Part::Text(message)],
})
}
}
pub fn create_default_registry() -> ToolUiRenderRegistry {
let mut registry = ToolUiRenderRegistry::new();
crate::ui_tool_renderers::register_common_renderers(&mut registry);
registry
}