Skip to main content

adk_ui/tools/
render_screen.rs

1use crate::a2ui::{A2uiSchemaVersion, A2uiValidator};
2use crate::catalog_registry::CatalogRegistry;
3use crate::interop::{
4    A2uiAdapter, AgUiAdapter, McpAppsAdapter, UiProtocol, UiProtocolAdapter, UiSurface,
5};
6use crate::tools::SurfaceProtocolOptions;
7use adk_core::{Result, Tool, ToolContext};
8use async_trait::async_trait;
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use std::sync::Arc;
13
14fn default_surface_id() -> String {
15    "main".to_string()
16}
17
18fn default_send_data_model() -> bool {
19    true
20}
21
22fn default_validate() -> bool {
23    true
24}
25
26/// Parameters for the render_screen tool.
27#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
28pub struct RenderScreenParams {
29    /// Surface id (default: "main")
30    #[serde(default = "default_surface_id")]
31    pub surface_id: String,
32    /// Catalog id (defaults to the embedded ADK catalog)
33    #[serde(default)]
34    pub catalog_id: Option<String>,
35    /// A2UI component definitions (must include a component with id "root")
36    pub components: Vec<Value>,
37    /// Optional initial data model (sent via updateDataModel at path "/")
38    #[serde(default)]
39    pub data_model: Option<Value>,
40    /// Optional theme object for createSurface
41    #[serde(default)]
42    pub theme: Option<Value>,
43    /// If true, the client should include the data model in action metadata (default: true)
44    #[serde(default = "default_send_data_model")]
45    pub send_data_model: bool,
46    /// Validate generated messages against the A2UI v0.9 schema (default: true)
47    #[serde(default = "default_validate")]
48    pub validate: bool,
49    /// Shared protocol output options.
50    #[serde(flatten)]
51    pub protocol_options: SurfaceProtocolOptions,
52}
53
54/// Tool for emitting A2UI JSONL for a single screen (surface).
55///
56/// This tool wraps a list of A2UI components with the standard envelope messages:
57/// - createSurface
58/// - updateDataModel (optional)
59/// - updateComponents
60pub struct RenderScreenTool;
61
62impl RenderScreenTool {
63    pub fn new() -> Self {
64        Self
65    }
66}
67
68impl Default for RenderScreenTool {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74#[async_trait]
75impl Tool for RenderScreenTool {
76    fn name(&self) -> &str {
77        "render_screen"
78    }
79
80    fn description(&self) -> &str {
81        r#"Emit A2UI JSONL for a single screen (surface). Input must include A2UI component objects with ids, including a root component with id "root".
82Returns a JSONL string with createSurface/updateDataModel/updateComponents messages."#
83    }
84
85    fn parameters_schema(&self) -> Option<Value> {
86        Some(super::generate_gemini_schema::<RenderScreenParams>())
87    }
88
89    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
90        let params: RenderScreenParams = serde_json::from_value(args.clone()).map_err(|e| {
91            adk_core::AdkError::Tool(format!("Invalid parameters: {}. Got: {}", e, args))
92        })?;
93
94        if params.components.is_empty() {
95            return Err(adk_core::AdkError::Tool(
96                "Invalid parameters: components must not be empty.".to_string(),
97            ));
98        }
99
100        let has_root = params.components.iter().any(|component| {
101            component.get("id").and_then(Value::as_str).map(|id| id == "root").unwrap_or(false)
102        });
103
104        if !has_root {
105            return Err(adk_core::AdkError::Tool(
106                "Invalid parameters: components must include a root component with id \"root\"."
107                    .to_string(),
108            ));
109        }
110
111        let registry = CatalogRegistry::new();
112        let catalog_id =
113            params.catalog_id.unwrap_or_else(|| registry.default_catalog_id().to_string());
114
115        let surface =
116            UiSurface::new(params.surface_id.clone(), catalog_id, params.components.clone())
117                .with_data_model(params.data_model.clone())
118                .with_theme(params.theme.clone())
119                .with_send_data_model(params.send_data_model);
120
121        match params.protocol_options.protocol {
122            UiProtocol::A2ui => {
123                let messages = surface.to_a2ui_messages();
124                if params.validate {
125                    let validator = A2uiValidator::new().map_err(|e| {
126                        adk_core::AdkError::Tool(format!(
127                            "Failed to initialize A2UI validator: {}",
128                            e
129                        ))
130                    })?;
131                    for message in &messages {
132                        if let Err(errors) =
133                            validator.validate_message(message, A2uiSchemaVersion::V0_9)
134                        {
135                            let details = errors
136                                .iter()
137                                .map(|err| format!("{} at {}", err.message, err.instance_path))
138                                .collect::<Vec<_>>()
139                                .join("; ");
140                            return Err(adk_core::AdkError::Tool(format!(
141                                "A2UI validation failed: {}",
142                                details
143                            )));
144                        }
145                    }
146                }
147
148                let adapter = A2uiAdapter;
149                let payload = adapter.to_protocol_payload(&surface)?;
150                adapter.validate(&payload)?;
151                Ok(payload)
152            }
153            UiProtocol::AgUi => {
154                let thread_id =
155                    params.protocol_options.resolved_ag_ui_thread_id(&params.surface_id);
156                let run_id = params.protocol_options.resolved_ag_ui_run_id(&params.surface_id);
157                let adapter = AgUiAdapter::new(thread_id, run_id);
158                adapter.to_protocol_payload(&surface)
159            }
160            UiProtocol::McpApps => {
161                let options = params.protocol_options.parse_mcp_options()?;
162                let adapter = McpAppsAdapter::new(options);
163                adapter.to_protocol_payload(&surface)
164            }
165        }
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172    use adk_core::{Content, EventActions, ReadonlyContext};
173    use async_trait::async_trait;
174    use std::sync::{Arc, Mutex};
175
176    struct TestContext {
177        content: Content,
178        actions: Mutex<EventActions>,
179    }
180
181    impl TestContext {
182        fn new() -> Self {
183            Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
184        }
185    }
186
187    #[async_trait]
188    impl ReadonlyContext for TestContext {
189        fn invocation_id(&self) -> &str {
190            "test"
191        }
192        fn agent_name(&self) -> &str {
193            "test"
194        }
195        fn user_id(&self) -> &str {
196            "user"
197        }
198        fn app_name(&self) -> &str {
199            "app"
200        }
201        fn session_id(&self) -> &str {
202            "session"
203        }
204        fn branch(&self) -> &str {
205            ""
206        }
207        fn user_content(&self) -> &Content {
208            &self.content
209        }
210    }
211
212    #[async_trait]
213    impl adk_core::CallbackContext for TestContext {
214        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
215            None
216        }
217    }
218
219    #[async_trait]
220    impl ToolContext for TestContext {
221        fn function_call_id(&self) -> &str {
222            "call-123"
223        }
224        fn actions(&self) -> EventActions {
225            self.actions.lock().unwrap().clone()
226        }
227        fn set_actions(&self, actions: EventActions) {
228            *self.actions.lock().unwrap() = actions;
229        }
230        async fn search_memory(&self, _query: &str) -> Result<Vec<adk_core::MemoryEntry>> {
231            Ok(vec![])
232        }
233    }
234
235    #[tokio::test]
236    async fn render_screen_emits_jsonl() {
237        use crate::a2ui::{column, text};
238
239        let tool = RenderScreenTool::new();
240        let args = serde_json::json!({
241            "components": [
242                text("title", "Hello World", Some("h1")),
243                text("desc", "Welcome", None),
244                column("root", vec!["title", "desc"])
245            ],
246            "data_model": { "title": "Hello" }
247        });
248
249        let ctx: Arc<dyn ToolContext> = Arc::new(TestContext::new());
250        let value = tool.execute(ctx, args).await.unwrap();
251
252        // The tool now returns a JSON object with components, data_model, and jsonl
253        assert!(value.is_object());
254        assert!(value.get("surface_id").is_some());
255        assert!(value.get("components").is_some());
256        assert!(value.get("jsonl").is_some());
257
258        // Verify JSONL is still generated
259        let jsonl = value["jsonl"].as_str().unwrap();
260        let lines: Vec<Value> =
261            jsonl.trim_end().lines().map(|line| serde_json::from_str(line).unwrap()).collect();
262
263        assert_eq!(lines.len(), 3);
264        assert!(lines[0].get("createSurface").is_some());
265        assert!(lines[1].get("updateDataModel").is_some());
266        assert!(lines[2].get("updateComponents").is_some());
267
268        // Verify component structure in the returned JSON
269        let components = value["components"].as_array().unwrap();
270        assert_eq!(components.len(), 3);
271        let root = &components[2];
272        assert_eq!(root["id"], "root");
273        assert_eq!(root["component"], "Column");
274    }
275
276    #[tokio::test]
277    async fn render_screen_emits_ag_ui_events() {
278        use crate::a2ui::{column, text};
279
280        let tool = RenderScreenTool::new();
281        let args = serde_json::json!({
282            "protocol": "ag_ui",
283            "components": [
284                text("title", "Hello World", Some("h1")),
285                column("root", vec!["title"])
286            ]
287        });
288
289        let ctx: Arc<dyn ToolContext> = Arc::new(TestContext::new());
290        let value = tool.execute(ctx, args).await.unwrap();
291
292        assert_eq!(value["protocol"], "ag_ui");
293        let events = value["events"].as_array().unwrap();
294        assert_eq!(events.len(), 3);
295        assert_eq!(events[0]["type"], "RUN_STARTED");
296        assert_eq!(events[1]["type"], "CUSTOM");
297        assert_eq!(events[2]["type"], "RUN_FINISHED");
298    }
299
300    #[tokio::test]
301    async fn render_screen_emits_mcp_apps_payload() {
302        use crate::a2ui::{column, text};
303
304        let tool = RenderScreenTool::new();
305        let args = serde_json::json!({
306            "protocol": "mcp_apps",
307            "components": [
308                text("title", "Hello World", Some("h1")),
309                column("root", vec!["title"])
310            ],
311            "mcp_apps": {
312                "resource_uri": "ui://tests/screen"
313            }
314        });
315
316        let ctx: Arc<dyn ToolContext> = Arc::new(TestContext::new());
317        let value = tool.execute(ctx, args).await.unwrap();
318
319        assert_eq!(value["protocol"], "mcp_apps");
320        assert_eq!(value["payload"]["resource"]["uri"], "ui://tests/screen");
321        assert_eq!(value["payload"]["toolMeta"]["_meta"]["ui"]["resourceUri"], "ui://tests/screen");
322    }
323}