distri_types/
ui_tool_render.rs1use anyhow::Result;
2use serde::{Deserialize, Serialize};
3use serde_json::Value;
4use std::{collections::HashMap, sync::Arc};
5
6use crate::{Part, ToolCall, ToolResponse};
7
8#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
10pub struct ToolUiMessage {
11 pub message_type: ToolUiMessageType,
12 pub parts: Vec<Part>,
13}
14
15#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
17#[serde(rename_all = "snake_case")]
18pub enum ToolUiMessageType {
19 ToolStart,
21 ToolEnd,
23 ToolError,
25 ToolProgress,
27}
28
29#[derive(Debug, Clone)]
31pub struct ToolUiContext {
32 pub tool_call: ToolCall,
33 pub tool_response: Option<ToolResponse>,
34 pub error: Option<String>,
35 pub progress_info: Option<Value>,
36}
37
38pub trait UiToolRender: Send + Sync + std::fmt::Debug {
40 fn get_tool_name(&self) -> String;
42
43 fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
45
46 fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
48
49 fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage>;
51
52 fn render_tool_progress(&self, _context: &ToolUiContext) -> Result<Option<ToolUiMessage>> {
54 Ok(None) }
56
57 fn supports_progress(&self) -> bool {
59 false }
61}
62
63#[derive(Debug, Default)]
65pub struct ToolUiRenderRegistry {
66 renderers: HashMap<String, Arc<dyn UiToolRender>>,
67}
68
69impl ToolUiRenderRegistry {
70 pub fn new() -> Self {
72 Self {
73 renderers: HashMap::new(),
74 }
75 }
76
77 pub fn register(&mut self, tool_name: String, renderer: Arc<dyn UiToolRender>) {
79 tracing::debug!("Registering UI renderer for tool: {}", tool_name);
80 self.renderers.insert(tool_name, renderer);
81 }
82
83 pub fn get_renderer(&self, tool_name: &str) -> Option<&Arc<dyn UiToolRender>> {
85 self.renderers.get(tool_name)
86 }
87
88 pub fn render_tool_start(&self, tool_call: &ToolCall) -> Result<ToolUiMessage> {
90 let context = ToolUiContext {
91 tool_call: tool_call.clone(),
92 tool_response: None,
93 error: None,
94 progress_info: None,
95 };
96
97 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
98 renderer.render_tool_start(&context)
99 } else {
100 DefaultToolRenderer.render_tool_start(&context)
102 }
103 }
104
105 pub fn render_tool_end(
107 &self,
108 tool_call: &ToolCall,
109 tool_response: &ToolResponse,
110 ) -> Result<ToolUiMessage> {
111 let context = ToolUiContext {
112 tool_call: tool_call.clone(),
113 tool_response: Some(tool_response.clone()),
114 error: None,
115 progress_info: None,
116 };
117
118 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
119 renderer.render_tool_end(&context)
120 } else {
121 DefaultToolRenderer.render_tool_end(&context)
123 }
124 }
125
126 pub fn render_tool_error(
128 &self,
129 tool_call: &ToolCall,
130 error: &anyhow::Error,
131 ) -> Result<ToolUiMessage> {
132 let context = ToolUiContext {
133 tool_call: tool_call.clone(),
134 tool_response: None,
135 error: Some(error.to_string()),
136 progress_info: None,
137 };
138
139 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
140 renderer.render_tool_error(&context)
141 } else {
142 DefaultToolRenderer.render_tool_error(&context)
144 }
145 }
146
147 pub fn render_tool_progress(
149 &self,
150 tool_call: &ToolCall,
151 progress_info: Value,
152 ) -> Result<Option<ToolUiMessage>> {
153 let context = ToolUiContext {
154 tool_call: tool_call.clone(),
155 tool_response: None,
156 error: None,
157 progress_info: Some(progress_info),
158 };
159
160 if let Some(renderer) = self.get_renderer(&tool_call.tool_name) {
161 if renderer.supports_progress() {
162 return renderer.render_tool_progress(&context);
163 }
164 }
165
166 Ok(None) }
168
169 pub fn list_registered_tools(&self) -> Vec<String> {
171 self.renderers.keys().cloned().collect()
172 }
173}
174
175#[derive(Debug)]
177pub struct DefaultToolRenderer;
178
179impl UiToolRender for DefaultToolRenderer {
180 fn get_tool_name(&self) -> String {
181 "default".to_string()
182 }
183
184 fn render_tool_start(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
185 let formatted_input =
186 if context.tool_call.input.is_object() || context.tool_call.input.is_array() {
187 serde_json::to_string_pretty(&context.tool_call.input)?
188 } else {
189 context.tool_call.input.to_string()
190 };
191
192 let message = format!(
193 "🔧 **{}**\n\n```json\n{}\n```",
194 context.tool_call.tool_name, formatted_input
195 );
196
197 Ok(ToolUiMessage {
198 message_type: ToolUiMessageType::ToolStart,
199 parts: vec![Part::Text(message)],
200 })
201 }
202
203 fn render_tool_end(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
204 let tool_response = context
205 .tool_response
206 .as_ref()
207 .ok_or_else(|| anyhow::anyhow!("Tool response required for tool_end message"))?;
208
209 let message = format!("✅ **{}** completed", context.tool_call.tool_name);
210
211 let mut parts = vec![Part::Text(message)];
212
213 parts.extend(tool_response.parts.clone());
215
216 Ok(ToolUiMessage {
217 message_type: ToolUiMessageType::ToolEnd,
218 parts,
219 })
220 }
221
222 fn render_tool_error(&self, context: &ToolUiContext) -> Result<ToolUiMessage> {
223 let error_msg = context
224 .error
225 .as_ref()
226 .cloned()
227 .unwrap_or_else(|| "Unknown error".to_string());
228
229 let message = format!(
230 "❌ **{}** failed\n\n```\n{}\n```",
231 context.tool_call.tool_name, error_msg
232 );
233
234 Ok(ToolUiMessage {
235 message_type: ToolUiMessageType::ToolError,
236 parts: vec![Part::Text(message)],
237 })
238 }
239}
240
241pub fn create_default_registry() -> ToolUiRenderRegistry {
243 let mut registry = ToolUiRenderRegistry::new();
244
245 crate::ui_tool_renderers::register_common_renderers(&mut registry);
247
248 registry
249}