Skip to main content

a2a_rust/types/
security.rs

1use std::collections::BTreeMap;
2
3use serde::{Deserialize, Deserializer, Serialize};
4use serde_json::Value;
5
6/// Wrapper used by proto JSON for repeated string values in maps.
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8#[serde(rename_all = "camelCase")]
9pub struct StringList {
10    #[serde(default, skip_serializing_if = "Vec::is_empty")]
11    /// Ordered string values.
12    pub list: Vec<String>,
13}
14
15/// Security requirement mapping from scheme name to scopes.
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct SecurityRequirement {
19    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
20    /// Required schemes and scope lists.
21    pub schemes: BTreeMap<String, StringList>,
22}
23
24/// Supported security scheme variants.
25#[derive(Debug, Clone, Serialize)]
26#[serde(rename_all = "camelCase")]
27pub enum SecurityScheme {
28    #[serde(rename = "apiKeySecurityScheme")]
29    /// API key security scheme.
30    ApiKeySecurityScheme(ApiKeySecurityScheme),
31    #[serde(rename = "httpAuthSecurityScheme")]
32    /// HTTP auth security scheme.
33    HttpAuthSecurityScheme(HttpAuthSecurityScheme),
34    #[serde(rename = "oauth2SecurityScheme")]
35    /// OAuth 2.0 security scheme.
36    OAuth2SecurityScheme(OAuth2SecurityScheme),
37    #[serde(rename = "openIdConnectSecurityScheme")]
38    /// OpenID Connect discovery scheme.
39    OpenIdConnectSecurityScheme(OpenIdConnectSecurityScheme),
40    #[serde(rename = "mtlsSecurityScheme")]
41    /// Mutual TLS security scheme.
42    MutualTlsSecurityScheme(MutualTlsSecurityScheme),
43}
44
45impl<'de> Deserialize<'de> for SecurityScheme {
46    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47    where
48        D: Deserializer<'de>,
49    {
50        let value = Value::deserialize(deserializer)?;
51        deserialize_security_scheme(value).map_err(serde::de::Error::custom)
52    }
53}
54
55/// API key security scheme definition.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ApiKeySecurityScheme {
59    #[serde(default, skip_serializing_if = "Option::is_none")]
60    /// Optional description for human readers.
61    pub description: Option<String>,
62    /// Location of the API key, such as `header` or `query`.
63    pub location: String,
64    /// Header or parameter name carrying the key.
65    pub name: String,
66}
67
68/// HTTP auth security scheme definition.
69#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct HttpAuthSecurityScheme {
72    #[serde(default, skip_serializing_if = "Option::is_none")]
73    /// Optional description for human readers.
74    pub description: Option<String>,
75    /// Authentication scheme, such as `basic` or `bearer`.
76    pub scheme: String,
77    #[serde(default, skip_serializing_if = "Option::is_none")]
78    /// Optional bearer token format hint.
79    pub bearer_format: Option<String>,
80}
81
82/// OAuth 2.0 security scheme definition.
83#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(rename_all = "camelCase")]
85pub struct OAuth2SecurityScheme {
86    #[serde(default, skip_serializing_if = "Option::is_none")]
87    /// Optional description for human readers.
88    pub description: Option<String>,
89    /// Supported OAuth flow.
90    pub flows: OAuthFlows,
91    #[serde(default, skip_serializing_if = "Option::is_none")]
92    /// Optional metadata discovery URL.
93    pub oauth2_metadata_url: Option<String>,
94}
95
96/// OpenID Connect security scheme definition.
97#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(rename_all = "camelCase")]
99pub struct OpenIdConnectSecurityScheme {
100    #[serde(default, skip_serializing_if = "Option::is_none")]
101    /// Optional description for human readers.
102    pub description: Option<String>,
103    /// OpenID Connect discovery URL.
104    pub open_id_connect_url: String,
105}
106
107/// Mutual TLS security scheme definition.
108#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct MutualTlsSecurityScheme {
111    #[serde(default, skip_serializing_if = "Option::is_none")]
112    /// Optional description for human readers.
113    pub description: Option<String>,
114}
115
116/// Supported OAuth 2.0 flow variants.
117#[derive(Debug, Clone, Serialize)]
118#[serde(rename_all = "camelCase")]
119pub enum OAuthFlows {
120    /// Authorization code flow.
121    AuthorizationCode(AuthorizationCodeOAuthFlow),
122    /// Client credentials flow.
123    ClientCredentials(ClientCredentialsOAuthFlow),
124    /// Implicit flow.
125    Implicit(ImplicitOAuthFlow),
126    /// Resource owner password flow.
127    Password(PasswordOAuthFlow),
128    /// Device code flow.
129    DeviceCode(DeviceCodeOAuthFlow),
130}
131
132impl<'de> Deserialize<'de> for OAuthFlows {
133    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134    where
135        D: Deserializer<'de>,
136    {
137        let value = Value::deserialize(deserializer)?;
138        deserialize_oauth_flows(value).map_err(serde::de::Error::custom)
139    }
140}
141
142/// Authorization code flow settings.
143#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(rename_all = "camelCase")]
145pub struct AuthorizationCodeOAuthFlow {
146    /// Authorization endpoint URL.
147    pub authorization_url: String,
148    /// Token endpoint URL.
149    pub token_url: String,
150    #[serde(default, skip_serializing_if = "Option::is_none")]
151    /// Optional refresh endpoint URL.
152    pub refresh_url: Option<String>,
153    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
154    /// OAuth scopes and their descriptions.
155    pub scopes: BTreeMap<String, String>,
156    #[serde(default, skip_serializing_if = "crate::types::is_false")]
157    /// Whether PKCE is required for this flow.
158    pub pkce_required: bool,
159}
160
161/// Client credentials flow settings.
162#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct ClientCredentialsOAuthFlow {
165    /// Token endpoint URL.
166    pub token_url: String,
167    #[serde(default, skip_serializing_if = "Option::is_none")]
168    /// Optional refresh endpoint URL.
169    pub refresh_url: Option<String>,
170    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
171    /// OAuth scopes and their descriptions.
172    pub scopes: BTreeMap<String, String>,
173}
174
175/// Implicit flow settings.
176#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct ImplicitOAuthFlow {
179    /// Authorization endpoint URL.
180    pub authorization_url: String,
181    #[serde(default, skip_serializing_if = "Option::is_none")]
182    /// Optional refresh endpoint URL.
183    pub refresh_url: Option<String>,
184    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
185    /// OAuth scopes and their descriptions.
186    pub scopes: BTreeMap<String, String>,
187}
188
189/// Password flow settings.
190#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(rename_all = "camelCase")]
192pub struct PasswordOAuthFlow {
193    /// Token endpoint URL.
194    pub token_url: String,
195    #[serde(default, skip_serializing_if = "Option::is_none")]
196    /// Optional refresh endpoint URL.
197    pub refresh_url: Option<String>,
198    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
199    /// OAuth scopes and their descriptions.
200    pub scopes: BTreeMap<String, String>,
201}
202
203/// Device code flow settings.
204#[derive(Debug, Clone, Serialize, Deserialize)]
205#[serde(rename_all = "camelCase")]
206pub struct DeviceCodeOAuthFlow {
207    /// Device authorization endpoint URL.
208    pub device_authorization_url: String,
209    /// Token endpoint URL.
210    pub token_url: String,
211    #[serde(default, skip_serializing_if = "Option::is_none")]
212    /// Optional refresh endpoint URL.
213    pub refresh_url: Option<String>,
214    #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
215    /// OAuth scopes and their descriptions.
216    pub scopes: BTreeMap<String, String>,
217}
218
219fn deserialize_security_scheme(value: Value) -> Result<SecurityScheme, String> {
220    let Value::Object(mut object) = value else {
221        return Err("security scheme must be a JSON object".to_owned());
222    };
223
224    if object.len() == 1 {
225        let (key, value) = object
226            .into_iter()
227            .next()
228            .ok_or_else(|| "security scheme object cannot be empty".to_owned())?;
229        return match key.as_str() {
230            "apiKeySecurityScheme" => {
231                deserialize_variant(value, SecurityScheme::ApiKeySecurityScheme)
232            }
233            "httpAuthSecurityScheme" => {
234                deserialize_variant(value, SecurityScheme::HttpAuthSecurityScheme)
235            }
236            "oauth2SecurityScheme" => {
237                deserialize_variant(value, SecurityScheme::OAuth2SecurityScheme)
238            }
239            "openIdConnectSecurityScheme" => {
240                deserialize_variant(value, SecurityScheme::OpenIdConnectSecurityScheme)
241            }
242            "mtlsSecurityScheme" => {
243                deserialize_variant(value, SecurityScheme::MutualTlsSecurityScheme)
244            }
245            _ => Err(format!("unknown security scheme variant: {key}")),
246        };
247    }
248
249    let type_name = object
250        .remove("type")
251        .and_then(|value| match value {
252            Value::String(value) => Some(value),
253            _ => None,
254        })
255        .ok_or_else(|| "security scheme must contain either a proto oneof tag or a Python SDK 'type' discriminator".to_owned())?;
256
257    match type_name.as_str() {
258        "apiKey" => {
259            if let Some(location) = object.remove("in") {
260                object.insert("location".to_owned(), location);
261            }
262            deserialize_variant(Value::Object(object), SecurityScheme::ApiKeySecurityScheme)
263        }
264        "http" => deserialize_variant(
265            Value::Object(object),
266            SecurityScheme::HttpAuthSecurityScheme,
267        ),
268        "oauth2" => {
269            deserialize_variant(Value::Object(object), SecurityScheme::OAuth2SecurityScheme)
270        }
271        "openIdConnect" => deserialize_variant(
272            Value::Object(object),
273            SecurityScheme::OpenIdConnectSecurityScheme,
274        ),
275        "mutualTLS" | "mutualTls" | "mtls" => deserialize_variant(
276            Value::Object(object),
277            SecurityScheme::MutualTlsSecurityScheme,
278        ),
279        other => Err(format!(
280            "unsupported security scheme type discriminator: {other}"
281        )),
282    }
283}
284
285fn deserialize_oauth_flows(value: Value) -> Result<OAuthFlows, String> {
286    let Value::Object(mut object) = value else {
287        return Err("oauth flows must be a JSON object".to_owned());
288    };
289
290    let mut chosen: Option<(&'static str, Value)> = None;
291    for key in [
292        "authorizationCode",
293        "clientCredentials",
294        "implicit",
295        "password",
296        "deviceCode",
297    ] {
298        match object.remove(key) {
299            Some(Value::Null) | None => {}
300            Some(value) => {
301                if chosen.is_some() {
302                    return Err("oauth flows must contain exactly one flow variant".to_owned());
303                }
304                chosen = Some((key, value));
305            }
306        }
307    }
308
309    if !object.is_empty() {
310        let mut keys = object.keys().cloned().collect::<Vec<_>>();
311        keys.sort();
312        return Err(format!(
313            "oauth flows contained unexpected keys: {}",
314            keys.join(", ")
315        ));
316    }
317
318    let Some((key, value)) = chosen else {
319        return Err("oauth flows must contain exactly one flow variant".to_owned());
320    };
321
322    match key {
323        "authorizationCode" => deserialize_variant(value, OAuthFlows::AuthorizationCode),
324        "clientCredentials" => deserialize_variant(value, OAuthFlows::ClientCredentials),
325        "implicit" => deserialize_variant(value, OAuthFlows::Implicit),
326        "password" => deserialize_variant(value, OAuthFlows::Password),
327        "deviceCode" => deserialize_variant(value, OAuthFlows::DeviceCode),
328        _ => Err(format!("unsupported oauth flow variant: {key}")),
329    }
330}
331
332fn deserialize_variant<T, U>(value: Value, constructor: impl FnOnce(T) -> U) -> Result<U, String>
333where
334    T: serde::de::DeserializeOwned,
335{
336    serde_json::from_value(value)
337        .map(constructor)
338        .map_err(|error| error.to_string())
339}
340
341#[cfg(test)]
342mod tests {
343    use std::collections::BTreeMap;
344
345    use super::{
346        ApiKeySecurityScheme, AuthorizationCodeOAuthFlow, HttpAuthSecurityScheme,
347        OAuth2SecurityScheme, OAuthFlows, OpenIdConnectSecurityScheme, SecurityScheme,
348    };
349
350    #[test]
351    fn security_scheme_serializes_as_externally_tagged_enum() {
352        let scheme = SecurityScheme::ApiKeySecurityScheme(ApiKeySecurityScheme {
353            description: None,
354            location: "header".to_owned(),
355            name: "X-API-Key".to_owned(),
356        });
357
358        let json = serde_json::to_string(&scheme).expect("scheme should serialize");
359        assert_eq!(
360            json,
361            r#"{"apiKeySecurityScheme":{"location":"header","name":"X-API-Key"}}"#
362        );
363    }
364
365    #[test]
366    fn oauth_flows_serializes_with_variant_name() {
367        let mut scopes = BTreeMap::new();
368        scopes.insert("read".to_owned(), "Read access".to_owned());
369
370        let scheme = OAuth2SecurityScheme {
371            description: None,
372            flows: OAuthFlows::AuthorizationCode(AuthorizationCodeOAuthFlow {
373                authorization_url: "https://example.com/authorize".to_owned(),
374                token_url: "https://example.com/token".to_owned(),
375                refresh_url: None,
376                scopes,
377                pkce_required: true,
378            }),
379            oauth2_metadata_url: None,
380        };
381
382        let json = serde_json::to_string(&scheme).expect("oauth2 scheme should serialize");
383        assert!(json.contains(
384            r#""authorizationCode":{"authorizationUrl":"https://example.com/authorize""#
385        ));
386        assert!(json.contains(r#""pkceRequired":true"#));
387    }
388
389    #[test]
390    fn security_scheme_deserializes_python_sdk_api_key_shape() {
391        let json = serde_json::json!({
392            "type": "apiKey",
393            "description": "Header auth",
394            "in": "header",
395            "name": "X-API-Key"
396        });
397
398        let scheme: SecurityScheme =
399            serde_json::from_value(json).expect("scheme should deserialize");
400
401        match &scheme {
402            SecurityScheme::ApiKeySecurityScheme(scheme) => {
403                assert_eq!(scheme.location, "header");
404                assert_eq!(scheme.name, "X-API-Key");
405            }
406            _ => panic!("expected api key scheme"),
407        }
408
409        let reserialized = serde_json::to_string(&scheme).expect("scheme should serialize");
410        assert_eq!(
411            reserialized,
412            r#"{"apiKeySecurityScheme":{"description":"Header auth","location":"header","name":"X-API-Key"}}"#
413        );
414    }
415
416    #[test]
417    fn security_scheme_deserializes_python_sdk_http_shape() {
418        let json = serde_json::json!({
419            "type": "http",
420            "scheme": "bearer",
421            "bearerFormat": "JWT"
422        });
423
424        let scheme: SecurityScheme =
425            serde_json::from_value(json).expect("scheme should deserialize");
426
427        assert!(matches!(
428            scheme,
429            SecurityScheme::HttpAuthSecurityScheme(HttpAuthSecurityScheme { scheme, .. }) if scheme == "bearer"
430        ));
431    }
432
433    #[test]
434    fn security_scheme_deserializes_python_sdk_openid_shape() {
435        let json = serde_json::json!({
436            "type": "openIdConnect",
437            "openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
438        });
439
440        let scheme: SecurityScheme =
441            serde_json::from_value(json).expect("scheme should deserialize");
442
443        assert!(matches!(
444            scheme,
445            SecurityScheme::OpenIdConnectSecurityScheme(OpenIdConnectSecurityScheme { open_id_connect_url, .. })
446                if open_id_connect_url == "https://example.com/.well-known/openid-configuration"
447        ));
448    }
449
450    #[test]
451    fn oauth_flows_deserialize_python_sdk_object_shape() {
452        let json = serde_json::json!({
453            "authorizationCode": {
454                "authorizationUrl": "https://example.com/authorize",
455                "tokenUrl": "https://example.com/token",
456                "scopes": {
457                    "read": "Read access"
458                },
459                "pkceRequired": true
460            }
461        });
462
463        let flows: OAuthFlows = serde_json::from_value(json).expect("flows should deserialize");
464        assert!(matches!(
465            flows,
466            OAuthFlows::AuthorizationCode(AuthorizationCodeOAuthFlow {
467                pkce_required: true,
468                ..
469            })
470        ));
471    }
472
473    #[test]
474    fn security_scheme_deserializes_python_sdk_oauth2_shape() {
475        let json = serde_json::json!({
476            "type": "oauth2",
477            "flows": {
478                "authorizationCode": {
479                    "authorizationUrl": "https://example.com/authorize",
480                    "tokenUrl": "https://example.com/token",
481                    "scopes": {
482                        "read": "Read access"
483                    }
484                }
485            }
486        });
487
488        let scheme: SecurityScheme =
489            serde_json::from_value(json).expect("scheme should deserialize");
490
491        assert!(matches!(
492            scheme,
493            SecurityScheme::OAuth2SecurityScheme(OAuth2SecurityScheme {
494                flows: OAuthFlows::AuthorizationCode(_),
495                ..
496            })
497        ));
498    }
499
500    #[test]
501    fn security_scheme_deserializes_python_sdk_mutual_tls_shape() {
502        let json = serde_json::json!({
503            "type": "mutualTLS",
504            "description": "mTLS client cert"
505        });
506
507        let scheme: SecurityScheme =
508            serde_json::from_value(json).expect("scheme should deserialize");
509
510        assert!(matches!(scheme, SecurityScheme::MutualTlsSecurityScheme(_)));
511    }
512}