use std::sync::Arc;
use std::time::{Duration, Instant};
use reqwest::Client;
use serde::Deserialize;
use thiserror::Error;
use tokio::sync::Mutex;
#[derive(Debug, Error)]
pub enum TokenProviderError {
#[error("@smooai/config: OAuth token exchange failed: HTTP {status} {body}")]
OAuthFailed { status: u16, body: String },
#[error("@smooai/config: OAuth token endpoint returned no access_token")]
MissingAccessToken,
#[error("@smooai/config: OAuth request failed: {0}")]
Request(#[from] reqwest::Error),
#[error("@smooai/config: OAuth response not JSON: {0}")]
BadJson(#[from] serde_json::Error),
#[error("@smooai/config: {0}")]
InvalidArgument(String),
}
#[derive(Deserialize)]
struct TokenResponse {
access_token: Option<String>,
expires_in: Option<i64>,
}
#[derive(Debug, Clone)]
struct CachedToken {
access_token: String,
expires_at: Instant,
}
#[derive(Debug)]
pub struct TokenProvider {
auth_url: String,
client_id: String,
client_secret: String,
refresh_window: Duration,
http_client: Client,
cache: Mutex<Option<CachedToken>>,
}
impl TokenProvider {
pub fn new(auth_url: &str, client_id: &str, client_secret: &str) -> Result<Self, TokenProviderError> {
Self::with_options(
auth_url,
client_id,
client_secret,
Duration::from_secs(60),
Client::new(),
)
}
pub fn with_options(
auth_url: &str,
client_id: &str,
client_secret: &str,
refresh_window: Duration,
http_client: Client,
) -> Result<Self, TokenProviderError> {
if auth_url.is_empty() {
return Err(TokenProviderError::InvalidArgument(
"TokenProvider requires auth_url".to_string(),
));
}
if client_id.is_empty() {
return Err(TokenProviderError::InvalidArgument(
"TokenProvider requires client_id".to_string(),
));
}
if client_secret.is_empty() {
return Err(TokenProviderError::InvalidArgument(
"TokenProvider requires client_secret".to_string(),
));
}
Ok(Self {
auth_url: auth_url.trim_end_matches('/').to_string(),
client_id: client_id.to_string(),
client_secret: client_secret.to_string(),
refresh_window,
http_client,
cache: Mutex::new(None),
})
}
pub async fn get_access_token(&self) -> Result<String, TokenProviderError> {
let mut guard = self.cache.lock().await;
if let Some(cached) = guard.as_ref() {
if Instant::now()
< cached
.expires_at
.checked_sub(self.refresh_window)
.unwrap_or(cached.expires_at)
{
return Ok(cached.access_token.clone());
}
}
let token = self.refresh().await?;
*guard = Some(token.clone());
Ok(token.access_token)
}
pub async fn invalidate(&self) {
*self.cache.lock().await = None;
}
async fn refresh(&self) -> Result<CachedToken, TokenProviderError> {
let url = format!("{}/token", self.auth_url);
let form = [
("grant_type", "client_credentials"),
("provider", "client_credentials"),
("client_id", self.client_id.as_str()),
("client_secret", self.client_secret.as_str()),
];
let resp = self.http_client.post(&url).form(&form).send().await?;
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
if !status.is_success() {
return Err(TokenProviderError::OAuthFailed {
status: status.as_u16(),
body,
});
}
let parsed: TokenResponse = serde_json::from_str(&body)?;
let access_token = parsed
.access_token
.filter(|t| !t.is_empty())
.ok_or(TokenProviderError::MissingAccessToken)?;
let expires_in_secs = parsed.expires_in.filter(|n| *n > 0).unwrap_or(3600) as u64;
Ok(CachedToken {
access_token,
expires_at: Instant::now() + Duration::from_secs(expires_in_secs),
})
}
}
pub type SharedTokenProvider = Arc<TokenProvider>;