use crate::error::PortError;
use async_trait::async_trait;
use reqwest::Client as HttpClient;
use serde::Deserialize;
use std::fmt;
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use url::Url;
#[derive(Clone)]
pub enum AuthStrategy {
StaticToken(String),
ClientCredentials(ClientCredentialsOptions),
Provider(Arc<dyn TokenProvider>),
}
impl fmt::Debug for AuthStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
AuthStrategy::StaticToken(_) => f.debug_tuple("StaticToken").finish(),
AuthStrategy::ClientCredentials(options) => {
f.debug_tuple("ClientCredentials").field(options).finish()
}
AuthStrategy::Provider(_) => f.debug_tuple("Provider").field(&"<custom>").finish(),
}
}
}
impl AuthStrategy {
pub fn into_provider(self) -> Result<Arc<dyn TokenProvider>, PortError> {
match self {
AuthStrategy::StaticToken(token) => Ok(Arc::new(StaticTokenProvider::new(token))),
AuthStrategy::ClientCredentials(options) => {
Ok(Arc::new(ClientCredentialsTokenProvider::new(options)?))
}
AuthStrategy::Provider(provider) => Ok(provider),
}
}
}
#[derive(Clone, Debug)]
pub struct ClientCredentialsOptions {
pub client_id: String,
pub client_secret: String,
pub token_url: Url,
pub minimum_ttl: Duration,
}
#[async_trait]
pub trait TokenProvider: Send + Sync {
async fn bearer_token(&self) -> Result<String, PortError>;
}
#[derive(Debug)]
pub struct StaticTokenProvider {
token: String,
}
impl StaticTokenProvider {
pub fn new(token: impl Into<String>) -> Self {
Self { token: token.into() }
}
}
#[async_trait]
impl TokenProvider for StaticTokenProvider {
async fn bearer_token(&self) -> Result<String, PortError> {
Ok(self.token.clone())
}
}
#[derive(Clone)]
struct CachedToken {
value: String,
expires_at: Instant,
}
struct ClientCredentialsTokenProvider {
options: ClientCredentialsOptions,
http: HttpClient,
cache: Mutex<Option<CachedToken>>,
}
impl ClientCredentialsTokenProvider {
fn new(options: ClientCredentialsOptions) -> Result<Self, PortError> {
let http = HttpClient::builder().build().map_err(|err| {
PortError::Configuration(format!("failed to build OAuth client: {err}"))
})?;
Ok(Self { options, http, cache: Mutex::new(None) })
}
fn cached_token(&self) -> Option<CachedToken> {
let guard = self.cache.lock().expect("client credentials cache mutex poisoned");
guard.clone()
}
fn store_token(&self, token: CachedToken) {
let mut guard = self.cache.lock().expect("client credentials cache mutex poisoned");
*guard = Some(token);
}
fn should_refresh(&self, expires_at: Instant) -> bool {
Instant::now() + self.options.minimum_ttl >= expires_at
}
async fn fetch_token(&self) -> Result<CachedToken, PortError> {
let form = vec![
("grant_type".to_string(), "client_credentials".to_string()),
("client_id".to_string(), self.options.client_id.clone()),
("client_secret".to_string(), self.options.client_secret.clone()),
];
let response = self
.http
.post(self.options.token_url.clone())
.form(&form)
.send()
.await?
.error_for_status()?;
let payload: OAuthTokenResponse = response.json().await?;
let token = payload.access_token;
if token.trim().is_empty() {
return Err(PortError::Configuration(
"OAuth client credentials flow returned an empty access token".into(),
));
}
let expires_in = payload.expires_in.unwrap_or(3600);
let expires_at = Instant::now()
+ if expires_in <= self.options.minimum_ttl.as_secs() {
Duration::from_secs(expires_in)
} else {
Duration::from_secs(expires_in - self.options.minimum_ttl.as_secs())
};
Ok(CachedToken { value: token, expires_at })
}
}
#[async_trait]
impl TokenProvider for ClientCredentialsTokenProvider {
async fn bearer_token(&self) -> Result<String, PortError> {
if let Some(cached) = self.cached_token() {
if !self.should_refresh(cached.expires_at) {
return Ok(cached.value);
}
}
let token = self.fetch_token().await?;
let value = token.value.clone();
self.store_token(token);
Ok(value)
}
}
#[derive(Debug, Deserialize)]
struct OAuthTokenResponse {
access_token: String,
#[allow(dead_code)]
token_type: Option<String>,
expires_in: Option<u64>,
}