use super::access_token::{AccessToken, is_token_expired};
use super::token_type::TokenType;
use crate::Result;
use async_trait::async_trait;
use chrono::Utc;
use std::sync::{
Arc,
atomic::{AtomicBool, Ordering},
};
use tokio::sync::{Notify, RwLock};
use tracing::debug;
#[async_trait]
pub trait TokenStorage: Send + Sync {
async fn token(&self) -> Result<String>;
async fn refresh_access_token(&self) -> Result<String>;
fn token_type(&self) -> Arc<dyn TokenType>;
}
pub struct MemoryTokenStorage {
access_token: Arc<RwLock<AccessToken>>,
refreshing: Arc<AtomicBool>,
notify: Arc<Notify>,
token_type: Arc<dyn TokenType>,
}
impl MemoryTokenStorage {
pub fn new(token_type: Arc<dyn TokenType>) -> Self {
MemoryTokenStorage {
access_token: Arc::new(RwLock::new(AccessToken {
access_token: String::new(),
expired_at: Utc::now(),
})),
refreshing: Arc::new(AtomicBool::new(false)),
notify: Arc::new(Notify::new()),
token_type,
}
}
}
#[async_trait]
impl TokenStorage for MemoryTokenStorage {
async fn token(&self) -> Result<String> {
{
let guard = self.access_token.read().await;
if !is_token_expired(&guard) {
return Ok(guard.access_token.clone());
}
}
if self
.refreshing
.compare_exchange(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
match self.refresh_access_token().await {
Ok(token) => {
self.refreshing.store(false, Ordering::Release);
self.notify.notify_waiters();
Ok(token)
}
Err(e) => {
self.refreshing.store(false, Ordering::Release);
self.notify.notify_waiters();
Err(e)
}
}
} else {
self.notify.notified().await;
let guard = self.access_token.read().await;
Ok(guard.access_token.clone())
}
}
async fn refresh_access_token(&self) -> Result<String> {
let mut guard = self.access_token.write().await;
if !is_token_expired(&guard) {
debug!("token already refreshed by another thread");
return Ok(guard.access_token.clone());
}
debug!("performing network request to refresh token");
let builder = self.token_type.token().await?;
guard.access_token = builder.access_token.clone();
guard.expired_at = builder.expired_at;
debug!("fresh access token: {:#?}", guard);
Ok(guard.access_token.clone())
}
fn token_type(&self) -> Arc<dyn TokenType> {
self.token_type.clone()
}
}