distri_types/
ui_tool_render.rs1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::{collections::HashMap, sync::Arc};
5
6use crate::{Part, ToolCall, ToolResponse};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct ToolUiMessage {
11 pub message_type: ToolUiMessageType,
12 pub parts: Vec<Part>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(rename_all = "snake_case")]
18pub enum ToolUiMessageType {
19 ToolStart,
21 ToolEnd,
23 ToolError,
25 ToolProgress,
27}
28
29#[derive(Debug, Clone)]
31pub struct ToolUiContext {
32 pub tool_call: ToolCall,
33 pub tool_response: Option<ToolResponse>,
34 pub error: Option<String>,
35 pub progress_info: Option<Value>,
36}
37
38pub trait UiToolRender: Send + Sync + std::fmt::Debug {
40 fn get_tool_name(&self) -> String;
42
43 fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
45
46 fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
48
49 fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
51
52 fn render_tool_progress(&self, _context: &ToolUiContext) -> Result<Option<ToolUiMessage>> {
54 Ok(None) }
56
57 fn supports_progress(&self) -> bool {
59 false }
61}
62
63#[derive(Debug, Default)]
65pub struct ToolUiRenderRegistry {
66 renderers: HashMap<String, Arc<dyn UiToolRender>>,
67}
68
69impl ToolUiRenderRegistry {
70 pub fn new() -> Self {
72 Self {
73 renderers: HashMap::new(),
74 }
75 }
76
77 pub fn register(&mut self, tool_name: String, renderer: Arc<dyn UiToolRender>) {
79 tracing::debug!("Registering UI renderer for tool: {}", tool_name);
80 self.renderers.insert(tool_name, renderer);
81 }
82
83 pub fn get_renderer(&self, tool_name: &str) -> Option<&Arc<dyn UiToolRender>> {
85 self.renderers.get(tool_name)
86 }
87
88 pub fn render_tool_start(&self, tool_call: &ToolCall) -> Result<ToolUiMessage> {
90 let context = ToolUiContext {
91 tool_call: tool_call.clone(),
92 tool_response: None,
93 error: None,
94 progress_info: None,
95 };
96
97 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
98 renderer.render_tool_start(&context)
99 } else {
100 DefaultToolRenderer.render_tool_start(&context)
102 }
103 }
104
105 pub fn render_tool_end(
107 &self,
108 tool_call: &ToolCall,
109 tool_response: &ToolResponse,
110 ) -> Result<ToolUiMessage> {
111 let context = ToolUiContext {
112 tool_call: tool_call.clone(),
113 tool_response: Some(tool_response.clone()),
114 error: None,
115 progress_info: None,
116 };
117
118 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
119 renderer.render_tool_end(&context)
120 } else {
121 DefaultToolRenderer.render_tool_end(&context)
123 }
124 }
125
126 pub fn render_tool_error(
128 &self,
129 tool_call: &ToolCall,
130 error: &anyhow::Error,
131 ) -> Result<ToolUiMessage> {
132 let context = ToolUiContext {
133 tool_call: tool_call.clone(),
134 tool_response: None,
135 error: Some(error.to_string()),
136 progress_info: None,
137 };
138
139 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
140 renderer.render_tool_error(&context)
141 } else {
142 DefaultToolRenderer.render_tool_error(&context)
144 }
145 }
146
147 pub fn render_tool_progress(
149 &self,
150 tool_call: &ToolCall,
151 progress_info: Value,
152 ) -> Result<Option<ToolUiMessage>> {
153 let context = ToolUiContext {
154 tool_call: tool_call.clone(),
155 tool_response: None,
156 error: None,
157 progress_info: Some(progress_info),
158 };
159
160 if let Some(renderer) = self.get_renderer(&tool_call.tool_name)
161 && renderer.supports_progress() {
162 return renderer.render_tool_progress(&context);
163 }
164
165 Ok(None) }
167
168 pub fn list_registered_tools(&self) -> Vec<String> {
170 self.renderers.keys().cloned().collect()
171 }
172}
173
174#[derive(Debug)]
176pub struct DefaultToolRenderer;
177
178impl UiToolRender for DefaultToolRenderer {
179 fn get_tool_name(&self) -> String {
180 "default".to_string()
181 }
182
183 fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
184 let formatted_input =
185 if context.tool_call.input.is_object() || context.tool_call.input.is_array() {
186 serde_json::to_string_pretty(&context.tool_call.input)?
187 } else {
188 context.tool_call.input.to_string()
189 };
190
191 let message = format!(
192 "🔧 **{}**\n\n```json\n{}\n```",
193 context.tool_call.tool_name, formatted_input
194 );
195
196 Ok(ToolUiMessage {
197 message_type: ToolUiMessageType::ToolStart,
198 parts: vec![Part::Text(message)],
199 })
200 }
201
202 fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
203 let tool_response = context
204 .tool_response
205 .as_ref()
206 .ok_or_else(|| anyhow::anyhow!("Tool response required for tool_end message"))?;
207
208 let message = format!("✅ **{}** completed", context.tool_call.tool_name);
209
210 let mut parts = vec![Part::Text(message)];
211
212 parts.extend(tool_response.parts.clone());
214
215 Ok(ToolUiMessage {
216 message_type: ToolUiMessageType::ToolEnd,
217 parts,
218 })
219 }
220
221 fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
222 let error_msg = context
223 .error
224 .as_ref()
225 .cloned()
226 .unwrap_or_else(|| "Unknown error".to_string());
227
228 let message = format!(
229 "❌ **{}** failed\n\n```\n{}\n```",
230 context.tool_call.tool_name, error_msg
231 );
232
233 Ok(ToolUiMessage {
234 message_type: ToolUiMessageType::ToolError,
235 parts: vec![Part::Text(message)],
236 })
237 }
238}
239
240pub fn create_default_registry() -> ToolUiRenderRegistry {
242 let mut registry = ToolUiRenderRegistry::new();
243
244 crate::ui_tool_renderers::register_common_renderers(&mut registry);
246
247 registry
248}