distri_types/
ui_tool_render.rs

1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::{collections::HashMap, sync::Arc};
5
6use crate::{Part, ToolCall, ToolResponse};
7
8/// UI messages that can be generated for tool execution
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct ToolUiMessage {
11    pub message_type: ToolUiMessageType,
12    pub parts: Vec<Part>,
13}
14
15/// Different types of UI messages for tool execution
16#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(rename_all = "snake_case")]
18pub enum ToolUiMessageType {
19    /// Message shown when tool starts executing
20    ToolStart,
21    /// Message shown when tool completes successfully
22    ToolEnd,
23    /// Message shown when tool fails
24    ToolError,
25    /// Progress message during tool execution (for long-running tools)
26    ToolProgress,
27}
28
29/// Context for rendering UI messages
30#[derive(Debug, Clone)]
31pub struct ToolUiContext {
32    pub tool_call: ToolCall,
33    pub tool_response: Option<ToolResponse>,
34    pub error: Option<String>,
35    pub progress_info: Option<Value>,
36}
37
38/// Trait for rendering UI messages for tools
39pub trait UiToolRender: Send + Sync + std::fmt::Debug {
40    /// Get the tool name this renderer handles
41    fn get_tool_name(&self) -> String;
42
43    /// Generate tool start message
44    fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
45
46    /// Generate tool end message
47    fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
48
49    /// Generate tool error message
50    fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
51
52    /// Generate tool progress message (optional, for long-running tools)
53    fn render_tool_progress(&self, _context: &ToolUiContext) -> Result<Option<ToolUiMessage>> {
54        Ok(None) // Default: no progress messages
55    }
56
57    /// Check if this tool supports progress messages
58    fn supports_progress(&self) -> bool {
59        false // Default: no progress support
60    }
61}
62
63/// Registry for tool UI renderers
64#[derive(Debug, Default)]
65pub struct ToolUiRenderRegistry {
66    renderers: HashMap<String, Arc<dyn UiToolRender>>,
67}
68
69impl ToolUiRenderRegistry {
70    /// Create a new registry
71    pub fn new() -> Self {
72        Self {
73            renderers: HashMap::new(),
74        }
75    }
76
77    /// Register a UI renderer for a specific tool
78    pub fn register(&mut self, tool_name: String, renderer: Arc<dyn UiToolRender>) {
79        tracing::debug!("Registering UI renderer for tool: {}", tool_name);
80        self.renderers.insert(tool_name, renderer);
81    }
82
83    /// Get a UI renderer for a specific tool
84    pub fn get_renderer(&self, tool_name: &str) -> Option<&Arc<dyn UiToolRender>> {
85        self.renderers.get(tool_name)
86    }
87
88    /// Render tool start message
89    pub fn render_tool_start(&self, tool_call: &ToolCall) -> Result<ToolUiMessage> {
90        let context = ToolUiContext {
91            tool_call: tool_call.clone(),
92            tool_response: None,
93            error: None,
94            progress_info: None,
95        };
96
97        if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
98            renderer.render_tool_start(&context)
99        } else {
100            // Default rendering
101            DefaultToolRenderer.render_tool_start(&context)
102        }
103    }
104
105    /// Render tool end message
106    pub fn render_tool_end(
107        &self,
108        tool_call: &ToolCall,
109        tool_response: &ToolResponse,
110    ) -> Result<ToolUiMessage> {
111        let context = ToolUiContext {
112            tool_call: tool_call.clone(),
113            tool_response: Some(tool_response.clone()),
114            error: None,
115            progress_info: None,
116        };
117
118        if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
119            renderer.render_tool_end(&context)
120        } else {
121            // Default rendering
122            DefaultToolRenderer.render_tool_end(&context)
123        }
124    }
125
126    /// Render tool error message
127    pub fn render_tool_error(
128        &self,
129        tool_call: &ToolCall,
130        error: &anyhow::Error,
131    ) -> Result<ToolUiMessage> {
132        let context = ToolUiContext {
133            tool_call: tool_call.clone(),
134            tool_response: None,
135            error: Some(error.to_string()),
136            progress_info: None,
137        };
138
139        if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
140            renderer.render_tool_error(&context)
141        } else {
142            // Default rendering
143            DefaultToolRenderer.render_tool_error(&context)
144        }
145    }
146
147    /// Render tool progress message (if supported)
148    pub fn render_tool_progress(
149        &self,
150        tool_call: &ToolCall,
151        progress_info: Value,
152    ) -> Result<Option<ToolUiMessage>> {
153        let context = ToolUiContext {
154            tool_call: tool_call.clone(),
155            tool_response: None,
156            error: None,
157            progress_info: Some(progress_info),
158        };
159
160        if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
161            if renderer.supports_progress() {
162                return renderer.render_tool_progress(&context);
163            }
164        }
165
166        Ok(None) // No progress rendering available
167    }
168
169    /// List all registered tool names
170    pub fn list_registered_tools(&self) -> Vec<String> {
171        self.renderers.keys().cloned().collect()
172    }
173}
174
175/// Default tool renderer for tools without custom rendering
176#[derive(Debug)]
177pub struct DefaultToolRenderer;
178
179impl UiToolRender for DefaultToolRenderer {
180    fn get_tool_name(&self) -> String {
181        "default".to_string()
182    }
183
184    fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
185        let formatted_input =
186            if context.tool_call.input.is_object() || context.tool_call.input.is_array() {
187                serde_json::to_string_pretty(&context.tool_call.input)?
188            } else {
189                context.tool_call.input.to_string()
190            };
191
192        let message = format!(
193            "🔧 **{}**\n\n```json\n{}\n```",
194            context.tool_call.tool_name, formatted_input
195        );
196
197        Ok(ToolUiMessage {
198            message_type: ToolUiMessageType::ToolStart,
199            parts: vec![Part::Text(message)],
200        })
201    }
202
203    fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
204        let tool_response = context
205            .tool_response
206            .as_ref()
207            .ok_or_else(|| anyhow::anyhow!("Tool response required for tool_end message"))?;
208
209        let message = format!("✅ **{}** completed", context.tool_call.tool_name);
210
211        let mut parts = vec![Part::Text(message)];
212
213        // Add tool response parts
214        parts.extend(tool_response.parts.clone());
215
216        Ok(ToolUiMessage {
217            message_type: ToolUiMessageType::ToolEnd,
218            parts,
219        })
220    }
221
222    fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
223        let error_msg = context
224            .error
225            .as_ref()
226            .cloned()
227            .unwrap_or_else(|| "Unknown error".to_string());
228
229        let message = format!(
230            "❌ **{}** failed\n\n```\n{}\n```",
231            context.tool_call.tool_name, error_msg
232        );
233
234        Ok(ToolUiMessage {
235            message_type: ToolUiMessageType::ToolError,
236            parts: vec![Part::Text(message)],
237        })
238    }
239}
240
241/// Helper function to create a global registry instance
242pub fn create_default_registry() -> ToolUiRenderRegistry {
243    let mut registry = ToolUiRenderRegistry::new();
244
245    // Register built-in tool renderers
246    crate::ui_tool_renderers::register_common_renderers(&mut registry);
247
248    registry
249}