use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use url::Url;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ClientMetadata {
pub client_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub client_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub logo_uri: Option<String>,
pub redirect_uris: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub grant_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub response_types: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint_auth_method: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks: Option<serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
pub contacts: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_version: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub software_statement: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub scope: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tos_uri: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub policy_uri: Option<String>,
#[serde(flatten)]
pub additional_fields: HashMap<String, serde_json::Value>,
}
impl ClientMetadata {
pub fn new(client_id: String, redirect_uris: Vec<String>) -> Self {
Self {
client_id,
redirect_uris,
client_name: None,
client_uri: None,
logo_uri: None,
grant_types: None,
response_types: None,
token_endpoint_auth_method: None,
jwks_uri: None,
jwks: None,
contacts: None,
software_id: None,
software_version: None,
software_statement: None,
scope: None,
tos_uri: None,
policy_uri: None,
additional_fields: HashMap::new(),
}
}
pub fn validate(&self) -> Result<(), ClientMetadataError> {
let client_id_url = Url::parse(&self.client_id)
.map_err(|e| ClientMetadataError::InvalidClientId(format!("Invalid URL: {}", e)))?;
if client_id_url.scheme() != "https" {
return Err(ClientMetadataError::InvalidClientId(
"client_id MUST use https scheme".to_string(),
));
}
if self.redirect_uris.is_empty() {
return Err(ClientMetadataError::MissingRedirectUris);
}
for uri in &self.redirect_uris {
Url::parse(uri).map_err(|e| {
ClientMetadataError::InvalidRedirectUri(format!("Invalid redirect URI: {}", e))
})?;
}
if let Some(ref client_uri) = self.client_uri {
Url::parse(client_uri).map_err(|e| {
ClientMetadataError::InvalidField(format!("Invalid client_uri: {}", e))
})?;
}
if let Some(ref logo_uri) = self.logo_uri {
Url::parse(logo_uri).map_err(|e| {
ClientMetadataError::InvalidField(format!("Invalid logo_uri: {}", e))
})?;
}
if let Some(ref jwks_uri) = self.jwks_uri {
Url::parse(jwks_uri).map_err(|e| {
ClientMetadataError::InvalidField(format!("Invalid jwks_uri: {}", e))
})?;
}
if let Some(ref tos_uri) = self.tos_uri {
Url::parse(tos_uri).map_err(|e| {
ClientMetadataError::InvalidField(format!("Invalid tos_uri: {}", e))
})?;
}
if let Some(ref policy_uri) = self.policy_uri {
Url::parse(policy_uri).map_err(|e| {
ClientMetadataError::InvalidField(format!("Invalid policy_uri: {}", e))
})?;
}
Ok(())
}
pub fn is_public_client(&self) -> bool {
matches!(
self.token_endpoint_auth_method.as_deref(),
None | Some("none")
)
}
pub fn uses_private_key_jwt(&self) -> bool {
self.token_endpoint_auth_method.as_deref() == Some("private_key_jwt")
}
pub fn grant_types(&self) -> Vec<String> {
self.grant_types
.clone()
.unwrap_or_else(|| vec!["authorization_code".to_string()])
}
pub fn response_types(&self) -> Vec<String> {
self.response_types
.clone()
.unwrap_or_else(|| vec!["code".to_string()])
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum ClientMetadataError {
#[error("Invalid client_id: {0}")]
InvalidClientId(String),
#[error("redirect_uris is required and cannot be empty")]
MissingRedirectUris,
#[error("Invalid redirect URI: {0}")]
InvalidRedirectUri(String),
#[error("Invalid metadata field: {0}")]
InvalidField(String),
#[error("client_id in document ({document}) does not match URL ({url})")]
ClientIdMismatch { document: String, url: String },
#[error("Failed to parse metadata: {0}")]
ParseError(String),
}
#[derive(Debug, Clone)]
pub struct ValidatedClientMetadata {
metadata: ClientMetadata,
source_url: String,
fetched_at: std::time::SystemTime,
}
impl ValidatedClientMetadata {
pub fn new(metadata: ClientMetadata, source_url: String) -> Result<Self, ClientMetadataError> {
metadata.validate()?;
if metadata.client_id != source_url {
return Err(ClientMetadataError::ClientIdMismatch {
document: metadata.client_id.clone(),
url: source_url,
});
}
Ok(Self {
metadata,
source_url,
fetched_at: std::time::SystemTime::now(),
})
}
pub fn metadata(&self) -> &ClientMetadata {
&self.metadata
}
pub fn source_url(&self) -> &str {
&self.source_url
}
pub fn fetched_at(&self) -> std::time::SystemTime {
self.fetched_at
}
pub fn is_redirect_uri_allowed(&self, redirect_uri: &str) -> bool {
self.metadata
.redirect_uris
.contains(&redirect_uri.to_string())
}
pub fn validate_redirect_uri(&self, redirect_uri: &str) -> Result<(), ClientMetadataError> {
if self.is_redirect_uri_allowed(redirect_uri) {
Ok(())
} else {
Err(ClientMetadataError::InvalidRedirectUri(format!(
"{} is not an allowed redirect URI for client {}",
redirect_uri, self.metadata.client_id
)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_metadata_creation() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
assert_eq!(
metadata.client_id,
"https://example.com/client-metadata.json"
);
assert_eq!(metadata.redirect_uris.len(), 1);
}
#[test]
fn test_client_metadata_validation_success() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
assert!(metadata.validate().is_ok());
}
#[test]
fn test_client_metadata_validation_requires_https() {
let metadata = ClientMetadata::new(
"http://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
assert!(matches!(
metadata.validate(),
Err(ClientMetadataError::InvalidClientId(_))
));
}
#[test]
fn test_client_metadata_validation_requires_redirect_uris() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec![],
);
assert!(matches!(
metadata.validate(),
Err(ClientMetadataError::MissingRedirectUris)
));
}
#[test]
fn test_validated_client_metadata_client_id_match() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
let validated = ValidatedClientMetadata::new(
metadata,
"https://example.com/client-metadata.json".to_string(),
);
assert!(validated.is_ok());
}
#[test]
fn test_validated_client_metadata_client_id_mismatch() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
let validated =
ValidatedClientMetadata::new(metadata, "https://attacker.com/fake.json".to_string());
assert!(matches!(
validated,
Err(ClientMetadataError::ClientIdMismatch { .. })
));
}
#[test]
fn test_redirect_uri_validation() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec![
"http://localhost:3000/callback".to_string(),
"http://127.0.0.1:3000/callback".to_string(),
],
);
let validated = ValidatedClientMetadata::new(
metadata,
"https://example.com/client-metadata.json".to_string(),
)
.unwrap();
assert!(validated.is_redirect_uri_allowed("http://localhost:3000/callback"));
assert!(validated.is_redirect_uri_allowed("http://127.0.0.1:3000/callback"));
assert!(!validated.is_redirect_uri_allowed("http://attacker.com/callback"));
}
#[test]
fn test_is_public_client() {
let mut metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
assert!(metadata.is_public_client());
metadata.token_endpoint_auth_method = Some("none".to_string());
assert!(metadata.is_public_client());
metadata.token_endpoint_auth_method = Some("client_secret_basic".to_string());
assert!(!metadata.is_public_client());
}
#[test]
fn test_grant_types_defaults() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
assert_eq!(metadata.grant_types(), vec!["authorization_code"]);
}
#[test]
fn test_response_types_defaults() {
let metadata = ClientMetadata::new(
"https://example.com/client-metadata.json".to_string(),
vec!["http://localhost:3000/callback".to_string()],
);
assert_eq!(metadata.response_types(), vec!["code"]);
}
#[test]
fn test_serde_roundtrip() {
let metadata = ClientMetadata {
client_id: "https://example.com/client-metadata.json".to_string(),
client_name: Some("Test Client".to_string()),
client_uri: Some("https://example.com".to_string()),
logo_uri: Some("https://example.com/logo.png".to_string()),
redirect_uris: vec!["http://localhost:3000/callback".to_string()],
grant_types: Some(vec!["authorization_code".to_string()]),
response_types: Some(vec!["code".to_string()]),
token_endpoint_auth_method: Some("none".to_string()),
jwks_uri: None,
jwks: None,
contacts: Some(vec!["admin@example.com".to_string()]),
software_id: Some("test-client-v1".to_string()),
software_version: Some("1.0.0".to_string()),
software_statement: None,
scope: Some("read write".to_string()),
tos_uri: Some("https://example.com/tos".to_string()),
policy_uri: Some("https://example.com/privacy".to_string()),
additional_fields: HashMap::new(),
};
let json = serde_json::to_string(&metadata).unwrap();
let deserialized: ClientMetadata = serde_json::from_str(&json).unwrap();
assert_eq!(metadata, deserialized);
}
}