1use std::collections::BTreeMap;
2
3use serde::{Deserialize, Deserializer, Serialize};
4use serde_json::Value;
5
6#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8#[serde(rename_all = "camelCase")]
9pub struct StringList {
10 #[serde(default, skip_serializing_if = "Vec::is_empty")]
11 pub list: Vec<String>,
13}
14
15#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17#[serde(rename_all = "camelCase")]
18pub struct SecurityRequirement {
19 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
20 pub schemes: BTreeMap<String, StringList>,
22}
23
24#[derive(Debug, Clone, Serialize)]
26#[serde(rename_all = "camelCase")]
27pub enum SecurityScheme {
28 #[serde(rename = "apiKeySecurityScheme")]
29 ApiKeySecurityScheme(ApiKeySecurityScheme),
31 #[serde(rename = "httpAuthSecurityScheme")]
32 HttpAuthSecurityScheme(HttpAuthSecurityScheme),
34 #[serde(rename = "oauth2SecurityScheme")]
35 OAuth2SecurityScheme(OAuth2SecurityScheme),
37 #[serde(rename = "openIdConnectSecurityScheme")]
38 OpenIdConnectSecurityScheme(OpenIdConnectSecurityScheme),
40 #[serde(rename = "mtlsSecurityScheme")]
41 MutualTlsSecurityScheme(MutualTlsSecurityScheme),
43}
44
45impl<'de> Deserialize<'de> for SecurityScheme {
46 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
47 where
48 D: Deserializer<'de>,
49 {
50 let value = Value::deserialize(deserializer)?;
51 deserialize_security_scheme(value).map_err(serde::de::Error::custom)
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(rename_all = "camelCase")]
58pub struct ApiKeySecurityScheme {
59 #[serde(default, skip_serializing_if = "Option::is_none")]
60 pub description: Option<String>,
62 pub location: String,
64 pub name: String,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
70#[serde(rename_all = "camelCase")]
71pub struct HttpAuthSecurityScheme {
72 #[serde(default, skip_serializing_if = "Option::is_none")]
73 pub description: Option<String>,
75 pub scheme: String,
77 #[serde(default, skip_serializing_if = "Option::is_none")]
78 pub bearer_format: Option<String>,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84#[serde(rename_all = "camelCase")]
85pub struct OAuth2SecurityScheme {
86 #[serde(default, skip_serializing_if = "Option::is_none")]
87 pub description: Option<String>,
89 pub flows: OAuthFlows,
91 #[serde(default, skip_serializing_if = "Option::is_none")]
92 pub oauth2_metadata_url: Option<String>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98#[serde(rename_all = "camelCase")]
99pub struct OpenIdConnectSecurityScheme {
100 #[serde(default, skip_serializing_if = "Option::is_none")]
101 pub description: Option<String>,
103 pub open_id_connect_url: String,
105}
106
107#[derive(Debug, Clone, Serialize, Deserialize)]
109#[serde(rename_all = "camelCase")]
110pub struct MutualTlsSecurityScheme {
111 #[serde(default, skip_serializing_if = "Option::is_none")]
112 pub description: Option<String>,
114}
115
116#[derive(Debug, Clone, Serialize)]
118#[serde(rename_all = "camelCase")]
119pub enum OAuthFlows {
120 AuthorizationCode(AuthorizationCodeOAuthFlow),
122 ClientCredentials(ClientCredentialsOAuthFlow),
124 Implicit(ImplicitOAuthFlow),
126 Password(PasswordOAuthFlow),
128 DeviceCode(DeviceCodeOAuthFlow),
130}
131
132impl<'de> Deserialize<'de> for OAuthFlows {
133 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
134 where
135 D: Deserializer<'de>,
136 {
137 let value = Value::deserialize(deserializer)?;
138 deserialize_oauth_flows(value).map_err(serde::de::Error::custom)
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144#[serde(rename_all = "camelCase")]
145pub struct AuthorizationCodeOAuthFlow {
146 pub authorization_url: String,
148 pub token_url: String,
150 #[serde(default, skip_serializing_if = "Option::is_none")]
151 pub refresh_url: Option<String>,
153 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
154 pub scopes: BTreeMap<String, String>,
156 #[serde(default, skip_serializing_if = "crate::types::is_false")]
157 pub pkce_required: bool,
159}
160
161#[derive(Debug, Clone, Serialize, Deserialize)]
163#[serde(rename_all = "camelCase")]
164pub struct ClientCredentialsOAuthFlow {
165 pub token_url: String,
167 #[serde(default, skip_serializing_if = "Option::is_none")]
168 pub refresh_url: Option<String>,
170 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
171 pub scopes: BTreeMap<String, String>,
173}
174
175#[derive(Debug, Clone, Serialize, Deserialize)]
177#[serde(rename_all = "camelCase")]
178pub struct ImplicitOAuthFlow {
179 pub authorization_url: String,
181 #[serde(default, skip_serializing_if = "Option::is_none")]
182 pub refresh_url: Option<String>,
184 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
185 pub scopes: BTreeMap<String, String>,
187}
188
189#[derive(Debug, Clone, Serialize, Deserialize)]
191#[serde(rename_all = "camelCase")]
192pub struct PasswordOAuthFlow {
193 pub token_url: String,
195 #[serde(default, skip_serializing_if = "Option::is_none")]
196 pub refresh_url: Option<String>,
198 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
199 pub scopes: BTreeMap<String, String>,
201}
202
203#[derive(Debug, Clone, Serialize, Deserialize)]
205#[serde(rename_all = "camelCase")]
206pub struct DeviceCodeOAuthFlow {
207 pub device_authorization_url: String,
209 pub token_url: String,
211 #[serde(default, skip_serializing_if = "Option::is_none")]
212 pub refresh_url: Option<String>,
214 #[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
215 pub scopes: BTreeMap<String, String>,
217}
218
219fn deserialize_security_scheme(value: Value) -> Result<SecurityScheme, String> {
220 let Value::Object(mut object) = value else {
221 return Err("security scheme must be a JSON object".to_owned());
222 };
223
224 if object.len() == 1 {
225 let (key, value) = object
226 .into_iter()
227 .next()
228 .ok_or_else(|| "security scheme object cannot be empty".to_owned())?;
229 return match key.as_str() {
230 "apiKeySecurityScheme" => {
231 deserialize_variant(value, SecurityScheme::ApiKeySecurityScheme)
232 }
233 "httpAuthSecurityScheme" => {
234 deserialize_variant(value, SecurityScheme::HttpAuthSecurityScheme)
235 }
236 "oauth2SecurityScheme" => {
237 deserialize_variant(value, SecurityScheme::OAuth2SecurityScheme)
238 }
239 "openIdConnectSecurityScheme" => {
240 deserialize_variant(value, SecurityScheme::OpenIdConnectSecurityScheme)
241 }
242 "mtlsSecurityScheme" => {
243 deserialize_variant(value, SecurityScheme::MutualTlsSecurityScheme)
244 }
245 _ => Err(format!("unknown security scheme variant: {key}")),
246 };
247 }
248
249 let type_name = object
250 .remove("type")
251 .and_then(|value| match value {
252 Value::String(value) => Some(value),
253 _ => None,
254 })
255 .ok_or_else(|| "security scheme must contain either a proto oneof tag or a Python SDK 'type' discriminator".to_owned())?;
256
257 match type_name.as_str() {
258 "apiKey" => {
259 if let Some(location) = object.remove("in") {
260 object.insert("location".to_owned(), location);
261 }
262 deserialize_variant(Value::Object(object), SecurityScheme::ApiKeySecurityScheme)
263 }
264 "http" => deserialize_variant(
265 Value::Object(object),
266 SecurityScheme::HttpAuthSecurityScheme,
267 ),
268 "oauth2" => {
269 deserialize_variant(Value::Object(object), SecurityScheme::OAuth2SecurityScheme)
270 }
271 "openIdConnect" => deserialize_variant(
272 Value::Object(object),
273 SecurityScheme::OpenIdConnectSecurityScheme,
274 ),
275 "mutualTLS" | "mutualTls" | "mtls" => deserialize_variant(
276 Value::Object(object),
277 SecurityScheme::MutualTlsSecurityScheme,
278 ),
279 other => Err(format!(
280 "unsupported security scheme type discriminator: {other}"
281 )),
282 }
283}
284
285fn deserialize_oauth_flows(value: Value) -> Result<OAuthFlows, String> {
286 let Value::Object(mut object) = value else {
287 return Err("oauth flows must be a JSON object".to_owned());
288 };
289
290 let mut chosen: Option<(&'static str, Value)> = None;
291 for key in [
292 "authorizationCode",
293 "clientCredentials",
294 "implicit",
295 "password",
296 "deviceCode",
297 ] {
298 match object.remove(key) {
299 Some(Value::Null) | None => {}
300 Some(value) => {
301 if chosen.is_some() {
302 return Err("oauth flows must contain exactly one flow variant".to_owned());
303 }
304 chosen = Some((key, value));
305 }
306 }
307 }
308
309 if !object.is_empty() {
310 let mut keys = object.keys().cloned().collect::<Vec<_>>();
311 keys.sort();
312 return Err(format!(
313 "oauth flows contained unexpected keys: {}",
314 keys.join(", ")
315 ));
316 }
317
318 let Some((key, value)) = chosen else {
319 return Err("oauth flows must contain exactly one flow variant".to_owned());
320 };
321
322 match key {
323 "authorizationCode" => deserialize_variant(value, OAuthFlows::AuthorizationCode),
324 "clientCredentials" => deserialize_variant(value, OAuthFlows::ClientCredentials),
325 "implicit" => deserialize_variant(value, OAuthFlows::Implicit),
326 "password" => deserialize_variant(value, OAuthFlows::Password),
327 "deviceCode" => deserialize_variant(value, OAuthFlows::DeviceCode),
328 _ => Err(format!("unsupported oauth flow variant: {key}")),
329 }
330}
331
332fn deserialize_variant<T, U>(value: Value, constructor: impl FnOnce(T) -> U) -> Result<U, String>
333where
334 T: serde::de::DeserializeOwned,
335{
336 serde_json::from_value(value)
337 .map(constructor)
338 .map_err(|error| error.to_string())
339}
340
341#[cfg(test)]
342mod tests {
343 use std::collections::BTreeMap;
344
345 use super::{
346 ApiKeySecurityScheme, AuthorizationCodeOAuthFlow, HttpAuthSecurityScheme,
347 OAuth2SecurityScheme, OAuthFlows, OpenIdConnectSecurityScheme, SecurityScheme,
348 };
349
350 #[test]
351 fn security_scheme_serializes_as_externally_tagged_enum() {
352 let scheme = SecurityScheme::ApiKeySecurityScheme(ApiKeySecurityScheme {
353 description: None,
354 location: "header".to_owned(),
355 name: "X-API-Key".to_owned(),
356 });
357
358 let json = serde_json::to_string(&scheme).expect("scheme should serialize");
359 assert_eq!(
360 json,
361 r#"{"apiKeySecurityScheme":{"location":"header","name":"X-API-Key"}}"#
362 );
363 }
364
365 #[test]
366 fn oauth_flows_serializes_with_variant_name() {
367 let mut scopes = BTreeMap::new();
368 scopes.insert("read".to_owned(), "Read access".to_owned());
369
370 let scheme = OAuth2SecurityScheme {
371 description: None,
372 flows: OAuthFlows::AuthorizationCode(AuthorizationCodeOAuthFlow {
373 authorization_url: "https://example.com/authorize".to_owned(),
374 token_url: "https://example.com/token".to_owned(),
375 refresh_url: None,
376 scopes,
377 pkce_required: true,
378 }),
379 oauth2_metadata_url: None,
380 };
381
382 let json = serde_json::to_string(&scheme).expect("oauth2 scheme should serialize");
383 assert!(json.contains(
384 r#""authorizationCode":{"authorizationUrl":"https://example.com/authorize""#
385 ));
386 assert!(json.contains(r#""pkceRequired":true"#));
387 }
388
389 #[test]
390 fn security_scheme_deserializes_python_sdk_api_key_shape() {
391 let json = serde_json::json!({
392 "type": "apiKey",
393 "description": "Header auth",
394 "in": "header",
395 "name": "X-API-Key"
396 });
397
398 let scheme: SecurityScheme =
399 serde_json::from_value(json).expect("scheme should deserialize");
400
401 match &scheme {
402 SecurityScheme::ApiKeySecurityScheme(scheme) => {
403 assert_eq!(scheme.location, "header");
404 assert_eq!(scheme.name, "X-API-Key");
405 }
406 _ => panic!("expected api key scheme"),
407 }
408
409 let reserialized = serde_json::to_string(&scheme).expect("scheme should serialize");
410 assert_eq!(
411 reserialized,
412 r#"{"apiKeySecurityScheme":{"description":"Header auth","location":"header","name":"X-API-Key"}}"#
413 );
414 }
415
416 #[test]
417 fn security_scheme_deserializes_python_sdk_http_shape() {
418 let json = serde_json::json!({
419 "type": "http",
420 "scheme": "bearer",
421 "bearerFormat": "JWT"
422 });
423
424 let scheme: SecurityScheme =
425 serde_json::from_value(json).expect("scheme should deserialize");
426
427 assert!(matches!(
428 scheme,
429 SecurityScheme::HttpAuthSecurityScheme(HttpAuthSecurityScheme { scheme, .. }) if scheme == "bearer"
430 ));
431 }
432
433 #[test]
434 fn security_scheme_deserializes_python_sdk_openid_shape() {
435 let json = serde_json::json!({
436 "type": "openIdConnect",
437 "openIdConnectUrl": "https://example.com/.well-known/openid-configuration"
438 });
439
440 let scheme: SecurityScheme =
441 serde_json::from_value(json).expect("scheme should deserialize");
442
443 assert!(matches!(
444 scheme,
445 SecurityScheme::OpenIdConnectSecurityScheme(OpenIdConnectSecurityScheme { open_id_connect_url, .. })
446 if open_id_connect_url == "https://example.com/.well-known/openid-configuration"
447 ));
448 }
449
450 #[test]
451 fn oauth_flows_deserialize_python_sdk_object_shape() {
452 let json = serde_json::json!({
453 "authorizationCode": {
454 "authorizationUrl": "https://example.com/authorize",
455 "tokenUrl": "https://example.com/token",
456 "scopes": {
457 "read": "Read access"
458 },
459 "pkceRequired": true
460 }
461 });
462
463 let flows: OAuthFlows = serde_json::from_value(json).expect("flows should deserialize");
464 assert!(matches!(
465 flows,
466 OAuthFlows::AuthorizationCode(AuthorizationCodeOAuthFlow {
467 pkce_required: true,
468 ..
469 })
470 ));
471 }
472
473 #[test]
474 fn security_scheme_deserializes_python_sdk_oauth2_shape() {
475 let json = serde_json::json!({
476 "type": "oauth2",
477 "flows": {
478 "authorizationCode": {
479 "authorizationUrl": "https://example.com/authorize",
480 "tokenUrl": "https://example.com/token",
481 "scopes": {
482 "read": "Read access"
483 }
484 }
485 }
486 });
487
488 let scheme: SecurityScheme =
489 serde_json::from_value(json).expect("scheme should deserialize");
490
491 assert!(matches!(
492 scheme,
493 SecurityScheme::OAuth2SecurityScheme(OAuth2SecurityScheme {
494 flows: OAuthFlows::AuthorizationCode(_),
495 ..
496 })
497 ));
498 }
499
500 #[test]
501 fn security_scheme_deserializes_python_sdk_mutual_tls_shape() {
502 let json = serde_json::json!({
503 "type": "mutualTLS",
504 "description": "mTLS client cert"
505 });
506
507 let scheme: SecurityScheme =
508 serde_json::from_value(json).expect("scheme should deserialize");
509
510 assert!(matches!(scheme, SecurityScheme::MutualTlsSecurityScheme(_)));
511 }
512}