use async_trait::async_trait;
use serde::Deserialize;
use std::time::{Duration, Instant};
use tokio::sync::{Mutex, RwLock};
use crate::types::AuthError;
const DEFAULT_SKEW: Duration = Duration::from_secs(30);
#[async_trait]
pub trait TokenProvider: Send + Sync + std::fmt::Debug {
async fn get_token(&self) -> Result<String, AuthError>;
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[allow(dead_code)]
token_type: String,
expires_in: u64,
}
struct CachedToken {
access_token: String,
#[allow(dead_code)]
expires_at: Instant,
refresh_at: Instant,
}
impl CachedToken {
fn new(access_token: String, expires_in: Duration, skew: Duration) -> Self {
let expires_at = Instant::now() + expires_in;
Self {
access_token,
refresh_at: expires_at.checked_sub(skew).unwrap_or(expires_at),
expires_at,
}
}
fn is_usable(&self) -> bool {
Instant::now() < self.refresh_at
}
}
pub struct ClientCredentialsProvider {
token_endpoint: String,
client_id: String,
client_secret: String,
scope: Option<String>,
audience: Option<Vec<String>>,
cache: RwLock<Option<CachedToken>>,
refresh_lock: Mutex<()>,
http: reqwest::Client,
}
impl std::fmt::Debug for ClientCredentialsProvider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ClientCredentialsProvider")
.field("token_endpoint", &self.token_endpoint)
.field("client_id", &self.client_id)
.field("scope", &self.scope)
.field("audience", &self.audience)
.finish_non_exhaustive()
}
}
impl ClientCredentialsProvider {
pub fn new(
token_endpoint: String,
client_id: String,
client_secret: String,
scope: Option<String>,
audience: Option<Vec<String>>,
) -> Self {
Self {
token_endpoint,
client_id,
client_secret,
scope,
audience,
cache: RwLock::new(None),
refresh_lock: Mutex::new(()),
http: reqwest::Client::new(),
}
}
pub fn new_unchecked_for_test(
token_endpoint: String,
client_id: String,
client_secret: String,
scope: Option<String>,
audience: Option<Vec<String>>,
http: reqwest::Client,
) -> Self {
Self {
token_endpoint,
client_id,
client_secret,
scope,
audience,
cache: RwLock::new(None),
refresh_lock: Mutex::new(()),
http,
}
}
async fn fetch_token(&self) -> Result<CachedToken, AuthError> {
let mut params = vec![
("grant_type", "client_credentials".to_string()),
("client_id", self.client_id.clone()),
("client_secret", self.client_secret.clone()),
];
if let Some(ref scope) = self.scope {
params.push(("scope", scope.clone()));
}
if let Some(ref audience) = self.audience {
for aud in audience {
params.push(("resource", aud.clone()));
}
}
let resp = self
.http
.post(&self.token_endpoint)
.form(¶ms)
.send()
.await
.map_err(|e| AuthError::ProviderUnavailable(format!("OAuth2 request failed: {e}")))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
let sanitized = if body.len() > 128 {
format!("{}...(truncated)", &body[..128])
} else {
body
};
let message = format!("token endpoint returned {status}: {sanitized}"); return Err(AuthError::ProviderUnavailable(message));
}
let token_resp: TokenResponse = resp
.json()
.await
.map_err(|e| AuthError::ProviderUnavailable(format!("invalid OAuth2 response: {e}")))?;
Ok(CachedToken::new(
token_resp.access_token,
Duration::from_secs(token_resp.expires_in),
DEFAULT_SKEW,
))
}
}
#[async_trait]
impl TokenProvider for ClientCredentialsProvider {
async fn get_token(&self) -> Result<String, AuthError> {
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.is_usable()
{
return Ok(cached.access_token.clone());
}
}
let _guard = self.refresh_lock.lock().await;
{
let cache = self.cache.read().await;
if let Some(ref cached) = *cache
&& cached.is_usable()
{
return Ok(cached.access_token.clone());
}
}
let cached = self.fetch_token().await?;
let token = cached.access_token.clone();
{
let mut cache = self.cache.write().await;
*cache = Some(cached);
}
Ok(token)
}
}
#[cfg(test)]
mod tests {
use std::sync::Arc;
use super::*;
use wiremock::matchers::{body_string_contains, method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
fn token_response(access_token: &str, expires_in: u64) -> serde_json::Value {
serde_json::json!({
"access_token": access_token,
"token_type": "Bearer",
"expires_in": expires_in,
})
}
#[tokio::test]
async fn test_get_token_fresh() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(path("/protocol/openid-connect/token"))
.respond_with(ResponseTemplate::new(200).set_body_json(token_response("abc123", 300)))
.mount(&server)
.await;
let provider = ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "test-client".into(),
"test-secret".into(),
None,
None,
);
let token = provider.get_token().await.unwrap();
assert_eq!(token, "abc123");
}
#[tokio::test]
async fn test_get_token_uses_cache() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(token_response("cached", 300)))
.expect(1)
.mount(&server)
.await;
let provider = ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
"s".into(),
None,
None,
);
let t1 = provider.get_token().await.unwrap();
let t2 = provider.get_token().await.unwrap();
assert_eq!(t1, "cached");
assert_eq!(t2, "cached");
}
#[tokio::test]
async fn test_get_token_refreshes_when_stale() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(token_response("first", 1)))
.up_to_n_times(1)
.mount(&server)
.await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(token_response("second", 300)))
.mount(&server)
.await;
let provider = ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
"s".into(),
None,
None,
);
let t1 = provider.get_token().await.unwrap();
assert_eq!(t1, "first");
tokio::time::sleep(Duration::from_millis(1100)).await;
let t2 = provider.get_token().await.unwrap();
assert_eq!(t2, "second");
}
#[tokio::test]
async fn test_get_token_server_error() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(500))
.mount(&server)
.await;
let provider = ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
"s".into(),
None,
None,
);
let err = provider.get_token().await.unwrap_err();
assert!(matches!(err, AuthError::ProviderUnavailable(_)));
}
#[tokio::test]
async fn test_get_token_invalid_response() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(
ResponseTemplate::new(200)
.set_body_json(serde_json::json!({"error": "invalid_grant"})),
)
.mount(&server)
.await;
let provider = ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
"s".into(),
None,
None,
);
let err = provider.get_token().await.unwrap_err();
assert!(matches!(err, AuthError::ProviderUnavailable(_)));
}
#[tokio::test]
async fn test_get_token_sends_audience_as_resource() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.and(body_string_contains(
"resource=https%3A%2F%2Fapi.example.com",
))
.respond_with(
ResponseTemplate::new(200).set_body_json(token_response("aud-token", 300)),
)
.mount(&server)
.await;
let provider = ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
"s".into(),
None,
Some(vec!["https://api.example.com".into()]),
);
let token = provider.get_token().await.unwrap();
assert_eq!(token, "aud-token");
}
#[tokio::test]
async fn test_single_flight_concurrent_callers() {
let server = MockServer::start().await;
Mock::given(method("POST"))
.respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
"access_token": "single-flight-token",
"token_type": "Bearer",
"expires_in": 300,
})))
.expect(1)
.mount(&server)
.await;
let provider = Arc::new(ClientCredentialsProvider::new(
format!("{}/protocol/openid-connect/token", server.uri()), "c".into(),
"s".into(),
None,
None,
));
let mut handles = vec![];
for _ in 0..5 {
let p = Arc::clone(&provider);
handles.push(tokio::spawn(async move { p.get_token().await }));
}
for h in handles {
let token = h.await.unwrap().unwrap();
assert_eq!(token, "single-flight-token");
}
}
}