use std::{collections::HashSet, time::Duration};
use jsonwebtoken::{Algorithm, jwk::JwkSet};
use mockall_double::double;
use reqwest::Client;
use url::Url;
use crate::{error::StartupError, oidc::OidcConfig, validation::ClaimsValidationSpec};
#[double]
use crate::oidc::OidcDiscovery;
pub fn default_allowed_algorithms() -> HashSet<Algorithm> {
HashSet::from([
Algorithm::RS256,
Algorithm::RS384,
Algorithm::RS512,
Algorithm::ES256,
Algorithm::ES384,
Algorithm::PS256,
Algorithm::PS384,
Algorithm::PS512,
Algorithm::EdDSA,
])
}
#[derive(Debug, Clone)]
pub(crate) enum TenantKind {
JwksUrl {
jwks_url: Url,
jwks_refresh_interval: Duration,
http_client: Client,
},
Static {
jwks: JwkSet,
},
}
#[derive(Debug, Clone)]
pub struct TenantConfiguration {
pub(crate) identifier: String,
pub(crate) claims_validation_spec: ClaimsValidationSpec,
pub(crate) allowed_algorithms: HashSet<Algorithm>,
pub(crate) kind: TenantKind,
}
impl TenantConfiguration {
pub fn builder(issuer_url: impl Into<String>) -> TenantConfigurationBuilder {
TenantConfigurationBuilder::new(issuer_url)
}
pub fn static_builder(jwks: impl Into<String>) -> TenantStaticConfigurationBuilder {
TenantStaticConfigurationBuilder::new(jwks)
}
}
pub struct TenantConfigurationBuilder {
issuer_url: String,
identifier: Option<String>,
jwks_url: Option<String>,
http_client: Option<Client>,
audiences: Vec<String>,
jwk_set_refresh_interval: Option<Duration>,
claims_validation_spec: Option<ClaimsValidationSpec>,
allowed_algorithms: Option<HashSet<Algorithm>>,
}
impl TenantConfigurationBuilder {
fn new(issuer_url: impl Into<String>) -> Self {
Self {
issuer_url: issuer_url.into(),
identifier: None,
jwks_url: None,
http_client: None,
audiences: Vec::new(),
jwk_set_refresh_interval: None,
claims_validation_spec: None,
allowed_algorithms: None,
}
}
pub fn identifier(mut self, identifier: &str) -> Self {
self.identifier = Some(identifier.to_string());
self
}
pub fn jwks_url(mut self, jwks_url: impl Into<String>) -> Self {
self.jwks_url = Some(jwks_url.into());
self
}
pub fn http_client(mut self, http_client: Client) -> Self {
self.http_client = Some(http_client);
self
}
pub fn audiences(mut self, audiences: &[impl ToString]) -> Self {
self.audiences = audiences.iter().map(|aud| aud.to_string()).collect();
self
}
pub fn jwks_refresh_interval(mut self, jwk_set_refresh_interval: Duration) -> Self {
self.jwk_set_refresh_interval = Some(jwk_set_refresh_interval);
self
}
pub fn claims_validation(mut self, claims_validation: ClaimsValidationSpec) -> Self {
self.claims_validation_spec = Some(claims_validation);
self
}
pub fn allowed_algorithms(mut self, algorithms: &[Algorithm]) -> Self {
self.allowed_algorithms = Some(algorithms.iter().copied().collect());
self
}
pub async fn build(self) -> Result<TenantConfiguration, StartupError> {
let identifier = match self.identifier {
Some(id) => id,
None => self.issuer_url.clone(),
};
let issuer_url = Url::parse(&self.issuer_url)
.map_err(|_| StartupError::InvalidParameter("Invalid issuer_url format".to_string()))?;
let jwks_url = self
.jwks_url
.as_deref()
.map(|jwks_url| {
Url::parse(jwks_url).map_err(|_| {
StartupError::InvalidParameter("Invalid jwks_url format".to_string())
})
})
.transpose()?;
let http_client = self.http_client.unwrap_or_else(|| {
Client::builder()
.build()
.expect("Could not create reqwest client")
});
let oidc_config = if jwks_url.is_some() {
None
} else {
Some(
OidcDiscovery::discover(&issuer_url, http_client.clone())
.await
.map_err(|e| StartupError::OidcDiscoveryFailed(e.to_string()))?,
)
};
let claims_validation_spec = self
.claims_validation_spec
.unwrap_or(recommended_claims_spec(&self.audiences, &oidc_config));
let allowed_algorithms = self
.allowed_algorithms
.unwrap_or_else(default_allowed_algorithms);
let jwks_url = match jwks_url {
Some(jwks_url) => jwks_url,
None => match oidc_config {
Some(oidc_config) => oidc_config.jwks_uri,
None => {
return Err(StartupError::InvalidParameter(
"Failed to resolve JWKS URL".to_string(),
));
}
},
};
let kind = TenantKind::JwksUrl {
jwks_url,
jwks_refresh_interval: self
.jwk_set_refresh_interval
.unwrap_or(Duration::from_secs(60)),
http_client,
};
Ok(TenantConfiguration {
identifier,
claims_validation_spec,
allowed_algorithms,
kind,
})
}
}
pub struct TenantStaticConfigurationBuilder {
identifier: Option<String>,
audiences: Vec<String>,
claims_validation_spec: Option<ClaimsValidationSpec>,
allowed_algorithms: Option<HashSet<Algorithm>>,
jwks: String,
}
impl TenantStaticConfigurationBuilder {
fn new(jwks: impl Into<String>) -> Self {
Self {
jwks: jwks.into(),
identifier: None,
audiences: Vec::new(),
claims_validation_spec: None,
allowed_algorithms: None,
}
}
pub fn identifier(mut self, identifier: &str) -> Self {
self.identifier = Some(identifier.to_string());
self
}
pub fn audiences(mut self, audiences: &[impl ToString]) -> Self {
self.audiences = audiences.iter().map(|aud| aud.to_string()).collect();
self
}
pub fn claims_validation(mut self, claims_validation: ClaimsValidationSpec) -> Self {
self.claims_validation_spec = Some(claims_validation);
self
}
pub fn allowed_algorithms(mut self, algorithms: &[Algorithm]) -> Self {
self.allowed_algorithms = Some(algorithms.iter().copied().collect());
self
}
pub fn build(self) -> Result<TenantConfiguration, StartupError> {
let identifier = self.identifier.unwrap_or_else(|| String::from("static"));
let claims_validation_spec = self
.claims_validation_spec
.unwrap_or(recommended_claims_spec(&self.audiences, &None));
let allowed_algorithms = self
.allowed_algorithms
.unwrap_or_else(default_allowed_algorithms);
let jwks = serde_json::from_str(&self.jwks)
.map_err(|e| StartupError::InvalidParameter(format!("Failed to parse JWKS: {e}")))?;
let kind = TenantKind::Static { jwks };
Ok(TenantConfiguration {
identifier,
claims_validation_spec,
allowed_algorithms,
kind,
})
}
}
fn recommended_claims_spec(
audiences: &Vec<String>,
oidc_config: &Option<OidcConfig>,
) -> ClaimsValidationSpec {
let mut claims_spec = ClaimsValidationSpec::new().exp(true);
if !audiences.is_empty() {
claims_spec = claims_spec.aud(audiences);
}
if let Some(config) = &oidc_config {
if let Some(claims_supported) = &config.claims_supported
&& claims_supported.contains(&"nbf".to_owned())
{
claims_spec = claims_spec.nbf(true);
}
claims_spec = claims_spec.iss(config.issuer.as_str());
}
claims_spec
}
#[cfg(test)]
mod tests {
use super::*;
use crate::oidc::{MockOidcDiscovery, OidcConfig};
use std::sync::Mutex;
static MTX: Mutex<()> = Mutex::new(());
#[tokio::test]
async fn test_should_perform_oidc_discovery() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect()
.returning(|_, _| Ok(default_oidc_config()))
.once();
let result = TenantConfigurationBuilder::new("http://some-issuer.com")
.build()
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_should_skip_oidc_discovery_if_jwks_url_set() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect().never();
let result = TenantConfigurationBuilder::new("http://some-issuer.com")
.jwks_url("https://some-issuer.com/jwks")
.build()
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_should_use_issuer_as_identifier() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect()
.returning(|_, _| Ok(default_oidc_config()))
.once();
let result = TenantConfigurationBuilder::new("http://some-issuer.com")
.build()
.await;
assert!(result.is_ok());
let tenant = result.unwrap();
assert_eq!(tenant.identifier, "http://some-issuer.com");
}
#[tokio::test]
async fn test_custom_identifier_overrides_issuer() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect()
.returning(|_, _| Ok(default_oidc_config()))
.once();
let result = TenantConfigurationBuilder::new("http://some-issuer.com")
.identifier("custom-identifier")
.build()
.await;
assert!(result.is_ok());
let tenant = result.unwrap();
assert_eq!(tenant.identifier, "custom-identifier");
}
#[tokio::test]
async fn test_valid_issuer_url_required() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect().never();
let result = TenantConfigurationBuilder::new("not-a-url").build().await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
StartupError::InvalidParameter("Invalid issuer_url format".to_owned())
)
}
#[tokio::test]
async fn test_valid_jwks_url_required() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect().never();
let result = TenantConfigurationBuilder::new("https://some-issuer.com")
.jwks_url("not-a-url")
.build()
.await;
assert!(result.is_err());
assert_eq!(
result.unwrap_err(),
StartupError::InvalidParameter("Invalid jwks_url format".to_owned())
)
}
#[tokio::test]
async fn test_provides_recommended_claims_validation_spec() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect()
.returning(|_, _| Ok(default_oidc_config()))
.once();
let result = TenantConfigurationBuilder::new("https://some-issuer.com")
.audiences(&["https://some-resource-server.com"])
.build()
.await;
assert!(result.is_ok());
assert_eq!(
result.unwrap().claims_validation_spec,
ClaimsValidationSpec::new()
.exp(true)
.iss("http://some-issuer.com")
.aud(&vec!["https://some-resource-server.com".to_owned()])
);
}
#[tokio::test]
async fn test_custom_claims_validation_spec_overrides_recommended() {
let _m = MTX.lock();
let ctx = MockOidcDiscovery::discover_context();
ctx.expect()
.returning(|_, _| Ok(default_oidc_config()))
.once();
let claims_validation = ClaimsValidationSpec::new().exp(false);
let result = TenantConfigurationBuilder::new("https://some-issuer.com")
.audiences(&["https://some-resource-server.com"])
.claims_validation(claims_validation.clone())
.build()
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().claims_validation_spec, claims_validation);
}
#[test]
fn test_static_build() {
let jwks = mock_jwks();
let t = TenantStaticConfigurationBuilder::new(jwks).build().unwrap();
assert_eq!(t.identifier, "static");
assert!(matches!(t.kind, TenantKind::Static { .. }));
}
#[test]
fn test_static_build_invalid_jwks() {
let jwks = " {}";
let e = TenantStaticConfigurationBuilder::new(jwks)
.build()
.unwrap_err();
assert!(matches!(e, StartupError::InvalidParameter { .. }))
}
#[test]
fn test_static_build_custom_identifier() {
let jwks = mock_jwks();
let t = TenantStaticConfigurationBuilder::new(jwks)
.identifier("custom")
.build()
.unwrap();
assert_eq!(t.identifier, "custom");
assert!(matches!(t.kind, TenantKind::Static { .. }));
}
#[test]
fn test_static_provides_recommended_claims_validation_spec() {
let jwks = mock_jwks();
let t = TenantStaticConfigurationBuilder::new(jwks)
.audiences(&["https://some-resource-server.com"])
.build()
.unwrap();
assert_eq!(
t.claims_validation_spec,
ClaimsValidationSpec::new()
.exp(true)
.aud(&vec!["https://some-resource-server.com".to_owned()])
);
}
fn mock_jwks() -> String {
let modulus = "oEz_RrupHP9d9XiFbXLoJMwG-75Z18t4ziBy2PHTZHxkHOep7aFeNj-13NmIcL4ooj-2nxrLhWbgA2iBaWr95wKkf5peTsc-5Q6-B2uCcn9xPSQK08Y_jNVhtly3mAOdsT4Y9mQIO_oqaqEyzutypZBEu-18NkbGVwkNhG9sxvUjFXHvMoJs5iwILaDA2FhuEioIDzOy-ZjD8p928ye2v8CdPWl1xPxoBXd2KIe3RkocRDxLeeBg3wH8a9tQ5Z7fOmiXiAI8_lN57zYf078yazvLUlKzCo1pQoR25MU51d7zgI_I7H2Fb5PZGcCmfvN1Up41OfEQyMLL6JYyoP23XQ";
let exponent = "AQAB";
serde_json::json!({
"keys": [{
"kty": "RSA",
"kid": "test-kid",
"n": modulus,
"e": exponent
}]
})
.to_string()
}
fn default_oidc_config() -> OidcConfig {
OidcConfig {
jwks_uri: "http://some-issuer.com/jwks".parse::<Url>().unwrap(),
issuer: "http://some-issuer.com".to_owned(),
claims_supported: None,
}
}
}