use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
use url::Url;
#[derive(Debug, Clone, Error)]
pub enum DiscoveryError {
#[error("Invalid issuer URL: {0}")]
InvalidIssuer(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Invalid field value for {field}: {reason}")]
InvalidField { field: String, reason: String },
#[error("Issuer in document ({document}) does not match expected issuer ({expected})")]
IssuerMismatch { document: String, expected: String },
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct AuthorizationServerMetadata {
pub issuer: String,
pub authorization_endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scopes_supported: Option<Vec<String>>,
pub response_types_supported: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_modes_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grant_types_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_methods_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_signing_alg_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub service_documentation: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub ui_locales_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub op_policy_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub op_tos_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub revocation_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub revocation_endpoint_auth_methods_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub introspection_endpoint: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub introspection_endpoint_auth_methods_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub code_challenge_methods_supported: Option<Vec<String>>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
impl AuthorizationServerMetadata {
pub fn validate(&self) -> Result<(), DiscoveryError> {
let issuer_url = Url::parse(&self.issuer)
.map_err(|e| DiscoveryError::InvalidIssuer(format!("Invalid issuer URL: {}", e)))?;
if issuer_url.scheme() != "https" {
return Err(DiscoveryError::InvalidIssuer(
"Issuer MUST use https scheme".to_string(),
));
}
Url::parse(&self.authorization_endpoint).map_err(|e| DiscoveryError::InvalidField {
field: "authorization_endpoint".to_string(),
reason: format!("Invalid URL: {}", e),
})?;
if let Some(ref token_endpoint) = self.token_endpoint {
Url::parse(token_endpoint).map_err(|e| DiscoveryError::InvalidField {
field: "token_endpoint".to_string(),
reason: format!("Invalid URL: {}", e),
})?;
}
if let Some(ref jwks_uri) = self.jwks_uri {
Url::parse(jwks_uri).map_err(|e| DiscoveryError::InvalidField {
field: "jwks_uri".to_string(),
reason: format!("Invalid URL: {}", e),
})?;
}
if self.response_types_supported.is_empty() {
return Err(DiscoveryError::MissingField(
"response_types_supported cannot be empty".to_string(),
));
}
Ok(())
}
pub fn grant_types(&self) -> Vec<String> {
self.grant_types_supported
.clone()
.unwrap_or_else(|| vec!["authorization_code".to_string(), "implicit".to_string()])
}
pub fn supports_pkce(&self) -> bool {
self.code_challenge_methods_supported
.as_ref()
.map(|methods| !methods.is_empty())
.unwrap_or(false)
}
pub fn supports_pkce_method(&self, method: &str) -> bool {
self.code_challenge_methods_supported
.as_ref()
.map(|methods| methods.iter().any(|m| m == method))
.unwrap_or(false)
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct OIDCProviderMetadata {
#[serde(flatten)]
pub oauth2: AuthorizationServerMetadata,
pub userinfo_endpoint: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub acr_values_supported: Option<Vec<String>>,
pub subject_types_supported: Vec<String>,
pub id_token_signing_alg_values_supported: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token_encryption_alg_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub id_token_encryption_enc_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_signing_alg_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_encryption_alg_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_encryption_enc_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_object_signing_alg_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub display_values_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub claim_types_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub claims_supported: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub claims_parameter_supported: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_parameter_supported: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_uri_parameter_supported: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
pub require_request_uri_registration: Option<bool>,
}
impl OIDCProviderMetadata {
pub fn validate(&self) -> Result<(), DiscoveryError> {
self.oauth2.validate()?;
Url::parse(&self.userinfo_endpoint).map_err(|e| DiscoveryError::InvalidField {
field: "userinfo_endpoint".to_string(),
reason: format!("Invalid URL: {}", e),
})?;
if self.subject_types_supported.is_empty() {
return Err(DiscoveryError::MissingField(
"subject_types_supported cannot be empty".to_string(),
));
}
if self.id_token_signing_alg_values_supported.is_empty() {
return Err(DiscoveryError::MissingField(
"id_token_signing_alg_values_supported cannot be empty".to_string(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ValidatedDiscoveryMetadata {
metadata: DiscoveryMetadata,
issuer: String,
fetched_at: std::time::SystemTime,
}
#[derive(Debug, Clone)]
pub enum DiscoveryMetadata {
OAuth2(Box<AuthorizationServerMetadata>),
OIDC(Box<OIDCProviderMetadata>),
}
impl ValidatedDiscoveryMetadata {
pub fn new_oauth2(
metadata: AuthorizationServerMetadata,
issuer: String,
) -> Result<Self, DiscoveryError> {
metadata.validate()?;
if metadata.issuer != issuer {
return Err(DiscoveryError::IssuerMismatch {
document: metadata.issuer.clone(),
expected: issuer,
});
}
Ok(Self {
metadata: DiscoveryMetadata::OAuth2(Box::new(metadata)),
issuer,
fetched_at: std::time::SystemTime::now(),
})
}
pub fn new_oidc(
metadata: OIDCProviderMetadata,
issuer: String,
) -> Result<Self, DiscoveryError> {
metadata.validate()?;
if metadata.oauth2.issuer != issuer {
return Err(DiscoveryError::IssuerMismatch {
document: metadata.oauth2.issuer.clone(),
expected: issuer,
});
}
Ok(Self {
metadata: DiscoveryMetadata::OIDC(Box::new(metadata)),
issuer,
fetched_at: std::time::SystemTime::now(),
})
}
pub fn metadata(&self) -> &DiscoveryMetadata {
&self.metadata
}
pub fn issuer(&self) -> &str {
&self.issuer
}
pub fn fetched_at(&self) -> std::time::SystemTime {
self.fetched_at
}
pub fn oauth2(&self) -> &AuthorizationServerMetadata {
match &self.metadata {
DiscoveryMetadata::OAuth2(oauth2) => oauth2,
DiscoveryMetadata::OIDC(oidc) => &oidc.oauth2,
}
}
pub fn oidc(&self) -> Option<&OIDCProviderMetadata> {
match &self.metadata {
DiscoveryMetadata::OAuth2(_) => None,
DiscoveryMetadata::OIDC(oidc) => Some(oidc),
}
}
pub fn is_oidc(&self) -> bool {
matches!(self.metadata, DiscoveryMetadata::OIDC(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_oauth2_metadata_validation_success() {
let metadata = AuthorizationServerMetadata {
issuer: "https://server.example.com".to_string(),
authorization_endpoint: "https://server.example.com/authorize".to_string(),
token_endpoint: Some("https://server.example.com/token".to_string()),
jwks_uri: Some("https://server.example.com/jwks".to_string()),
registration_endpoint: None,
scopes_supported: Some(vec!["openid".to_string(), "profile".to_string()]),
response_types_supported: vec!["code".to_string()],
response_modes_supported: None,
grant_types_supported: Some(vec!["authorization_code".to_string()]),
token_endpoint_auth_methods_supported: Some(vec!["client_secret_basic".to_string()]),
token_endpoint_auth_signing_alg_values_supported: None,
service_documentation: None,
ui_locales_supported: None,
op_policy_uri: None,
op_tos_uri: None,
revocation_endpoint: None,
revocation_endpoint_auth_methods_supported: None,
introspection_endpoint: None,
introspection_endpoint_auth_methods_supported: None,
code_challenge_methods_supported: Some(vec!["S256".to_string()]),
additional_fields: HashMap::new(),
};
assert!(metadata.validate().is_ok());
}
#[test]
fn test_oauth2_metadata_validation_requires_https() {
let metadata = AuthorizationServerMetadata {
issuer: "http://server.example.com".to_string(),
authorization_endpoint: "https://server.example.com/authorize".to_string(),
token_endpoint: Some("https://server.example.com/token".to_string()),
jwks_uri: None,
registration_endpoint: None,
scopes_supported: None,
response_types_supported: vec!["code".to_string()],
response_modes_supported: None,
grant_types_supported: None,
token_endpoint_auth_methods_supported: None,
token_endpoint_auth_signing_alg_values_supported: None,
service_documentation: None,
ui_locales_supported: None,
op_policy_uri: None,
op_tos_uri: None,
revocation_endpoint: None,
revocation_endpoint_auth_methods_supported: None,
introspection_endpoint: None,
introspection_endpoint_auth_methods_supported: None,
code_challenge_methods_supported: None,
additional_fields: HashMap::new(),
};
assert!(matches!(
metadata.validate(),
Err(DiscoveryError::InvalidIssuer(_))
));
}
#[test]
fn test_pkce_support_detection() {
let mut metadata = AuthorizationServerMetadata {
issuer: "https://server.example.com".to_string(),
authorization_endpoint: "https://server.example.com/authorize".to_string(),
token_endpoint: Some("https://server.example.com/token".to_string()),
jwks_uri: None,
registration_endpoint: None,
scopes_supported: None,
response_types_supported: vec!["code".to_string()],
response_modes_supported: None,
grant_types_supported: None,
token_endpoint_auth_methods_supported: None,
token_endpoint_auth_signing_alg_values_supported: None,
service_documentation: None,
ui_locales_supported: None,
op_policy_uri: None,
op_tos_uri: None,
revocation_endpoint: None,
revocation_endpoint_auth_methods_supported: None,
introspection_endpoint: None,
introspection_endpoint_auth_methods_supported: None,
code_challenge_methods_supported: None,
additional_fields: HashMap::new(),
};
assert!(!metadata.supports_pkce());
assert!(!metadata.supports_pkce_method("S256"));
metadata.code_challenge_methods_supported = Some(vec!["S256".to_string()]);
assert!(metadata.supports_pkce());
assert!(metadata.supports_pkce_method("S256"));
assert!(!metadata.supports_pkce_method("plain"));
}
#[test]
fn test_validated_metadata_issuer_match() {
let metadata = AuthorizationServerMetadata {
issuer: "https://server.example.com".to_string(),
authorization_endpoint: "https://server.example.com/authorize".to_string(),
token_endpoint: Some("https://server.example.com/token".to_string()),
jwks_uri: None,
registration_endpoint: None,
scopes_supported: None,
response_types_supported: vec!["code".to_string()],
response_modes_supported: None,
grant_types_supported: None,
token_endpoint_auth_methods_supported: None,
token_endpoint_auth_signing_alg_values_supported: None,
service_documentation: None,
ui_locales_supported: None,
op_policy_uri: None,
op_tos_uri: None,
revocation_endpoint: None,
revocation_endpoint_auth_methods_supported: None,
introspection_endpoint: None,
introspection_endpoint_auth_methods_supported: None,
code_challenge_methods_supported: None,
additional_fields: HashMap::new(),
};
let validated = ValidatedDiscoveryMetadata::new_oauth2(
metadata,
"https://server.example.com".to_string(),
);
assert!(validated.is_ok());
}
#[test]
fn test_validated_metadata_issuer_mismatch() {
let metadata = AuthorizationServerMetadata {
issuer: "https://server.example.com".to_string(),
authorization_endpoint: "https://server.example.com/authorize".to_string(),
token_endpoint: Some("https://server.example.com/token".to_string()),
jwks_uri: None,
registration_endpoint: None,
scopes_supported: None,
response_types_supported: vec!["code".to_string()],
response_modes_supported: None,
grant_types_supported: None,
token_endpoint_auth_methods_supported: None,
token_endpoint_auth_signing_alg_values_supported: None,
service_documentation: None,
ui_locales_supported: None,
op_policy_uri: None,
op_tos_uri: None,
revocation_endpoint: None,
revocation_endpoint_auth_methods_supported: None,
introspection_endpoint: None,
introspection_endpoint_auth_methods_supported: None,
code_challenge_methods_supported: None,
additional_fields: HashMap::new(),
};
let validated =
ValidatedDiscoveryMetadata::new_oauth2(metadata, "https://attacker.com".to_string());
assert!(matches!(
validated,
Err(DiscoveryError::IssuerMismatch { .. })
));
}
}