use std::future::Future;
use std::time::{Duration, SystemTime};
use async_trait::async_trait;
use crate::domain::error::{ServiceError, StygianError};
#[derive(Debug, Clone)]
pub struct TokenSet {
pub access_token: String,
pub refresh_token: Option<String>,
pub expires_at: Option<SystemTime>,
pub scopes: Vec<String>,
}
impl TokenSet {
#[must_use]
pub fn is_expired(&self) -> bool {
let Some(exp) = self.expires_at else {
return false;
};
let threshold = SystemTime::now()
.checked_add(Duration::from_secs(60))
.unwrap_or(SystemTime::UNIX_EPOCH);
exp <= threshold
}
}
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("no token found — please run the auth flow")]
TokenNotFound,
#[error("token expired")]
TokenExpired,
#[error("token refresh failed: {0}")]
RefreshFailed(String),
#[error("token storage failed: {0}")]
StorageFailed(String),
#[error("auth flow failed: {0}")]
AuthFlowFailed(String),
#[error("invalid token: {0}")]
InvalidToken(String),
}
impl From<AuthError> for StygianError {
fn from(e: AuthError) -> Self {
Self::Service(ServiceError::AuthenticationFailed(e.to_string()))
}
}
pub trait AuthPort: Send + Sync {
fn load_token(
&self,
) -> impl Future<Output = std::result::Result<Option<TokenSet>, AuthError>> + Send;
fn refresh_token(
&self,
) -> impl Future<Output = std::result::Result<TokenSet, AuthError>> + Send;
}
pub async fn resolve_token(port: &impl AuthPort) -> std::result::Result<String, AuthError> {
let ts = port.load_token().await?.ok_or(AuthError::TokenNotFound)?;
if ts.is_expired() {
let refreshed = port.refresh_token().await?;
return Ok(refreshed.access_token);
}
Ok(ts.access_token)
}
#[async_trait]
pub trait ErasedAuthPort: Send + Sync {
async fn erased_resolve_token(&self) -> std::result::Result<String, AuthError>;
}
#[async_trait]
impl<T: AuthPort> ErasedAuthPort for T {
async fn erased_resolve_token(&self) -> std::result::Result<String, AuthError> {
resolve_token(self).await
}
}
pub struct EnvAuthPort {
var_name: String,
}
impl EnvAuthPort {
#[must_use]
pub fn new(var_name: impl Into<String>) -> Self {
Self {
var_name: var_name.into(),
}
}
}
impl AuthPort for EnvAuthPort {
async fn load_token(&self) -> std::result::Result<Option<TokenSet>, AuthError> {
match std::env::var(&self.var_name) {
Ok(token) if !token.is_empty() => Ok(Some(TokenSet {
access_token: token,
refresh_token: None,
expires_at: None,
scopes: vec![],
})),
Ok(_) | Err(_) => Ok(None),
}
}
async fn refresh_token(&self) -> std::result::Result<TokenSet, AuthError> {
Err(AuthError::TokenNotFound)
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unwrap_used, unsafe_code)] use super::*;
struct FixedToken(String);
impl AuthPort for FixedToken {
async fn load_token(&self) -> std::result::Result<Option<TokenSet>, AuthError> {
Ok(Some(TokenSet {
access_token: self.0.clone(),
refresh_token: None,
expires_at: None,
scopes: vec![],
}))
}
async fn refresh_token(&self) -> std::result::Result<TokenSet, AuthError> {
Err(AuthError::RefreshFailed("no refresh token".to_string()))
}
}
struct NoToken;
impl AuthPort for NoToken {
async fn load_token(&self) -> std::result::Result<Option<TokenSet>, AuthError> {
Ok(None)
}
async fn refresh_token(&self) -> std::result::Result<TokenSet, AuthError> {
Err(AuthError::TokenNotFound)
}
}
struct ExpiredToken {
new_token: String,
}
impl AuthPort for ExpiredToken {
async fn load_token(&self) -> std::result::Result<Option<TokenSet>, AuthError> {
Ok(Some(TokenSet {
access_token: "old_token".to_string(),
refresh_token: Some("ref".to_string()),
expires_at: SystemTime::now().checked_sub(Duration::from_secs(3600)),
scopes: vec![],
}))
}
async fn refresh_token(&self) -> std::result::Result<TokenSet, AuthError> {
Ok(TokenSet {
access_token: self.new_token.clone(),
refresh_token: None,
expires_at: None,
scopes: vec![],
})
}
}
#[test]
fn not_expired_when_no_expiry() {
let ts = TokenSet {
access_token: "tok".to_string(),
refresh_token: None,
expires_at: None,
scopes: vec![],
};
assert!(!ts.is_expired());
}
#[test]
fn expired_when_past_expiry() {
let ts = TokenSet {
access_token: "tok".to_string(),
refresh_token: None,
expires_at: SystemTime::now().checked_sub(Duration::from_secs(300)),
scopes: vec![],
};
assert!(ts.is_expired());
}
#[test]
fn not_expired_within_60s_margin() {
let ts = TokenSet {
access_token: "tok".to_string(),
refresh_token: None,
expires_at: SystemTime::now().checked_add(Duration::from_secs(30)),
scopes: vec![],
};
assert!(ts.is_expired());
}
#[test]
fn not_expired_outside_60s_margin() {
let ts = TokenSet {
access_token: "tok".to_string(),
refresh_token: None,
expires_at: SystemTime::now().checked_add(Duration::from_secs(120)),
scopes: vec![],
};
assert!(!ts.is_expired());
}
#[tokio::test]
async fn resolve_token_returns_access_token() {
let auth = FixedToken("tok_abc".to_string());
let token = resolve_token(&auth).await.unwrap();
assert_eq!(token, "tok_abc");
}
#[tokio::test]
async fn resolve_token_returns_err_when_no_token() {
let auth = NoToken;
assert!(resolve_token(&auth).await.is_err());
}
#[tokio::test]
async fn resolve_token_refreshes_when_expired() {
let auth = ExpiredToken {
new_token: "fresh_tok".to_string(),
};
let token = resolve_token(&auth).await.unwrap();
assert_eq!(token, "fresh_tok");
}
#[tokio::test]
async fn env_auth_port_loads_from_env() {
unsafe { std::env::set_var("_STYGIAN_TEST_TOKEN_1", "env_tok_xyz") };
let auth = EnvAuthPort::new("_STYGIAN_TEST_TOKEN_1");
let token = resolve_token(&auth).await.unwrap();
assert_eq!(token, "env_tok_xyz");
unsafe { std::env::remove_var("_STYGIAN_TEST_TOKEN_1") };
}
#[tokio::test]
async fn env_auth_port_returns_none_when_unset() {
let auth = EnvAuthPort::new("_STYGIAN_TEST_MISSING_VAR_9999");
let ts = auth.load_token().await.unwrap();
assert!(ts.is_none());
}
}