use std::sync::Arc;
use chrono::Utc;
use tokio::sync::RwLock;
use crate::error::XApiError;
use super::{save_tokens, TokenRefreshResponse, Tokens, REFRESH_WINDOW_SECS, TOKEN_URL};
pub struct TokenManager {
tokens: Arc<RwLock<Tokens>>,
client_id: String,
http_client: reqwest::Client,
token_path: std::path::PathBuf,
refresh_lock: tokio::sync::Mutex<()>,
}
impl TokenManager {
pub fn new(tokens: Tokens, client_id: String, token_path: std::path::PathBuf) -> Self {
Self {
tokens: Arc::new(RwLock::new(tokens)),
client_id,
http_client: reqwest::Client::new(),
token_path,
refresh_lock: tokio::sync::Mutex::new(()),
}
}
pub async fn get_access_token(&self) -> Result<String, XApiError> {
self.refresh_if_needed().await?;
let tokens = self.tokens.read().await;
Ok(tokens.access_token.clone())
}
pub fn tokens_lock(&self) -> Arc<RwLock<Tokens>> {
self.tokens.clone()
}
pub async fn refresh_if_needed(&self) -> Result<(), XApiError> {
{
let tokens = self.tokens.read().await;
let seconds_until_expiry = tokens
.expires_at
.signed_duration_since(Utc::now())
.num_seconds();
if seconds_until_expiry >= REFRESH_WINDOW_SECS {
return Ok(());
}
}
let _guard = self.refresh_lock.lock().await;
{
let tokens = self.tokens.read().await;
let seconds_until_expiry = tokens
.expires_at
.signed_duration_since(Utc::now())
.num_seconds();
if seconds_until_expiry >= REFRESH_WINDOW_SECS {
return Ok(());
}
}
self.do_refresh().await
}
async fn do_refresh(&self) -> Result<(), XApiError> {
let refresh_token = {
let tokens = self.tokens.read().await;
tokens.refresh_token.clone()
};
tracing::info!("Refreshing X API access token");
let params = [
("grant_type", "refresh_token"),
("refresh_token", &refresh_token),
("client_id", &self.client_id),
];
let response = self
.http_client
.post(TOKEN_URL)
.form(¶ms)
.send()
.await
.map_err(|e| XApiError::Network { source: e })?;
if !response.status().is_success() {
let status = response.status().as_u16();
let body = response.text().await.unwrap_or_default();
tracing::error!(
status,
body_len = body.len(),
"Token refresh failed (response body redacted)"
);
return Err(XApiError::AuthExpired);
}
let body: TokenRefreshResponse = response
.json()
.await
.map_err(|e| XApiError::Network { source: e })?;
let new_tokens = Tokens {
access_token: body.access_token,
refresh_token: body.refresh_token,
expires_at: Utc::now() + chrono::Duration::seconds(body.expires_in),
scopes: body
.scope
.split_whitespace()
.map(|s| s.to_string())
.collect(),
};
tracing::info!(
expires_at = %new_tokens.expires_at,
"Token refreshed successfully"
);
{
let mut tokens = self.tokens.write().await;
*tokens = new_tokens.clone();
}
save_tokens(&new_tokens, &self.token_path).map_err(|e| {
tracing::error!(error = %e, "Failed to save refreshed tokens");
XApiError::ApiError {
status: 0,
message: format!("Failed to save tokens: {e}"),
}
})?;
Ok(())
}
}