use std::fmt;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use tokio::sync::RwLock;
#[async_trait]
pub trait TokenProvider: Send + Sync + 'static {
async fn get_token(&self) -> Result<String, OAuthClientError>;
}
#[derive(Debug)]
#[non_exhaustive]
pub enum OAuthClientError {
Discovery(String),
TokenRequest(String),
InvalidResponse(String),
BuildError(String),
}
impl fmt::Display for OAuthClientError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Discovery(msg) => write!(f, "OAuth discovery error: {}", msg),
Self::TokenRequest(msg) => write!(f, "OAuth token request error: {}", msg),
Self::InvalidResponse(msg) => write!(f, "OAuth invalid response: {}", msg),
Self::BuildError(msg) => write!(f, "OAuth builder error: {}", msg),
}
}
}
impl std::error::Error for OAuthClientError {}
#[derive(Debug, Clone)]
struct CachedToken {
access_token: String,
expires_at: Instant,
}
#[derive(Debug, serde::Deserialize)]
struct TokenResponse {
access_token: String,
#[allow(dead_code)]
token_type: String,
expires_in: Option<u64>,
#[allow(dead_code)]
scope: Option<String>,
}
#[derive(Debug, serde::Deserialize)]
struct AuthServerMetadata {
token_endpoint: String,
}
struct OAuthClientCredentialsInner {
client_id: String,
client_secret: String,
token_endpoint: String,
scopes: Option<String>,
refresh_buffer: Duration,
client: reqwest::Client,
cache: RwLock<Option<CachedToken>>,
}
#[derive(Clone)]
pub struct OAuthClientCredentials {
inner: Arc<OAuthClientCredentialsInner>,
}
impl fmt::Debug for OAuthClientCredentials {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("OAuthClientCredentials")
.field("client_id", &self.inner.client_id)
.field("token_endpoint", &self.inner.token_endpoint)
.field("scopes", &self.inner.scopes)
.field("refresh_buffer", &self.inner.refresh_buffer)
.finish()
}
}
impl OAuthClientCredentials {
pub fn builder() -> OAuthClientCredentialsBuilder {
OAuthClientCredentialsBuilder::default()
}
pub async fn discover(
issuer: &str,
client_id: impl Into<String>,
client_secret: impl Into<String>,
) -> Result<Self, OAuthClientError> {
let client = reqwest::Client::new();
let url = format!(
"{}/.well-known/oauth-authorization-server",
issuer.trim_end_matches('/')
);
let metadata: AuthServerMetadata = client
.get(&url)
.send()
.await
.map_err(|e| OAuthClientError::Discovery(e.to_string()))?
.json()
.await
.map_err(|e| OAuthClientError::Discovery(e.to_string()))?;
Self::builder()
.client_id(client_id)
.client_secret(client_secret)
.token_endpoint(metadata.token_endpoint)
.build()
.map_err(|e| OAuthClientError::Discovery(e.to_string()))
}
fn is_token_valid(token: &CachedToken, buffer: Duration) -> bool {
token
.expires_at
.checked_sub(buffer)
.is_some_and(|effective| Instant::now() < effective)
}
async fn fetch_token(&self) -> Result<CachedToken, OAuthClientError> {
use base64::Engine;
let credentials = base64::engine::general_purpose::STANDARD.encode(format!(
"{}:{}",
self.inner.client_id, self.inner.client_secret
));
let mut body = "grant_type=client_credentials".to_string();
if let Some(ref scopes) = self.inner.scopes {
body.push_str("&scope=");
body.push_str(scopes);
}
let response = self
.inner
.client
.post(&self.inner.token_endpoint)
.header("Authorization", format!("Basic {}", credentials))
.header("Content-Type", "application/x-www-form-urlencoded")
.body(body)
.send()
.await
.map_err(|e| OAuthClientError::TokenRequest(e.to_string()))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(OAuthClientError::TokenRequest(format!(
"HTTP {}: {}",
status, body
)));
}
let token_response: TokenResponse = response
.json()
.await
.map_err(|e| OAuthClientError::InvalidResponse(e.to_string()))?;
let expires_in = Duration::from_secs(token_response.expires_in.unwrap_or(3600));
let expires_at = Instant::now() + expires_in;
Ok(CachedToken {
access_token: token_response.access_token,
expires_at,
})
}
}
#[async_trait]
impl TokenProvider for OAuthClientCredentials {
async fn get_token(&self) -> Result<String, OAuthClientError> {
{
let cache = self.inner.cache.read().await;
if let Some(ref token) = *cache
&& Self::is_token_valid(token, self.inner.refresh_buffer)
{
return Ok(token.access_token.clone());
}
}
let mut cache = self.inner.cache.write().await;
if let Some(ref token) = *cache
&& Self::is_token_valid(token, self.inner.refresh_buffer)
{
return Ok(token.access_token.clone());
}
let token = self.fetch_token().await?;
let access_token = token.access_token.clone();
*cache = Some(token);
Ok(access_token)
}
}
#[derive(Default)]
pub struct OAuthClientCredentialsBuilder {
client_id: Option<String>,
client_secret: Option<String>,
token_endpoint: Option<String>,
scopes: Option<String>,
refresh_buffer: Option<Duration>,
client: Option<reqwest::Client>,
}
impl OAuthClientCredentialsBuilder {
pub fn client_id(mut self, client_id: impl Into<String>) -> Self {
self.client_id = Some(client_id.into());
self
}
pub fn client_secret(mut self, client_secret: impl Into<String>) -> Self {
self.client_secret = Some(client_secret.into());
self
}
pub fn token_endpoint(mut self, url: impl Into<String>) -> Self {
self.token_endpoint = Some(url.into());
self
}
pub fn scopes(mut self, scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
let scope_str: Vec<String> = scopes.into_iter().map(|s| s.into()).collect();
if !scope_str.is_empty() {
self.scopes = Some(scope_str.join(" "));
}
self
}
pub fn refresh_buffer(mut self, duration: Duration) -> Self {
self.refresh_buffer = Some(duration);
self
}
pub fn http_client(mut self, client: reqwest::Client) -> Self {
self.client = Some(client);
self
}
pub fn build(self) -> Result<OAuthClientCredentials, OAuthClientError> {
let client_id = self
.client_id
.ok_or_else(|| OAuthClientError::BuildError("client_id is required".into()))?;
let client_secret = self
.client_secret
.ok_or_else(|| OAuthClientError::BuildError("client_secret is required".into()))?;
let token_endpoint = self
.token_endpoint
.ok_or_else(|| OAuthClientError::BuildError("token_endpoint is required".into()))?;
let inner = OAuthClientCredentialsInner {
client_id,
client_secret,
token_endpoint,
scopes: self.scopes,
refresh_buffer: self.refresh_buffer.unwrap_or(Duration::from_secs(30)),
client: self.client.unwrap_or_default(),
cache: RwLock::new(None),
};
Ok(OAuthClientCredentials {
inner: Arc::new(inner),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test]
fn test_builder_missing_client_id() {
let err = OAuthClientCredentials::builder()
.client_secret("secret")
.token_endpoint("https://auth.example.com/token")
.build()
.unwrap_err();
assert!(err.to_string().contains("client_id"));
}
#[test]
fn test_builder_missing_client_secret() {
let err = OAuthClientCredentials::builder()
.client_id("id")
.token_endpoint("https://auth.example.com/token")
.build()
.unwrap_err();
assert!(err.to_string().contains("client_secret"));
}
#[test]
fn test_builder_missing_token_endpoint() {
let err = OAuthClientCredentials::builder()
.client_id("id")
.client_secret("secret")
.build()
.unwrap_err();
assert!(err.to_string().contains("token_endpoint"));
}
#[test]
fn test_builder_success() {
let provider = OAuthClientCredentials::builder()
.client_id("my-client")
.client_secret("my-secret")
.token_endpoint("https://auth.example.com/token")
.build()
.unwrap();
assert_eq!(provider.inner.client_id, "my-client");
assert_eq!(
provider.inner.token_endpoint,
"https://auth.example.com/token"
);
assert!(provider.inner.scopes.is_none());
assert_eq!(provider.inner.refresh_buffer, Duration::from_secs(30));
}
#[test]
fn test_builder_with_scopes() {
let provider = OAuthClientCredentials::builder()
.client_id("id")
.client_secret("secret")
.token_endpoint("https://auth.example.com/token")
.scopes(["mcp:tools", "mcp:resources"])
.build()
.unwrap();
assert_eq!(
provider.inner.scopes.as_deref(),
Some("mcp:tools mcp:resources")
);
}
#[test]
fn test_builder_with_refresh_buffer() {
let provider = OAuthClientCredentials::builder()
.client_id("id")
.client_secret("secret")
.token_endpoint("https://auth.example.com/token")
.refresh_buffer(Duration::from_secs(60))
.build()
.unwrap();
assert_eq!(provider.inner.refresh_buffer, Duration::from_secs(60));
}
#[test]
fn test_debug_impl() {
let provider = OAuthClientCredentials::builder()
.client_id("my-client")
.client_secret("secret")
.token_endpoint("https://auth.example.com/token")
.build()
.unwrap();
let debug = format!("{:?}", provider);
assert!(debug.contains("my-client"));
assert!(debug.contains("auth.example.com"));
assert!(!debug.contains("secret"));
}
#[test]
fn test_token_validity() {
let valid_token = CachedToken {
access_token: "valid".into(),
expires_at: Instant::now() + Duration::from_secs(300),
};
assert!(OAuthClientCredentials::is_token_valid(
&valid_token,
Duration::from_secs(30)
));
let expiring_soon = CachedToken {
access_token: "expiring".into(),
expires_at: Instant::now() + Duration::from_secs(10),
};
assert!(!OAuthClientCredentials::is_token_valid(
&expiring_soon,
Duration::from_secs(30)
));
let expired = CachedToken {
access_token: "expired".into(),
expires_at: Instant::now() - Duration::from_secs(10),
};
assert!(!OAuthClientCredentials::is_token_valid(
&expired,
Duration::from_secs(30)
));
}
#[test]
fn test_error_display() {
let err = OAuthClientError::Discovery("not found".into());
assert_eq!(err.to_string(), "OAuth discovery error: not found");
let err = OAuthClientError::TokenRequest("timeout".into());
assert_eq!(err.to_string(), "OAuth token request error: timeout");
let err = OAuthClientError::InvalidResponse("bad json".into());
assert_eq!(err.to_string(), "OAuth invalid response: bad json");
let err = OAuthClientError::BuildError("missing field".into());
assert_eq!(err.to_string(), "OAuth builder error: missing field");
}
#[tokio::test]
async fn test_caching_returns_same_token() {
let provider = OAuthClientCredentials::builder()
.client_id("id")
.client_secret("secret")
.token_endpoint("https://auth.example.com/token")
.build()
.unwrap();
{
let mut cache = provider.inner.cache.write().await;
*cache = Some(CachedToken {
access_token: "cached-token-123".into(),
expires_at: Instant::now() + Duration::from_secs(300),
});
}
let token = provider.get_token().await.unwrap();
assert_eq!(token, "cached-token-123");
let token2 = provider.get_token().await.unwrap();
assert_eq!(token2, "cached-token-123");
}
#[tokio::test]
async fn test_expired_token_triggers_refresh_attempt() {
let provider = OAuthClientCredentials::builder()
.client_id("id")
.client_secret("secret")
.token_endpoint("http://127.0.0.1:1/nonexistent")
.build()
.unwrap();
{
let mut cache = provider.inner.cache.write().await;
*cache = Some(CachedToken {
access_token: "expired-token".into(),
expires_at: Instant::now() - Duration::from_secs(60),
});
}
let err = provider.get_token().await.unwrap_err();
assert!(matches!(err, OAuthClientError::TokenRequest(_)));
}
#[tokio::test]
async fn test_custom_token_provider() {
let call_count = Arc::new(AtomicUsize::new(0));
let count = call_count.clone();
struct CountingProvider {
count: Arc<AtomicUsize>,
}
#[async_trait]
impl TokenProvider for CountingProvider {
async fn get_token(&self) -> Result<String, OAuthClientError> {
let n = self.count.fetch_add(1, Ordering::SeqCst);
Ok(format!("token-{}", n))
}
}
let provider = CountingProvider { count };
assert_eq!(provider.get_token().await.unwrap(), "token-0");
assert_eq!(provider.get_token().await.unwrap(), "token-1");
assert_eq!(call_count.load(Ordering::SeqCst), 2);
}
#[test]
fn test_clone() {
let provider = OAuthClientCredentials::builder()
.client_id("id")
.client_secret("secret")
.token_endpoint("https://auth.example.com/token")
.build()
.unwrap();
let cloned = provider.clone();
assert!(Arc::ptr_eq(&provider.inner, &cloned.inner));
}
}