use crate::error::{Error, Result};
use reqwest::Client;
use serde_json::{Value, json};
use tokio::sync::RwLock;
use url::Url;
pub struct AuthManager {
api_key: String,
base_url: Url,
client: Client,
cached_token: RwLock<Option<(String, u64)>>,
}
impl AuthManager {
pub fn new(api_key: String, base_url: Url, client: Client) -> Self {
Self {
api_key,
base_url,
client,
cached_token: RwLock::new(None),
}
}
pub async fn get_token(&self) -> Result<String> {
let cached = self.cached_token.read().await;
if let Some((token, expires_at)) = cached.as_ref() {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
if now + 60 < *expires_at {
return Ok(token.clone());
}
log::debug!(
"Token expiring soon ({}s left), refreshing",
expires_at.saturating_sub(now)
);
}
drop(cached);
self.fetch_new_token().await
}
async fn fetch_new_token(&self) -> Result<String> {
let token_url = self.base_url.join("/api/auth/token")?;
let response = self
.client
.post(token_url)
.json(&json!({ "api_key": self.api_key }))
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let text = response.text().await.unwrap_or_default();
return Err(Error::Authentication(format!(
"Failed to get token ({}): {}",
status, text
)));
}
let token_data: Value = response.json().await?;
let token = token_data["token"]
.as_str()
.ok_or_else(|| Error::Authentication("No token in response".to_string()))?
.to_string();
let expires_at = Self::extract_jwt_expiry(&token).unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
+ 3600
});
let mut cached = self.cached_token.write().await;
*cached = Some((token.clone(), expires_at));
Ok(token)
}
fn extract_jwt_expiry(token: &str) -> Option<u64> {
use base64::Engine;
let parts: Vec<&str> = token.split('.').collect();
if parts.len() != 3 {
return None;
}
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(parts[1])
.ok()?;
let claims: Value = serde_json::from_slice(&payload).ok()?;
claims.get("exp").and_then(|v| v.as_u64())
}
pub async fn refresh_token(&self) -> Result<String> {
self.clear_cache().await;
self.fetch_new_token().await
}
pub async fn clear_cache(&self) {
let mut cached = self.cached_token.write().await;
*cached = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_get_token() {
let base_url = Url::parse("http://localhost:8080").unwrap();
let client = Client::new();
let auth = AuthManager::new("test-key".to_string(), base_url, client);
assert!(auth.api_key == "test-key");
}
#[tokio::test]
async fn test_refresh_token() {
let base_url = Url::parse("http://localhost:8080").unwrap();
let client = Client::new();
let auth = AuthManager::new("test-key".to_string(), base_url, client);
{
let mut cached = auth.cached_token.write().await;
*cached = Some(("old-token".to_string(), u64::MAX));
}
auth.clear_cache().await;
let cached = auth.cached_token.read().await;
assert!(cached.is_none());
}
#[tokio::test]
async fn test_clear_cache() {
let base_url = Url::parse("http://localhost:8080").unwrap();
let client = Client::new();
let auth = AuthManager::new("test-key".to_string(), base_url, client);
{
let mut cached = auth.cached_token.write().await;
*cached = Some(("old-token".to_string(), u64::MAX));
}
auth.clear_cache().await;
let cached = auth.cached_token.read().await;
assert!(cached.is_none());
}
#[tokio::test]
async fn test_expired_token_not_returned() {
let base_url = Url::parse("http://localhost:8080").unwrap();
let client = Client::new();
let auth = AuthManager::new("test-key".to_string(), base_url, client);
{
let mut cached = auth.cached_token.write().await;
let past = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.as_secs()
- 600;
*cached = Some(("expired-token".to_string(), past));
}
let result = auth.get_token().await;
assert!(
result.is_err(),
"should fail to fetch new token (no server)"
);
}
#[test]
fn test_extract_jwt_expiry() {
use base64::Engine;
let payload = serde_json::json!({"exp": 1700000000, "sub": "test"});
let encoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(serde_json::to_vec(&payload).unwrap());
let fake_jwt = format!("eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9.{encoded}.fakesig");
let exp = AuthManager::extract_jwt_expiry(&fake_jwt);
assert_eq!(exp, Some(1700000000));
}
#[test]
fn test_extract_jwt_expiry_invalid() {
assert_eq!(AuthManager::extract_jwt_expiry("not-a-jwt"), None);
assert_eq!(AuthManager::extract_jwt_expiry("a.b"), None);
assert_eq!(AuthManager::extract_jwt_expiry("a.!!!.c"), None);
}
}