axum_oidc_layer/
validation.rs1use std::sync::Arc;
4
5use jsonwebtoken::decode_header;
6use serde::de::DeserializeOwned;
7
8use crate::{
9 cache::{ConfigCacheKey, JwkCacheKey, JwksCache, TokenCacheKey},
10 config::{AuthenticationConfigProvider, OidcConfiguration},
11 error::OidcError,
12 jwks::{fetch_and_cache_jwks, validate_token_with_jwk},
13 token::extract_token_ttl,
14};
15
16pub async fn get_oidc_config(
30 cache: &Arc<dyn JwksCache>,
31 config_url: &str,
32 config: &(impl AuthenticationConfigProvider + Send + Sync),
33) -> Result<OidcConfiguration, OidcError> {
34 let config_cache_key = ConfigCacheKey::from_url(config_url);
35
36 if let Some(cached_config) = cache.get(config_cache_key.as_str()) {
38 return serde_json::from_str(&cached_config)
39 .map_err(|e| OidcError::CacheError(format!("Failed to deserialize cached config: {e}")));
40 }
41
42 let response = reqwest::get(config_url)
44 .await
45 .map_err(|e| OidcError::ConfigurationError(format!("Failed to fetch OIDC config: {e}")))?;
46
47 let config_doc: OidcConfiguration = response
48 .json()
49 .await
50 .map_err(|e| OidcError::ConfigurationError(format!("Failed to parse OIDC config: {e}")))?;
51
52 if let Ok(serialized) = serde_json::to_string(&config_doc) {
54 cache.set(config_cache_key.as_str(), serialized, config.get_config_cache_ttl());
55 }
56
57 Ok(config_doc)
58}
59
60pub async fn validate_and_cache_token<T>(
77 cache: &Arc<dyn JwksCache>,
78 token: &str,
79 jwks_uri: &str,
80 config: &(impl AuthenticationConfigProvider + Send + Sync),
81) -> Result<T, OidcError>
82where
83 T: DeserializeOwned + serde::Serialize + Clone + Send + Sync + 'static,
84{
85 let token_cache_key = TokenCacheKey::from_token(token);
86
87 if let Some(cached_claims) = cache.get(token_cache_key.as_str()) {
89 return serde_json::from_str(&cached_claims)
90 .map_err(|e| OidcError::CacheError(format!("Failed to deserialize cached token: {e}")));
91 }
92
93 let header = decode_header(token)
95 .map_err(|e| OidcError::InvalidToken(format!("Failed to decode JWT header: {e}")))?;
96
97 let kid = header.kid.ok_or(OidcError::MissingKid)?;
98
99 let jwk_cache_key = JwkCacheKey::from_jwks_uri_and_kid(jwks_uri, &kid);
101 let jwk = match cache.get(jwk_cache_key.as_str()) {
102 Some(cached_jwk) => serde_json::from_str(&cached_jwk)
103 .map_err(|e| OidcError::CacheError(format!("Failed to deserialize cached JWK: {e}")))?,
104 None => {
105 fetch_and_cache_jwks(cache, jwks_uri, config, &kid).await?
107 }
108 };
109
110 let validated_claims = validate_token_with_jwk::<T>(token, &jwk, header.alg)
112 .map_err(OidcError::ValidationError)?;
113
114 if let Ok(serialized) = serde_json::to_string(&validated_claims) {
116 let token_ttl = extract_token_ttl(token);
117 cache.set(token_cache_key.as_str(), serialized, token_ttl);
118 }
119
120 Ok(validated_claims)
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126 use crate::cache::InMemoryCache;
127 use crate::config::{AuthenticationConfigProvider, DEFAULT_CONFIG_TTL, DEFAULT_JWKS_TTL};
128 use std::time::Duration;
129
130 struct TestConfig;
131
132 impl AuthenticationConfigProvider for TestConfig {
133 fn get_provider_url(&self) -> String {
134 "https://example.com".to_string()
135 }
136
137 fn get_openid_configuration_url(&self) -> Option<String> {
138 None
139 }
140
141 fn get_jwks_cache_ttl(&self) -> Duration {
142 DEFAULT_JWKS_TTL
143 }
144
145 fn get_config_cache_ttl(&self) -> Duration {
146 DEFAULT_CONFIG_TTL
147 }
148 }
149
150 #[test]
151 fn test_token_cache_key_consistency() {
152 let token = "test.token.here";
153 let key1 = TokenCacheKey::from_token(token);
154 let key2 = TokenCacheKey::from_token(token);
155 assert_eq!(key1, key2);
156 }
157}