#![allow(clippy::doc_markdown)]
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
pub type NamedSecuritySchemes = HashMap<String, SecurityScheme>;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct StringList {
pub list: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct SecurityRequirement {
pub schemes: HashMap<String, StringList>,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum SecurityScheme {
#[serde(rename = "apiKey")]
ApiKey(ApiKeySecurityScheme),
#[serde(rename = "http")]
Http(HttpAuthSecurityScheme),
#[serde(rename = "oauth2")]
OAuth2(Box<OAuth2SecurityScheme>),
#[serde(rename = "openIdConnect")]
OpenIdConnect(OpenIdConnectSecurityScheme),
#[serde(rename = "mutualTLS")]
MutualTls(MutualTlsSecurityScheme),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ApiKeySecurityScheme {
#[serde(rename = "in")]
pub location: ApiKeyLocation,
pub name: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ApiKeyLocation {
Header,
Query,
Cookie,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct HttpAuthSecurityScheme {
pub scheme: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub bearer_format: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OAuth2SecurityScheme {
pub flows: OAuthFlows,
#[serde(skip_serializing_if = "Option::is_none")]
pub oauth2_metadata_url: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[non_exhaustive]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub enum OAuthFlows {
AuthorizationCode(AuthorizationCodeFlow),
ClientCredentials(ClientCredentialsFlow),
DeviceCode(DeviceCodeFlow),
Implicit(ImplicitFlow),
Password(PasswordOAuthFlow),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct AuthorizationCodeFlow {
pub authorization_url: String,
pub token_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_url: Option<String>,
pub scopes: HashMap<String, String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub pkce_required: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ClientCredentialsFlow {
pub token_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_url: Option<String>,
pub scopes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct DeviceCodeFlow {
pub device_authorization_url: String,
pub token_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_url: Option<String>,
pub scopes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ImplicitFlow {
pub authorization_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_url: Option<String>,
pub scopes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct PasswordOAuthFlow {
pub token_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub refresh_url: Option<String>,
pub scopes: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct OpenIdConnectSecurityScheme {
pub open_id_connect_url: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct MutualTlsSecurityScheme {
#[serde(skip_serializing_if = "Option::is_none")]
pub description: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn api_key_scheme_roundtrip() {
let scheme = SecurityScheme::ApiKey(ApiKeySecurityScheme {
location: ApiKeyLocation::Header,
name: "X-API-Key".into(),
description: None,
});
let json = serde_json::to_string(&scheme).expect("serialize");
assert!(
json.contains("\"type\":\"apiKey\""),
"tag must be present: {json}"
);
assert!(
json.contains("\"in\":\"header\""),
"location must use 'in': {json}"
);
let back: SecurityScheme = serde_json::from_str(&json).expect("deserialize");
match &back {
SecurityScheme::ApiKey(s) => {
assert_eq!(s.location, ApiKeyLocation::Header);
assert_eq!(s.name, "X-API-Key");
}
_ => panic!("expected ApiKey variant"),
}
}
#[test]
fn http_bearer_scheme_roundtrip() {
let scheme = SecurityScheme::Http(HttpAuthSecurityScheme {
scheme: "bearer".into(),
bearer_format: Some("JWT".into()),
description: None,
});
let json = serde_json::to_string(&scheme).expect("serialize");
assert!(json.contains("\"type\":\"http\""));
let back: SecurityScheme = serde_json::from_str(&json).expect("deserialize");
if let SecurityScheme::Http(h) = back {
assert_eq!(h.bearer_format.as_deref(), Some("JWT"));
} else {
panic!("wrong variant");
}
}
#[test]
fn oauth2_scheme_roundtrip() {
let scheme = SecurityScheme::OAuth2(Box::new(OAuth2SecurityScheme {
flows: OAuthFlows::ClientCredentials(ClientCredentialsFlow {
token_url: "https://auth.example.com/token".into(),
refresh_url: None,
scopes: HashMap::from([("read".into(), "Read access".into())]),
}),
oauth2_metadata_url: None,
description: None,
}));
let json = serde_json::to_string(&scheme).expect("serialize");
assert!(json.contains("\"type\":\"oauth2\""));
let back: SecurityScheme = serde_json::from_str(&json).expect("deserialize");
match &back {
SecurityScheme::OAuth2(o) => match &o.flows {
OAuthFlows::ClientCredentials(cc) => {
assert_eq!(cc.token_url, "https://auth.example.com/token");
assert_eq!(
cc.scopes.get("read").map(String::as_str),
Some("Read access")
);
}
_ => panic!("expected ClientCredentials flow"),
},
_ => panic!("expected OAuth2 variant"),
}
}
#[test]
fn mutual_tls_scheme_roundtrip() {
let scheme = SecurityScheme::MutualTls(MutualTlsSecurityScheme { description: None });
let json = serde_json::to_string(&scheme).expect("serialize");
assert!(json.contains("\"type\":\"mutualTLS\""));
let back: SecurityScheme = serde_json::from_str(&json).expect("deserialize");
match &back {
SecurityScheme::MutualTls(m) => {
assert!(m.description.is_none());
}
_ => panic!("expected MutualTls variant"),
}
}
#[test]
fn api_key_location_serialization() {
assert_eq!(
serde_json::to_string(&ApiKeyLocation::Header).expect("ser"),
"\"header\""
);
assert_eq!(
serde_json::to_string(&ApiKeyLocation::Query).expect("ser"),
"\"query\""
);
assert_eq!(
serde_json::to_string(&ApiKeyLocation::Cookie).expect("ser"),
"\"cookie\""
);
}
#[test]
fn wire_format_security_requirement() {
let req = SecurityRequirement {
schemes: HashMap::from([(
"oauth2".into(),
StringList {
list: vec!["read".into(), "write".into()],
},
)]),
};
let json = serde_json::to_string(&req).unwrap();
let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
assert_eq!(
parsed["schemes"]["oauth2"]["list"],
serde_json::json!(["read", "write"])
);
let back: SecurityRequirement = serde_json::from_str(&json).unwrap();
assert_eq!(back.schemes["oauth2"].list, vec!["read", "write"]);
}
#[test]
fn wire_format_password_oauth_flow() {
let flows = OAuthFlows::Password(PasswordOAuthFlow {
token_url: "https://auth.example.com/token".into(),
refresh_url: None,
scopes: HashMap::from([("read".into(), "Read access".into())]),
});
let json = serde_json::to_string(&flows).unwrap();
assert!(
json.contains("\"password\""),
"password flow must be present: {json}"
);
let back: OAuthFlows = serde_json::from_str(&json).unwrap();
match back {
OAuthFlows::Password(p) => {
assert_eq!(p.token_url, "https://auth.example.com/token");
}
_ => panic!("expected Password flow"),
}
}
}