Skip to main content

adk_ui/tools/
render_kit.rs

1use crate::kit::{KitArtifacts, KitGenerator, KitSpec};
2use crate::tools::LegacyProtocolOptions;
3use adk_core::{Result, Tool, ToolContext};
4use async_trait::async_trait;
5use schemars::JsonSchema;
6use serde::{Deserialize, Serialize};
7use serde_json::{Value, json};
8use std::sync::Arc;
9
10/// Parameters for the render_kit tool.
11#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
12pub struct RenderKitParams {
13    #[serde(flatten)]
14    pub spec: KitSpec,
15    /// Optional output format; "json" (default) or "catalog_only"
16    #[serde(default)]
17    pub output: Option<String>,
18    /// Optional protocol output configuration.
19    #[serde(flatten)]
20    pub protocol: LegacyProtocolOptions,
21}
22
23/// Tool for generating a UI kit (catalog + tokens + templates + theme).
24pub struct RenderKitTool;
25
26impl RenderKitTool {
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32impl Default for RenderKitTool {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38#[async_trait]
39impl Tool for RenderKitTool {
40    fn name(&self) -> &str {
41        "render_kit"
42    }
43
44    fn description(&self) -> &str {
45        "Generate a UI kit from a KitSpec. Returns catalog, tokens, templates, and theme CSS."
46    }
47
48    fn parameters_schema(&self) -> Option<Value> {
49        Some(super::generate_gemini_schema::<RenderKitParams>())
50    }
51
52    async fn execute(&self, _ctx: Arc<dyn ToolContext>, args: Value) -> Result<Value> {
53        let params: RenderKitParams = serde_json::from_value(args.clone()).map_err(|e| {
54            adk_core::AdkError::Tool(format!("Invalid parameters: {}. Got: {}", e, args))
55        })?;
56
57        let generator = KitGenerator::new();
58        let artifacts = generator.generate(&params.spec);
59        let payload = format_output(&artifacts, params.output.as_deref());
60
61        Ok(match params.protocol.protocol {
62            Some(protocol) => {
63                let protocol = serde_json::to_value(protocol).unwrap_or_else(|_| json!("a2ui"));
64                json!({
65                    "protocol": protocol,
66                    "surface_id": params.protocol.resolved_surface_id("kit"),
67                    "payload": payload
68                })
69            }
70            None => payload,
71        })
72    }
73}
74
75fn format_output(artifacts: &KitArtifacts, output: Option<&str>) -> Value {
76    match output {
77        Some("catalog_only") => artifacts.catalog.clone(),
78        _ => json!({
79            "catalog": artifacts.catalog,
80            "tokens": artifacts.tokens,
81            "templates": artifacts.templates,
82            "theme_css": artifacts.theme_css,
83        }),
84    }
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use adk_core::{Content, EventActions, ReadonlyContext};
91    use async_trait::async_trait;
92    use std::sync::{Arc, Mutex};
93
94    struct TestContext {
95        content: Content,
96        actions: Mutex<EventActions>,
97    }
98
99    impl TestContext {
100        fn new() -> Self {
101            Self { content: Content::new("user"), actions: Mutex::new(EventActions::default()) }
102        }
103    }
104
105    #[async_trait]
106    impl ReadonlyContext for TestContext {
107        fn invocation_id(&self) -> &str {
108            "test"
109        }
110        fn agent_name(&self) -> &str {
111            "test"
112        }
113        fn user_id(&self) -> &str {
114            "user"
115        }
116        fn app_name(&self) -> &str {
117            "app"
118        }
119        fn session_id(&self) -> &str {
120            "session"
121        }
122        fn branch(&self) -> &str {
123            ""
124        }
125        fn user_content(&self) -> &Content {
126            &self.content
127        }
128    }
129
130    #[async_trait]
131    impl adk_core::CallbackContext for TestContext {
132        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
133            None
134        }
135    }
136
137    #[async_trait]
138    impl ToolContext for TestContext {
139        fn function_call_id(&self) -> &str {
140            "call-123"
141        }
142        fn actions(&self) -> EventActions {
143            self.actions.lock().unwrap().clone()
144        }
145        fn set_actions(&self, actions: EventActions) {
146            *self.actions.lock().unwrap() = actions;
147        }
148        async fn search_memory(&self, _query: &str) -> Result<Vec<adk_core::MemoryEntry>> {
149            Ok(vec![])
150        }
151    }
152
153    #[tokio::test]
154    async fn render_kit_emits_catalog() {
155        let tool = RenderKitTool::new();
156        let args = serde_json::json!({
157            "name": "Fintech Pro",
158            "version": "0.1.0",
159            "brand": { "vibe": "trustworthy", "industry": "fintech" },
160            "colors": { "primary": "#2F6BFF" },
161            "typography": { "family": "Source Sans 3" },
162            "templates": ["auth_login"]
163        });
164
165        let ctx: Arc<dyn ToolContext> = Arc::new(TestContext::new());
166        let value = tool.execute(ctx, args).await.unwrap();
167        assert!(value.get("catalog").is_some());
168        assert!(value.get("tokens").is_some());
169    }
170
171    #[tokio::test]
172    async fn render_kit_emits_protocol_envelope() {
173        let tool = RenderKitTool::new();
174        let args = serde_json::json!({
175            "name": "Fintech Pro",
176            "version": "0.1.0",
177            "brand": { "vibe": "trustworthy", "industry": "fintech" },
178            "colors": { "primary": "#2F6BFF" },
179            "typography": { "family": "Source Sans 3" },
180            "protocol": "mcp_apps"
181        });
182
183        let ctx: Arc<dyn ToolContext> = Arc::new(TestContext::new());
184        let value = tool.execute(ctx, args).await.unwrap();
185        assert_eq!(value["protocol"], "mcp_apps");
186        assert!(value["payload"]["catalog"].is_object());
187    }
188}