Skip to main content

a2a_rs/domain/core/
agent.rs

1use std::collections::HashMap;
2
3// Re-export generated types so downstream code gets them from `domain::core::agent`
4pub use crate::domain::generated::{
5    APIKeySecurityScheme, AgentCapabilities, AgentCard, AgentCardSignature, AgentExtension,
6    AgentInterface, AgentProvider, AgentSkill, AuthenticationInfo, AuthorizationCodeOAuthFlow,
7    ClientCredentialsOAuthFlow, DeviceCodeOAuthFlow, HTTPAuthSecurityScheme, ImplicitOAuthFlow,
8    MutualTlsSecurityScheme, OAuth2SecurityScheme, OAuthFlows, OpenIdConnectSecurityScheme,
9    PasswordOAuthFlow, SecurityRequirement, SecurityScheme, StringList, o_auth_flows,
10    security_scheme,
11};
12
13pub type PushNotificationAuthenticationInfo = AuthenticationInfo;
14
15impl AgentSkill {
16    /// Create a new skill with the minimum required fields
17    pub fn new(id: String, name: String, description: String, tags: Vec<String>) -> Self {
18        Self {
19            id,
20            name,
21            description,
22            tags,
23            ..Default::default()
24        }
25    }
26
27    /// Add examples to the skill
28    pub fn with_examples(mut self, examples: Vec<String>) -> Self {
29        self.examples = examples;
30        self
31    }
32
33    /// Add input modes to the skill
34    pub fn with_input_modes(mut self, input_modes: Vec<String>) -> Self {
35        self.input_modes = input_modes;
36        self
37    }
38
39    /// Add output modes to the skill
40    pub fn with_output_modes(mut self, output_modes: Vec<String>) -> Self {
41        self.output_modes = output_modes;
42        self
43    }
44
45    /// Add security requirements to the skill
46    pub fn with_security(mut self, security: Vec<HashMap<String, Vec<String>>>) -> Self {
47        self.security_requirements = security
48            .into_iter()
49            .map(|req| {
50                let schemes = req
51                    .into_iter()
52                    .map(|(k, v)| {
53                        (
54                            k,
55                            StringList {
56                                list: v,
57                                ..Default::default()
58                            },
59                        )
60                    })
61                    .collect();
62                SecurityRequirement {
63                    schemes,
64                    ..Default::default()
65                }
66            })
67            .collect();
68        self
69    }
70
71    /// Create a comprehensive skill with all details in one call
72    #[allow(clippy::too_many_arguments)]
73    pub fn comprehensive(
74        id: String,
75        name: String,
76        description: String,
77        tags: Vec<String>,
78        examples: Option<Vec<String>>,
79        input_modes: Option<Vec<String>>,
80        output_modes: Option<Vec<String>>,
81        security: Option<Vec<HashMap<String, Vec<String>>>>,
82    ) -> Self {
83        let mut skill = Self::new(id, name, description, tags);
84        if let Some(ex) = examples {
85            skill = skill.with_examples(ex);
86        }
87        if let Some(im) = input_modes {
88            skill = skill.with_input_modes(im);
89        }
90        if let Some(om) = output_modes {
91            skill = skill.with_output_modes(om);
92        }
93        if let Some(sec) = security {
94            skill = skill.with_security(sec);
95        }
96        skill
97    }
98}
99
100impl SecurityScheme {
101    pub fn api_key(name: String, location: String, description: Option<String>) -> Self {
102        Self {
103            scheme: Some(security_scheme::Scheme::ApiKeySecurityScheme(Box::new(
104                APIKeySecurityScheme {
105                    name,
106                    location,
107                    description: description.unwrap_or_default(),
108                    ..Default::default()
109                },
110            ))),
111            ..Default::default()
112        }
113    }
114
115    pub fn http(
116        scheme_name: String,
117        bearer_format: Option<String>,
118        description: Option<String>,
119    ) -> Self {
120        Self {
121            scheme: Some(security_scheme::Scheme::HttpAuthSecurityScheme(Box::new(
122                HTTPAuthSecurityScheme {
123                    scheme: scheme_name,
124                    bearer_format: bearer_format.unwrap_or_default(),
125                    description: description.unwrap_or_default(),
126                    ..Default::default()
127                },
128            ))),
129            ..Default::default()
130        }
131    }
132
133    pub fn oauth2(
134        flows: OAuthFlows,
135        description: Option<String>,
136        oauth2_metadata_url: Option<String>,
137    ) -> Self {
138        Self {
139            scheme: Some(security_scheme::Scheme::Oauth2SecurityScheme(Box::new(
140                OAuth2SecurityScheme {
141                    flows: ::buffa::MessageField::some(flows),
142                    description: description.unwrap_or_default(),
143                    oauth2_metadata_url: oauth2_metadata_url.unwrap_or_default(),
144                    ..Default::default()
145                },
146            ))),
147            ..Default::default()
148        }
149    }
150
151    pub fn open_id_connect(open_id_connect_url: String, description: Option<String>) -> Self {
152        Self {
153            scheme: Some(security_scheme::Scheme::OpenIdConnectSecurityScheme(
154                Box::new(OpenIdConnectSecurityScheme {
155                    open_id_connect_url,
156                    description: description.unwrap_or_default(),
157                    ..Default::default()
158                }),
159            )),
160            ..Default::default()
161        }
162    }
163
164    pub fn mutual_tls(description: Option<String>) -> Self {
165        Self {
166            scheme: Some(security_scheme::Scheme::MtlsSecurityScheme(Box::new(
167                MutualTlsSecurityScheme {
168                    description: description.unwrap_or_default(),
169                    ..Default::default()
170                },
171            ))),
172            ..Default::default()
173        }
174    }
175}
176
177impl OAuthFlows {
178    pub fn authorization_code(flow: AuthorizationCodeOAuthFlow) -> Self {
179        Self {
180            flow: Some(o_auth_flows::Flow::AuthorizationCode(Box::new(flow))),
181            ..Default::default()
182        }
183    }
184
185    pub fn client_credentials(flow: ClientCredentialsOAuthFlow) -> Self {
186        Self {
187            flow: Some(o_auth_flows::Flow::ClientCredentials(Box::new(flow))),
188            ..Default::default()
189        }
190    }
191
192    pub fn device_code(flow: DeviceCodeOAuthFlow) -> Self {
193        Self {
194            flow: Some(o_auth_flows::Flow::DeviceCode(Box::new(flow))),
195            ..Default::default()
196        }
197    }
198}
199
200impl AgentCapabilities {
201    pub fn streaming(&self) -> bool {
202        self.streaming.unwrap_or(false)
203    }
204
205    pub fn push_notifications(&self) -> bool {
206        self.push_notifications.unwrap_or(false)
207    }
208
209    pub fn extended_agent_card(&self) -> bool {
210        self.extended_agent_card.unwrap_or(false)
211    }
212}
213
214impl AgentCardSignature {
215    pub fn new(
216        protected: String,
217        signature: String,
218        header: Option<::buffa_types::google::protobuf::Struct>,
219    ) -> Self {
220        Self {
221            protected,
222            signature,
223            header: header.into(),
224            ..Default::default()
225        }
226    }
227}
228
229impl AgentCard {
230    pub fn builder() -> AgentCardBuilder {
231        AgentCardBuilder::new()
232    }
233
234    pub fn url(&self) -> &str {
235        self.supported_interfaces
236            .first()
237            .map(|i| i.url.as_str())
238            .unwrap_or("")
239    }
240
241    pub fn protocol_version(&self) -> &str {
242        self.supported_interfaces
243            .first()
244            .map(|i| i.protocol_version.as_str())
245            .unwrap_or("1.0")
246    }
247
248    pub fn preferred_transport(&self) -> &str {
249        self.supported_interfaces
250            .first()
251            .map(|i| i.protocol_binding.as_str())
252            .unwrap_or("JSONRPC")
253    }
254
255    pub fn supports_extended_agent_card(&self) -> bool {
256        self.capabilities.extended_agent_card.unwrap_or(false)
257    }
258}
259
260pub struct AgentCardBuilder {
261    name: String,
262    description: String,
263    url: String,
264    provider: Option<AgentProvider>,
265    version: String,
266    protocol_version: Option<String>,
267    preferred_transport: Option<String>,
268    supported_interfaces: Vec<AgentInterface>,
269    icon_url: Option<String>,
270    documentation_url: Option<String>,
271    capabilities: Option<AgentCapabilities>,
272    security_schemes: HashMap<String, SecurityScheme>,
273    security_requirements: Vec<SecurityRequirement>,
274    default_input_modes: Vec<String>,
275    default_output_modes: Vec<String>,
276    skills: Vec<AgentSkill>,
277    signatures: Vec<AgentCardSignature>,
278}
279
280impl Default for AgentCardBuilder {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286impl AgentCardBuilder {
287    pub fn new() -> Self {
288        Self {
289            name: String::new(),
290            description: String::new(),
291            url: String::new(),
292            provider: None,
293            version: String::new(),
294            protocol_version: None,
295            preferred_transport: None,
296            supported_interfaces: Vec::new(),
297            icon_url: None,
298            documentation_url: None,
299            capabilities: None,
300            security_schemes: HashMap::new(),
301            security_requirements: Vec::new(),
302            default_input_modes: vec!["text".to_string()],
303            default_output_modes: vec!["text".to_string()],
304            skills: Vec::new(),
305            signatures: Vec::new(),
306        }
307    }
308
309    pub fn name(mut self, name: String) -> Self {
310        self.name = name;
311        self
312    }
313
314    pub fn description(mut self, description: String) -> Self {
315        self.description = description;
316        self
317    }
318
319    pub fn url(mut self, url: String) -> Self {
320        self.url = url;
321        self
322    }
323
324    pub fn provider(mut self, provider: AgentProvider) -> Self {
325        self.provider = Some(provider);
326        self
327    }
328
329    pub fn version(mut self, version: String) -> Self {
330        self.version = version;
331        self
332    }
333
334    pub fn protocol_version(mut self, protocol_version: String) -> Self {
335        self.protocol_version = Some(protocol_version);
336        self
337    }
338
339    pub fn preferred_transport(mut self, preferred_transport: String) -> Self {
340        self.preferred_transport = Some(preferred_transport);
341        self
342    }
343
344    pub fn additional_interfaces(mut self, interfaces: Vec<AgentInterface>) -> Self {
345        self.supported_interfaces.extend(interfaces);
346        self
347    }
348
349    pub fn icon_url(mut self, icon_url: String) -> Self {
350        self.icon_url = Some(icon_url);
351        self
352    }
353
354    pub fn documentation_url(mut self, documentation_url: String) -> Self {
355        self.documentation_url = Some(documentation_url);
356        self
357    }
358
359    pub fn capabilities(mut self, capabilities: AgentCapabilities) -> Self {
360        self.capabilities = Some(capabilities);
361        self
362    }
363
364    pub fn security_schemes(mut self, security_schemes: HashMap<String, SecurityScheme>) -> Self {
365        self.security_schemes = security_schemes;
366        self
367    }
368
369    pub fn security(mut self, security: Vec<HashMap<String, Vec<String>>>) -> Self {
370        self.security_requirements = security
371            .into_iter()
372            .map(|req| {
373                let schemes = req
374                    .into_iter()
375                    .map(|(k, v)| {
376                        (
377                            k,
378                            StringList {
379                                list: v,
380                                ..Default::default()
381                            },
382                        )
383                    })
384                    .collect();
385                SecurityRequirement {
386                    schemes,
387                    ..Default::default()
388                }
389            })
390            .collect();
391        self
392    }
393
394    pub fn default_input_modes(mut self, default_input_modes: Vec<String>) -> Self {
395        self.default_input_modes = default_input_modes;
396        self
397    }
398
399    pub fn default_output_modes(mut self, default_output_modes: Vec<String>) -> Self {
400        self.default_output_modes = default_output_modes;
401        self
402    }
403
404    pub fn skills(mut self, skills: Vec<AgentSkill>) -> Self {
405        self.skills = skills;
406        self
407    }
408
409    pub fn signatures(mut self, signatures: Vec<AgentCardSignature>) -> Self {
410        self.signatures = signatures;
411        self
412    }
413
414    pub fn supports_extended_agent_card(mut self, val: bool) -> Self {
415        let caps = self
416            .capabilities
417            .get_or_insert_with(AgentCapabilities::default);
418        caps.extended_agent_card = Some(val);
419        self
420    }
421
422    pub fn build(self) -> AgentCard {
423        let mut supported_interfaces = self.supported_interfaces;
424        // Make sure the primary interface exists and is first
425        if !self.url.is_empty() {
426            let primary = AgentInterface {
427                url: self.url,
428                protocol_binding: self
429                    .preferred_transport
430                    .unwrap_or_else(|| "JSONRPC".to_string()),
431                protocol_version: self.protocol_version.unwrap_or_else(|| "1.0".to_string()),
432                ..Default::default()
433            };
434            supported_interfaces.insert(0, primary);
435        }
436
437        AgentCard {
438            name: self.name,
439            description: self.description,
440            supported_interfaces,
441            provider: self.provider.into(),
442            version: self.version,
443            documentation_url: self.documentation_url,
444            capabilities: self.capabilities.unwrap_or_default().into(),
445            security_schemes: self.security_schemes,
446            security_requirements: self.security_requirements,
447            default_input_modes: self.default_input_modes,
448            default_output_modes: self.default_output_modes,
449            skills: self.skills,
450            signatures: self.signatures,
451            icon_url: self.icon_url,
452            ..Default::default()
453        }
454    }
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    #[test]
462    fn test_security_scheme_api_key_serialization() {
463        let scheme = SecurityScheme::api_key(
464            "X-API-Key".to_string(),
465            "header".to_string(),
466            Some("API Key authentication".to_string()),
467        );
468
469        let json_value = serde_json::to_value(&scheme).expect("Failed to serialize SecurityScheme");
470        // Verify output matches protobuf JSON mappings
471        assert_eq!(json_value["apiKeySecurityScheme"]["location"], "header");
472        assert_eq!(json_value["apiKeySecurityScheme"]["name"], "X-API-Key");
473    }
474
475    #[test]
476    fn test_security_scheme_http_serialization() {
477        let scheme = SecurityScheme::http(
478            "bearer".to_string(),
479            Some("JWT".to_string()),
480            Some("Bearer token authentication".to_string()),
481        );
482
483        let json_value = serde_json::to_value(&scheme).expect("Failed to serialize SecurityScheme");
484        assert_eq!(json_value["httpAuthSecurityScheme"]["scheme"], "bearer");
485        assert_eq!(json_value["httpAuthSecurityScheme"]["bearerFormat"], "JWT");
486    }
487
488    #[test]
489    fn test_security_scheme_mtls_serialization() {
490        let scheme = SecurityScheme::mutual_tls(Some("Mutual TLS authentication".to_string()));
491
492        let json_value = serde_json::to_value(&scheme).expect("Failed to serialize SecurityScheme");
493        assert_eq!(
494            json_value["mtlsSecurityScheme"]["description"],
495            "Mutual TLS authentication"
496        );
497    }
498
499    #[test]
500    fn test_security_scheme_oauth2_with_metadata() {
501        let flows = OAuthFlows::authorization_code(AuthorizationCodeOAuthFlow {
502            authorization_url: "https://example.com/oauth/authorize".to_string(),
503            token_url: "https://example.com/oauth/token".to_string(),
504            refresh_url: String::new(),
505            scopes: HashMap::new(),
506            ..Default::default()
507        });
508
509        let scheme = SecurityScheme::oauth2(
510            flows,
511            Some("OAuth2 authentication".to_string()),
512            Some("https://example.com/.well-known/oauth-authorization-server".to_string()),
513        );
514
515        let json_value = serde_json::to_value(&scheme).expect("Failed to serialize SecurityScheme");
516        assert_eq!(
517            json_value["oauth2SecurityScheme"]["oauth2MetadataUrl"],
518            "https://example.com/.well-known/oauth-authorization-server"
519        );
520    }
521}