use std::{collections::HashMap, sync::Arc, time::Duration as StdDuration};
use chrono::{DateTime, Duration, Utc};
use serde::{Deserialize, Serialize};
use super::jwks::JwksCache;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub token_type: String,
pub expires_in: u64,
pub id_token: Option<String>,
pub scope: Option<String>,
}
impl TokenResponse {
pub fn new(access_token: String, token_type: String, expires_in: u64) -> Self {
Self {
access_token,
refresh_token: None,
token_type,
expires_in,
id_token: None,
scope: None,
}
}
pub fn expiry_time(&self) -> DateTime<Utc> {
Utc::now() + Duration::seconds(self.expires_in as i64)
}
pub fn is_expired(&self) -> bool {
self.expiry_time() <= Utc::now()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IdTokenClaims {
pub iss: String,
pub sub: String,
pub aud: String,
pub exp: i64,
pub iat: i64,
pub auth_time: Option<i64>,
pub nonce: Option<String>,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub name: Option<String>,
pub picture: Option<String>,
pub locale: Option<String>,
}
impl IdTokenClaims {
pub fn new(iss: String, sub: String, aud: String, exp: i64, iat: i64) -> Self {
Self {
iss,
sub,
aud,
exp,
iat,
auth_time: None,
nonce: None,
email: None,
email_verified: None,
name: None,
picture: None,
locale: None,
}
}
pub fn is_expired(&self) -> bool {
self.exp <= Utc::now().timestamp()
}
pub fn is_expiring_soon(&self, grace_seconds: i64) -> bool {
self.exp <= (Utc::now().timestamp() + grace_seconds)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UserInfo {
pub sub: String,
pub email: Option<String>,
pub email_verified: Option<bool>,
pub name: Option<String>,
pub picture: Option<String>,
pub locale: Option<String>,
}
impl UserInfo {
pub fn new(sub: String) -> Self {
Self {
sub,
email: None,
email_verified: None,
name: None,
picture: None,
locale: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct OIDCProviderConfig {
pub issuer: String,
pub authorization_endpoint: String,
pub token_endpoint: String,
pub userinfo_endpoint: Option<String>,
pub jwks_uri: String,
pub scopes_supported: Vec<String>,
pub response_types_supported: Vec<String>,
}
impl OIDCProviderConfig {
pub fn new(
issuer: String,
authorization_endpoint: String,
token_endpoint: String,
jwks_uri: String,
) -> Self {
Self {
issuer,
authorization_endpoint,
token_endpoint,
userinfo_endpoint: None,
jwks_uri,
scopes_supported: vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
],
response_types_supported: vec!["code".to_string()],
}
}
}
#[derive(Debug, Clone)]
pub struct OAuth2Client {
pub client_id: String,
client_secret: String,
pub authorization_endpoint: String,
token_endpoint: String,
pub scopes: Vec<String>,
pub use_pkce: bool,
http_client: reqwest::Client,
}
impl OAuth2Client {
pub fn new(
client_id: impl Into<String>,
client_secret: impl Into<String>,
authorization_endpoint: impl Into<String>,
token_endpoint: impl Into<String>,
) -> Self {
Self {
client_id: client_id.into(),
client_secret: client_secret.into(),
authorization_endpoint: authorization_endpoint.into(),
token_endpoint: token_endpoint.into(),
scopes: vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
],
use_pkce: false,
http_client: reqwest::Client::new(),
}
}
pub fn with_scopes(mut self, scopes: Vec<String>) -> Self {
self.scopes = scopes;
self
}
pub fn with_pkce(mut self, enabled: bool) -> Self {
self.use_pkce = enabled;
self
}
pub fn authorization_url(&self, redirect_uri: &str) -> Result<String, String> {
let state = uuid::Uuid::new_v4().to_string();
let scope = self.scopes.join(" ");
let url = format!(
"{}?client_id={}&redirect_uri={}&response_type=code&scope={}&state={}",
self.authorization_endpoint,
urlencoding::encode(&self.client_id),
urlencoding::encode(redirect_uri),
urlencoding::encode(&scope),
urlencoding::encode(&state),
);
Ok(url)
}
async fn post_token_request(&self, params: &[(&str, &str)]) -> Result<TokenResponse, String> {
let response = self
.http_client
.post(&self.token_endpoint)
.form(params)
.send()
.await
.map_err(|e| format!("Token request failed: {e}"))?;
if !response.status().is_success() {
let body = response.text().await.unwrap_or_default();
return Err(format!("Token endpoint returned error: {body}"));
}
response
.json::<TokenResponse>()
.await
.map_err(|e| format!("Failed to parse token response: {e}"))
}
pub async fn exchange_code(
&self,
code: &str,
redirect_uri: &str,
) -> Result<TokenResponse, String> {
let params = [
("grant_type", "authorization_code"),
("code", code),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
("redirect_uri", redirect_uri),
];
self.post_token_request(¶ms).await
}
pub async fn refresh_token(&self, refresh_token: &str) -> Result<TokenResponse, String> {
let params = [
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
];
self.post_token_request(¶ms).await
}
}
#[derive(Debug)]
pub struct OIDCClient {
pub config: OIDCProviderConfig,
pub client_id: String,
#[allow(dead_code)]
client_secret: String,
pub jwks_cache: Arc<JwksCache>,
http_client: reqwest::Client,
}
impl OIDCClient {
pub fn new(
config: OIDCProviderConfig,
client_id: impl Into<String>,
client_secret: impl Into<String>,
) -> Self {
let jwks_cache = Arc::new(JwksCache::new(&config.jwks_uri, StdDuration::from_secs(3600)));
Self {
config,
client_id: client_id.into(),
client_secret: client_secret.into(),
jwks_cache,
http_client: reqwest::Client::new(),
}
}
pub fn with_jwks_cache(
config: OIDCProviderConfig,
client_id: impl Into<String>,
client_secret: impl Into<String>,
jwks_cache: Arc<JwksCache>,
) -> Self {
Self {
config,
client_id: client_id.into(),
client_secret: client_secret.into(),
jwks_cache,
http_client: reqwest::Client::new(),
}
}
pub async fn verify_id_token(
&self,
id_token: &str,
expected_nonce: Option<&str>,
) -> Result<IdTokenClaims, String> {
let header = jsonwebtoken::decode_header(id_token)
.map_err(|e| format!("Invalid JWT header: {e}"))?;
let kid = header.kid.ok_or("JWT missing 'kid' in header")?;
let key = self
.jwks_cache
.get_key(&kid)
.await
.map_err(|e| format!("JWKS fetch error: {e}"))?
.ok_or_else(|| format!("No key found for kid '{kid}'"))?;
let mut validation = jsonwebtoken::Validation::new(header.alg);
validation.set_issuer(&[&self.config.issuer]);
validation.set_audience(&[&self.client_id]);
validation.set_required_spec_claims(&["exp", "iat", "iss", "aud", "sub"]);
let token_data = jsonwebtoken::decode::<IdTokenClaims>(id_token, &key, &validation)
.map_err(|e| format!("ID token validation failed: {e}"))?;
if let Some(expected) = expected_nonce {
if token_data.claims.nonce.as_deref() != Some(expected) {
return Err("Nonce mismatch".to_string());
}
}
Ok(token_data.claims)
}
pub async fn get_userinfo(&self, access_token: &str) -> Result<UserInfo, String> {
let endpoint = self
.config
.userinfo_endpoint
.as_ref()
.ok_or("No userinfo endpoint configured for this provider")?;
let response = self
.http_client
.get(endpoint)
.bearer_auth(access_token)
.send()
.await
.map_err(|e| format!("Userinfo request failed: {e}"))?;
if !response.status().is_success() {
return Err(format!("Userinfo endpoint returned {}", response.status()));
}
response
.json::<UserInfo>()
.await
.map_err(|e| format!("Failed to parse userinfo response: {e}"))
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum ProviderType {
OAuth2,
OIDC,
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::OAuth2 => write!(f, "oauth2"),
Self::OIDC => write!(f, "oidc"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthSession {
pub id: String,
pub user_id: String,
pub provider_type: ProviderType,
pub provider_name: String,
pub provider_user_id: String,
pub access_token: String,
pub refresh_token: Option<String>,
pub token_expiry: DateTime<Utc>,
pub created_at: DateTime<Utc>,
pub last_refreshed: Option<DateTime<Utc>>,
}
impl OAuthSession {
pub fn new(
user_id: String,
provider_type: ProviderType,
provider_name: String,
provider_user_id: String,
access_token: String,
token_expiry: DateTime<Utc>,
) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
user_id,
provider_type,
provider_name,
provider_user_id,
access_token,
refresh_token: None,
token_expiry,
created_at: Utc::now(),
last_refreshed: None,
}
}
pub fn is_expired(&self) -> bool {
self.token_expiry <= Utc::now()
}
pub fn is_expiring_soon(&self, grace_seconds: i64) -> bool {
self.token_expiry <= (Utc::now() + Duration::seconds(grace_seconds))
}
pub fn refresh_tokens(&mut self, access_token: String, token_expiry: DateTime<Utc>) {
self.access_token = access_token;
self.token_expiry = token_expiry;
self.last_refreshed = Some(Utc::now());
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ExternalAuthProvider {
pub id: String,
pub provider_type: ProviderType,
pub provider_name: String,
pub client_id: String,
pub client_secret_vault_path: String,
pub oidc_config: Option<OIDCProviderConfig>,
pub oauth2_config: Option<OAuth2ClientConfig>,
pub enabled: bool,
pub scopes: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct OAuth2ClientConfig {
pub authorization_endpoint: String,
pub token_endpoint: String,
pub use_pkce: bool,
}
impl ExternalAuthProvider {
pub fn new(
provider_type: ProviderType,
provider_name: impl Into<String>,
client_id: impl Into<String>,
client_secret_vault_path: impl Into<String>,
) -> Self {
Self {
id: uuid::Uuid::new_v4().to_string(),
provider_type,
provider_name: provider_name.into(),
client_id: client_id.into(),
client_secret_vault_path: client_secret_vault_path.into(),
oidc_config: None,
oauth2_config: None,
enabled: true,
scopes: vec![
"openid".to_string(),
"profile".to_string(),
"email".to_string(),
],
}
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn set_scopes(&mut self, scopes: Vec<String>) {
self.scopes = scopes;
}
}
#[derive(Debug, Clone)]
pub struct ProviderRegistry {
providers: Arc<std::sync::Mutex<HashMap<String, ExternalAuthProvider>>>,
}
impl ProviderRegistry {
pub fn new() -> Self {
Self {
providers: Arc::new(std::sync::Mutex::new(HashMap::new())),
}
}
pub fn register(&self, provider: ExternalAuthProvider) -> Result<(), String> {
let mut providers = self.providers.lock().map_err(|_| "Lock failed".to_string())?;
providers.insert(provider.provider_name.clone(), provider);
Ok(())
}
pub fn get(&self, name: &str) -> Result<Option<ExternalAuthProvider>, String> {
let providers = self.providers.lock().map_err(|_| "Lock failed".to_string())?;
Ok(providers.get(name).cloned())
}
pub fn list_enabled(&self) -> Result<Vec<ExternalAuthProvider>, String> {
let providers = self.providers.lock().map_err(|_| "Lock failed".to_string())?;
Ok(providers.values().filter(|p| p.enabled).cloned().collect())
}
pub fn disable(&self, name: &str) -> Result<bool, String> {
let mut providers = self.providers.lock().map_err(|_| "Lock failed".to_string())?;
if let Some(provider) = providers.get_mut(name) {
provider.set_enabled(false);
Ok(true)
} else {
Ok(false)
}
}
pub fn enable(&self, name: &str) -> Result<bool, String> {
let mut providers = self.providers.lock().map_err(|_| "Lock failed".to_string())?;
if let Some(provider) = providers.get_mut(name) {
provider.set_enabled(true);
Ok(true)
} else {
Ok(false)
}
}
}
impl Default for ProviderRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PKCEChallenge {
pub code_verifier: String,
pub code_challenge: String,
pub code_challenge_method: String,
}
impl PKCEChallenge {
pub fn new() -> Self {
use sha2::{Digest, Sha256};
let verifier = format!("{}", uuid::Uuid::new_v4());
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let digest = hasher.finalize();
let challenge = urlencoding::encode_binary(&digest).to_string();
Self {
code_verifier: verifier,
code_challenge: challenge,
code_challenge_method: "S256".to_string(),
}
}
pub fn verify(&self, verifier: &str) -> bool {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(verifier.as_bytes());
let digest = hasher.finalize();
let computed_challenge = urlencoding::encode_binary(&digest).to_string();
computed_challenge == self.code_challenge
}
}
impl Default for PKCEChallenge {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateParameter {
pub state: String,
pub expires_at: DateTime<Utc>,
}
impl StateParameter {
pub fn new() -> Self {
Self {
state: uuid::Uuid::new_v4().to_string(),
expires_at: Utc::now() + Duration::minutes(10),
}
}
pub fn is_expired(&self) -> bool {
self.expires_at <= Utc::now()
}
pub fn verify(&self, provided_state: &str) -> bool {
self.state == provided_state && !self.is_expired()
}
}
impl Default for StateParameter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NonceParameter {
pub nonce: String,
pub expires_at: DateTime<Utc>,
}
impl NonceParameter {
pub fn new() -> Self {
Self {
nonce: uuid::Uuid::new_v4().to_string(),
expires_at: Utc::now() + Duration::minutes(10),
}
}
pub fn is_expired(&self) -> bool {
self.expires_at <= Utc::now()
}
pub fn verify(&self, provided_nonce: &str) -> bool {
self.nonce == provided_nonce && !self.is_expired()
}
}
impl Default for NonceParameter {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct TokenRefreshScheduler {
refresh_queue: Arc<std::sync::Mutex<Vec<(String, DateTime<Utc>)>>>,
}
impl TokenRefreshScheduler {
pub fn new() -> Self {
Self {
refresh_queue: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
pub fn schedule_refresh(
&self,
session_id: String,
refresh_time: DateTime<Utc>,
) -> Result<(), String> {
let mut queue = self.refresh_queue.lock().map_err(|_| "Lock failed".to_string())?;
queue.push((session_id, refresh_time));
queue.sort_by_key(|(_, time)| *time);
Ok(())
}
pub fn get_next_refresh(&self) -> Result<Option<String>, String> {
let mut queue = self.refresh_queue.lock().map_err(|_| "Lock failed".to_string())?;
if let Some((_, refresh_time)) = queue.first() {
if *refresh_time <= Utc::now() {
let (id, _) = queue.remove(0);
return Ok(Some(id));
}
}
Ok(None)
}
pub fn cancel_refresh(&self, session_id: &str) -> Result<bool, String> {
let mut queue = self.refresh_queue.lock().map_err(|_| "Lock failed".to_string())?;
let len_before = queue.len();
queue.retain(|(id, _)| id != session_id);
Ok(queue.len() < len_before)
}
}
impl Default for TokenRefreshScheduler {
fn default() -> Self {
Self::new()
}
}
#[async_trait::async_trait]
pub trait TokenRefresher: Send + Sync {
async fn refresh_session(&self, session_id: &str) -> Result<Option<DateTime<Utc>>, String>;
}
pub struct TokenRefreshWorker {
scheduler: Arc<TokenRefreshScheduler>,
refresher: Arc<dyn TokenRefresher>,
cancel_rx: tokio::sync::watch::Receiver<bool>,
poll_interval: StdDuration,
}
impl TokenRefreshWorker {
pub fn new(
scheduler: Arc<TokenRefreshScheduler>,
refresher: Arc<dyn TokenRefresher>,
poll_interval: StdDuration,
) -> (Self, tokio::sync::watch::Sender<bool>) {
let (cancel_tx, cancel_rx) = tokio::sync::watch::channel(false);
(
Self {
scheduler,
refresher,
cancel_rx,
poll_interval,
},
cancel_tx,
)
}
pub async fn run(mut self) {
tracing::info!(
interval_secs = self.poll_interval.as_secs(),
"Token refresh worker started"
);
loop {
tokio::select! {
result = self.cancel_rx.changed() => {
if result.is_err() || *self.cancel_rx.borrow() {
tracing::info!("Token refresh worker stopped");
break;
}
},
() = tokio::time::sleep(self.poll_interval) => {
self.process_due_refreshes().await;
}
}
}
}
async fn process_due_refreshes(&self) {
while let Ok(Some(session_id)) = self.scheduler.get_next_refresh() {
match self.refresher.refresh_session(&session_id).await {
Ok(Some(new_expiry)) => {
let remaining = new_expiry - Utc::now();
let next_refresh_secs = (remaining.num_seconds() as f64 * 0.8) as i64;
let next_refresh = Utc::now() + Duration::seconds(next_refresh_secs);
if let Err(e) =
self.scheduler.schedule_refresh(session_id.clone(), next_refresh)
{
tracing::warn!(
session_id = %session_id,
error = %e,
"Failed to re-schedule token refresh"
);
}
},
Ok(None) => {
tracing::debug!(
session_id = %session_id,
"Session no longer exists, skipping refresh"
);
},
Err(e) => {
tracing::warn!(
session_id = %session_id,
error = %e,
"Token refresh failed"
);
},
}
}
}
}
#[derive(Debug, Clone)]
pub struct ProviderFailoverManager {
primary_provider: String,
fallback_providers: Vec<String>,
unavailable: Arc<std::sync::Mutex<Vec<(String, DateTime<Utc>)>>>,
}
impl ProviderFailoverManager {
pub fn new(primary: String, fallbacks: Vec<String>) -> Self {
Self {
primary_provider: primary,
fallback_providers: fallbacks,
unavailable: Arc::new(std::sync::Mutex::new(Vec::new())),
}
}
pub fn get_available_provider(&self) -> Result<String, String> {
let unavailable = self.unavailable.lock().map_err(|_| "Lock failed".to_string())?;
let now = Utc::now();
if !unavailable
.iter()
.any(|(name, exp)| name == &self.primary_provider && *exp > now)
{
return Ok(self.primary_provider.clone());
}
for fallback in &self.fallback_providers {
if !unavailable.iter().any(|(name, exp)| name == fallback && *exp > now) {
return Ok(fallback.clone());
}
}
Err("No providers available".to_string())
}
pub fn mark_unavailable(&self, provider: String, duration_seconds: u64) -> Result<(), String> {
let mut unavailable = self.unavailable.lock().map_err(|_| "Lock failed".to_string())?;
unavailable.push((provider, Utc::now() + Duration::seconds(duration_seconds as i64)));
Ok(())
}
pub fn mark_available(&self, provider: &str) -> Result<(), String> {
let mut unavailable = self.unavailable.lock().map_err(|_| "Lock failed".to_string())?;
unavailable.retain(|(name, _)| name != provider);
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OAuthAuditEvent {
pub event_type: String,
pub provider: String,
pub user_id: Option<String>,
pub status: String,
pub error: Option<String>,
pub timestamp: DateTime<Utc>,
pub metadata: HashMap<String, String>,
}
impl OAuthAuditEvent {
pub fn new(
event_type: impl Into<String>,
provider: impl Into<String>,
status: impl Into<String>,
) -> Self {
Self {
event_type: event_type.into(),
provider: provider.into(),
user_id: None,
status: status.into(),
error: None,
timestamp: Utc::now(),
metadata: HashMap::new(),
}
}
pub fn with_user_id(mut self, user_id: String) -> Self {
self.user_id = Some(user_id);
self
}
pub fn with_error(mut self, error: String) -> Self {
self.error = Some(error);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.metadata.insert(key, value);
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_response_creation() {
let token = TokenResponse::new("token123".to_string(), "Bearer".to_string(), 3600);
assert_eq!(token.access_token, "token123");
assert_eq!(token.token_type, "Bearer");
assert_eq!(token.expires_in, 3600);
}
#[test]
fn test_token_response_expiry_calculation() {
let token = TokenResponse::new("token123".to_string(), "Bearer".to_string(), 3600);
assert!(!token.is_expired());
}
#[test]
fn test_id_token_claims_creation() {
let exp = (Utc::now() + Duration::hours(1)).timestamp();
let claims = IdTokenClaims::new(
"https://provider.com".to_string(),
"user123".to_string(),
"client_id".to_string(),
exp,
Utc::now().timestamp(),
);
assert_eq!(claims.sub, "user123");
assert!(!claims.is_expired());
}
#[test]
fn test_id_token_claims_expiry() {
let exp = (Utc::now() - Duration::hours(1)).timestamp();
let claims = IdTokenClaims::new(
"https://provider.com".to_string(),
"user123".to_string(),
"client_id".to_string(),
exp,
(Utc::now() - Duration::hours(2)).timestamp(),
);
assert!(claims.is_expired());
}
#[test]
fn test_userinfo_creation() {
let userinfo = UserInfo::new("user123".to_string());
assert_eq!(userinfo.sub, "user123");
assert!(userinfo.email.is_none());
}
#[test]
fn test_oauth2_client_creation() {
let client = OAuth2Client::new(
"client_id",
"client_secret",
"https://provider.com/authorize",
"https://provider.com/token",
);
assert_eq!(client.client_id, "client_id");
}
#[test]
fn test_oauth2_client_with_scopes() {
let scopes = vec!["openid".to_string(), "profile".to_string()];
let client = OAuth2Client::new(
"client_id",
"client_secret",
"https://provider.com/authorize",
"https://provider.com/token",
)
.with_scopes(scopes.clone());
assert_eq!(client.scopes, scopes);
}
#[test]
fn test_oidc_provider_config_creation() {
let config = OIDCProviderConfig::new(
"https://provider.com".to_string(),
"https://provider.com/authorize".to_string(),
"https://provider.com/token".to_string(),
"https://provider.com/jwks".to_string(),
);
assert_eq!(config.issuer, "https://provider.com");
}
#[test]
fn test_oauth_session_creation() {
let session = OAuthSession::new(
"user_123".to_string(),
ProviderType::OIDC,
"auth0".to_string(),
"auth0|user_id".to_string(),
"access_token".to_string(),
Utc::now() + Duration::hours(1),
);
assert_eq!(session.user_id, "user_123");
assert!(!session.is_expired());
}
#[test]
fn test_oauth_session_token_refresh() {
let mut session = OAuthSession::new(
"user_123".to_string(),
ProviderType::OIDC,
"auth0".to_string(),
"auth0|user_id".to_string(),
"old_token".to_string(),
Utc::now() + Duration::hours(1),
);
let new_expiry = Utc::now() + Duration::hours(2);
session.refresh_tokens("new_token".to_string(), new_expiry);
assert_eq!(session.access_token, "new_token");
assert!(session.last_refreshed.is_some());
}
#[test]
fn test_external_auth_provider_creation() {
let provider = ExternalAuthProvider::new(
ProviderType::OIDC,
"auth0",
"client_id",
"vault/path/to/secret",
);
assert_eq!(provider.provider_name, "auth0");
assert!(provider.enabled);
}
#[test]
fn test_provider_registry_register_and_get() {
let registry = ProviderRegistry::new();
let provider =
ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "client_id", "vault/path");
registry.register(provider.clone()).unwrap();
let retrieved = registry.get("auth0").unwrap();
assert_eq!(retrieved, Some(provider));
}
#[test]
fn test_provider_registry_list_enabled() {
let registry = ProviderRegistry::new();
let provider1 = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id1", "path1");
let provider2 = ExternalAuthProvider::new(ProviderType::OAuth2, "google", "id2", "path2");
registry.register(provider1).unwrap();
registry.register(provider2).unwrap();
let enabled = registry.list_enabled().unwrap();
assert_eq!(enabled.len(), 2);
}
#[test]
fn test_provider_registry_disable_enable() {
let registry = ProviderRegistry::new();
let provider = ExternalAuthProvider::new(ProviderType::OIDC, "auth0", "id", "path");
registry.register(provider).unwrap();
registry.disable("auth0").unwrap();
let retrieved = registry.get("auth0").unwrap();
assert!(!retrieved.unwrap().enabled);
registry.enable("auth0").unwrap();
let retrieved = registry.get("auth0").unwrap();
assert!(retrieved.unwrap().enabled);
}
#[test]
fn test_pkce_challenge_generation() {
let challenge = PKCEChallenge::new();
assert!(!challenge.code_verifier.is_empty());
assert!(!challenge.code_challenge.is_empty());
assert_eq!(challenge.code_challenge_method, "S256");
}
#[test]
fn test_pkce_verification() {
let challenge = PKCEChallenge::new();
let verifier = challenge.code_verifier.clone();
assert!(challenge.verify(&verifier));
}
#[test]
fn test_pkce_verification_fails_with_wrong_verifier() {
let challenge = PKCEChallenge::new();
assert!(!challenge.verify("wrong_verifier"));
}
#[test]
fn test_state_parameter_generation() {
let state = StateParameter::new();
assert!(!state.state.is_empty());
assert!(!state.is_expired());
}
#[test]
fn test_state_parameter_verification() {
let state = StateParameter::new();
assert!(state.verify(&state.state));
}
#[test]
fn test_state_parameter_verification_fails_with_wrong_state() {
let state = StateParameter::new();
assert!(!state.verify("wrong_state"));
}
#[test]
fn test_nonce_parameter_generation() {
let nonce = NonceParameter::new();
assert!(!nonce.nonce.is_empty());
assert!(!nonce.is_expired());
}
#[test]
fn test_nonce_parameter_verification() {
let nonce = NonceParameter::new();
assert!(nonce.verify(&nonce.nonce));
}
#[test]
fn test_token_refresh_scheduler_schedule_and_retrieve() {
let scheduler = TokenRefreshScheduler::new();
let refresh_time = Utc::now() - Duration::seconds(10);
scheduler.schedule_refresh("session_1".to_string(), refresh_time).unwrap();
let next = scheduler.get_next_refresh().unwrap();
assert_eq!(next, Some("session_1".to_string()));
}
#[test]
fn test_token_refresh_scheduler_cancel() {
let scheduler = TokenRefreshScheduler::new();
let refresh_time = Utc::now() + Duration::hours(1);
scheduler.schedule_refresh("session_1".to_string(), refresh_time).unwrap();
let cancelled = scheduler.cancel_refresh("session_1").unwrap();
assert!(cancelled);
}
#[test]
fn test_failover_manager_primary_available() {
let manager = ProviderFailoverManager::new("auth0".to_string(), vec!["google".to_string()]);
let available = manager.get_available_provider().unwrap();
assert_eq!(available, "auth0");
}
#[test]
fn test_failover_manager_fallback() {
let manager = ProviderFailoverManager::new("auth0".to_string(), vec!["google".to_string()]);
manager.mark_unavailable("auth0".to_string(), 300).unwrap();
let available = manager.get_available_provider().unwrap();
assert_eq!(available, "google");
}
#[test]
fn test_failover_manager_mark_available() {
let manager = ProviderFailoverManager::new("auth0".to_string(), vec!["google".to_string()]);
manager.mark_unavailable("auth0".to_string(), 300).unwrap();
manager.mark_available("auth0").unwrap();
let available = manager.get_available_provider().unwrap();
assert_eq!(available, "auth0");
}
#[test]
fn test_oauth_audit_event_creation() {
let event = OAuthAuditEvent::new("authorization", "auth0", "success");
assert_eq!(event.event_type, "authorization");
assert_eq!(event.provider, "auth0");
assert_eq!(event.status, "success");
}
#[test]
fn test_oauth_audit_event_with_user_id() {
let event = OAuthAuditEvent::new("token_exchange", "auth0", "success")
.with_user_id("user_123".to_string());
assert_eq!(event.user_id, Some("user_123".to_string()));
}
#[test]
fn test_oauth_audit_event_with_error() {
let event = OAuthAuditEvent::new("token_exchange", "auth0", "failed")
.with_error("Provider unavailable".to_string());
assert_eq!(event.error, Some("Provider unavailable".to_string()));
}
#[test]
fn test_oauth_audit_event_with_metadata() {
let event = OAuthAuditEvent::new("authorization", "auth0", "success")
.with_metadata("ip_address".to_string(), "192.168.1.1".to_string());
assert_eq!(event.metadata.get("ip_address"), Some(&"192.168.1.1".to_string()));
}
fn mock_oauth2_client(token_endpoint: &str) -> OAuth2Client {
OAuth2Client::new(
"test_client",
"test_secret",
"https://example.com/authorize",
token_endpoint,
)
}
#[tokio::test]
async fn test_exchange_code_sends_correct_request() {
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{body_string_contains, method},
};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=authorization_code"))
.and(body_string_contains("code=auth_code_123"))
.and(body_string_contains("client_id=test_client"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "at_real",
"refresh_token": "rt_real",
"token_type": "Bearer",
"expires_in": 3600,
"id_token": "ey.header.payload",
"scope": "openid email"
})))
.mount(&mock_server)
.await;
let client = mock_oauth2_client(&format!("{}/token", mock_server.uri()));
let response = client
.exchange_code("auth_code_123", "http://localhost/callback")
.await
.unwrap();
assert_eq!(response.access_token, "at_real");
assert_eq!(response.refresh_token, Some("rt_real".to_string()));
assert_eq!(response.expires_in, 3600);
assert_eq!(response.id_token, Some("ey.header.payload".to_string()));
}
#[tokio::test]
async fn test_exchange_code_handles_error_response() {
use wiremock::{Mock, MockServer, ResponseTemplate, matchers::method};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(400).set_body_json(serde_json::json!({
"error": "invalid_grant",
"error_description": "Code expired"
})))
.mount(&mock_server)
.await;
let client = mock_oauth2_client(&format!("{}/token", mock_server.uri()));
let result = client.exchange_code("expired_code", "http://localhost/callback").await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("error"));
}
#[tokio::test]
async fn test_refresh_token_sends_correct_request() {
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{body_string_contains, method},
};
let mock_server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains("grant_type=refresh_token"))
.and(body_string_contains("refresh_token=rt_abc"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "new_at",
"refresh_token": "new_rt",
"token_type": "Bearer",
"expires_in": 3600
})))
.mount(&mock_server)
.await;
let client = mock_oauth2_client(&format!("{}/token", mock_server.uri()));
let response = client.refresh_token("rt_abc").await.unwrap();
assert_eq!(response.access_token, "new_at");
assert_eq!(response.refresh_token, Some("new_rt".to_string()));
}
fn test_oidc_config() -> OIDCProviderConfig {
OIDCProviderConfig::new(
"https://example.com".to_string(),
"https://example.com/authorize".to_string(),
"https://example.com/token".to_string(),
"https://example.com/.well-known/jwks.json".to_string(),
)
}
#[tokio::test]
async fn test_get_userinfo_success() {
use wiremock::{
Mock, MockServer, ResponseTemplate,
matchers::{header, method, path},
};
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.and(path("/userinfo"))
.and(header("Authorization", "Bearer access_token_xyz"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"sub": "user_789",
"email": "real@example.com",
"email_verified": true,
"name": "Real User",
"picture": "https://example.com/photo.jpg",
"locale": "fr-FR"
})))
.mount(&mock_server)
.await;
let mut config = test_oidc_config();
config.userinfo_endpoint = Some(format!("{}/userinfo", mock_server.uri()));
let client = OIDCClient::new(config, "client_id", "secret");
let user = client.get_userinfo("access_token_xyz").await.unwrap();
assert_eq!(user.sub, "user_789");
assert_eq!(user.email, Some("real@example.com".to_string()));
assert_eq!(user.name, Some("Real User".to_string()));
}
#[tokio::test]
async fn test_get_userinfo_no_endpoint() {
let mut config = test_oidc_config();
config.userinfo_endpoint = None;
let client = OIDCClient::new(config, "client_id", "secret");
let result = client.get_userinfo("token").await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("No userinfo endpoint"));
}
#[tokio::test]
async fn test_get_userinfo_server_error() {
use wiremock::{Mock, MockServer, ResponseTemplate, matchers::method};
let mock_server = MockServer::start().await;
Mock::given(method("GET"))
.respond_with(ResponseTemplate::new(500))
.mount(&mock_server)
.await;
let mut config = test_oidc_config();
config.userinfo_endpoint = Some(format!("{}/userinfo", mock_server.uri()));
let client = OIDCClient::new(config, "client_id", "secret");
let result = client.get_userinfo("token").await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("500"));
}
#[tokio::test]
async fn test_verify_id_token_rejects_missing_kid() {
let config = test_oidc_config();
let client = OIDCClient::new(config, "client_id", "secret");
let header = jsonwebtoken::Header::new(jsonwebtoken::Algorithm::HS256);
let claims = IdTokenClaims::new(
"https://example.com".into(),
"user_1".into(),
"client_id".into(),
(Utc::now() + Duration::hours(1)).timestamp(),
Utc::now().timestamp(),
);
let token = jsonwebtoken::encode(
&header,
&claims,
&jsonwebtoken::EncodingKey::from_secret(b"test-secret"),
)
.unwrap();
let result = client.verify_id_token(&token, None).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("kid"));
}
#[tokio::test]
async fn test_token_refresh_worker_processes_due_refresh() {
struct MockRefresher {
call_count: std::sync::atomic::AtomicU32,
}
#[async_trait::async_trait]
impl TokenRefresher for MockRefresher {
async fn refresh_session(
&self,
_session_id: &str,
) -> Result<Option<DateTime<Utc>>, String> {
self.call_count.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
Ok(Some(Utc::now() + Duration::hours(1)))
}
}
let scheduler = Arc::new(TokenRefreshScheduler::new());
scheduler
.schedule_refresh("session_1".to_string(), Utc::now() - Duration::seconds(1))
.unwrap();
let refresher = Arc::new(MockRefresher {
call_count: std::sync::atomic::AtomicU32::new(0),
});
let (worker, cancel_tx) =
TokenRefreshWorker::new(scheduler, refresher.clone(), StdDuration::from_millis(50));
let handle = tokio::spawn(worker.run());
tokio::time::sleep(StdDuration::from_millis(200)).await;
let _ = cancel_tx.send(true);
handle.await.unwrap();
assert!(refresher.call_count.load(std::sync::atomic::Ordering::Relaxed) >= 1);
}
#[tokio::test]
async fn test_token_refresh_worker_handles_missing_session() {
struct NoSessionRefresher;
#[async_trait::async_trait]
impl TokenRefresher for NoSessionRefresher {
async fn refresh_session(
&self,
_session_id: &str,
) -> Result<Option<DateTime<Utc>>, String> {
Ok(None) }
}
let scheduler = Arc::new(TokenRefreshScheduler::new());
scheduler
.schedule_refresh("missing_session".to_string(), Utc::now() - Duration::seconds(1))
.unwrap();
let refresher = Arc::new(NoSessionRefresher);
let (worker, cancel_tx) =
TokenRefreshWorker::new(scheduler, refresher, StdDuration::from_millis(50));
let handle = tokio::spawn(worker.run());
tokio::time::sleep(StdDuration::from_millis(200)).await;
let _ = cancel_tx.send(true);
handle.await.unwrap();
}
#[tokio::test]
async fn test_token_refresh_worker_cancellation() {
struct NeverCalledRefresher;
#[async_trait::async_trait]
impl TokenRefresher for NeverCalledRefresher {
async fn refresh_session(
&self,
_session_id: &str,
) -> Result<Option<DateTime<Utc>>, String> {
panic!("Should not be called");
}
}
let scheduler = Arc::new(TokenRefreshScheduler::new());
let refresher = Arc::new(NeverCalledRefresher);
let (worker, cancel_tx) =
TokenRefreshWorker::new(scheduler, refresher, StdDuration::from_secs(3600));
let handle = tokio::spawn(worker.run());
let _ = cancel_tx.send(true);
handle.await.unwrap();
}
}