use std::future::Future;
use std::sync::Mutex;
use std::time::{Duration, Instant};
use tokio::sync::Mutex as AsyncMutex;
const EXPIRATION_BUFFER_SECONDS: u64 = 120;
const DEFAULT_EXPIRY_SECONDS: u64 = 3600;
pub struct OAuthTokenProvider {
client_id: String,
client_secret: String,
inner: Mutex<OAuthTokenProviderInner>,
fetch_lock: Mutex<()>,
async_fetch_lock: AsyncMutex<()>,
}
struct OAuthTokenProviderInner {
access_token: Option<String>,
expires_at: Option<Instant>,
}
impl OAuthTokenProvider {
pub fn new(client_id: String, client_secret: String) -> Self {
Self {
client_id,
client_secret,
inner: Mutex::new(OAuthTokenProviderInner {
access_token: None,
expires_at: None,
}),
fetch_lock: Mutex::new(()),
async_fetch_lock: AsyncMutex::new(()),
}
}
pub fn client_id(&self) -> &str {
&self.client_id
}
pub fn client_secret(&self) -> &str {
&self.client_secret
}
pub fn set_token(&self, access_token: String, expires_in: u64) {
let mut inner = self.inner.lock().unwrap();
inner.access_token = Some(access_token);
if expires_in > 0 {
let effective_expires_in = expires_in.saturating_sub(EXPIRATION_BUFFER_SECONDS);
inner.expires_at = Some(Instant::now() + Duration::from_secs(effective_expires_in));
} else {
inner.expires_at = None;
}
}
pub fn get_token(&self) -> Option<String> {
let inner = self.inner.lock().unwrap();
if let Some(ref token) = inner.access_token {
if let Some(expires_at) = inner.expires_at {
if Instant::now() < expires_at {
return Some(token.clone());
}
} else {
return Some(token.clone());
}
}
None
}
pub fn get_or_fetch<F, E>(&self, fetch_func: F) -> Result<String, E>
where
F: FnOnce() -> Result<(String, u64), E>,
{
if let Some(token) = self.get_token() {
return Ok(token);
}
let _fetch_guard = self.fetch_lock.lock().unwrap();
if let Some(token) = self.get_token() {
return Ok(token);
}
let (access_token, expires_in) = fetch_func()?;
let effective_expires_in = if expires_in > 0 {
expires_in
} else {
DEFAULT_EXPIRY_SECONDS
};
self.set_token(access_token.clone(), effective_expires_in);
Ok(access_token)
}
pub async fn get_or_fetch_async<F, Fut, E>(&self, fetch_func: F) -> Result<String, E>
where
F: FnOnce() -> Fut,
Fut: Future<Output = Result<(String, u64), E>>,
{
if let Some(token) = self.get_token() {
return Ok(token);
}
let _fetch_guard = self.async_fetch_lock.lock().await;
if let Some(token) = self.get_token() {
return Ok(token);
}
let (access_token, expires_in) = fetch_func().await?;
let effective_expires_in = if expires_in > 0 {
expires_in
} else {
DEFAULT_EXPIRY_SECONDS
};
self.set_token(access_token.clone(), effective_expires_in);
Ok(access_token)
}
pub fn needs_refresh(&self) -> bool {
let inner = self.inner.lock().unwrap();
if inner.access_token.is_none() {
return true;
}
if let Some(expires_at) = inner.expires_at {
if Instant::now() >= expires_at {
return true;
}
}
false
}
pub fn reset(&self) {
let mut inner = self.inner.lock().unwrap();
inner.access_token = None;
inner.expires_at = None;
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::thread;
#[test]
fn test_new_provider() {
let provider =
OAuthTokenProvider::new("client_id".to_string(), "client_secret".to_string());
assert_eq!(provider.client_id(), "client_id");
assert_eq!(provider.client_secret(), "client_secret");
assert!(provider.get_token().is_none());
assert!(provider.needs_refresh());
}
#[test]
fn test_set_and_get_token() {
let provider =
OAuthTokenProvider::new("client_id".to_string(), "client_secret".to_string());
provider.set_token("test_token".to_string(), 3600);
let token = provider.get_token();
assert!(token.is_some());
assert_eq!(token.unwrap(), "test_token");
assert!(!provider.needs_refresh());
}
#[test]
fn test_expired_token() {
let provider =
OAuthTokenProvider::new("client_id".to_string(), "client_secret".to_string());
provider.set_token("test_token".to_string(), 1);
assert!(provider.get_token().is_none());
assert!(provider.needs_refresh());
}
#[test]
fn test_get_or_fetch() {
let provider =
OAuthTokenProvider::new("client_id".to_string(), "client_secret".to_string());
let result: Result<String, &str> =
provider.get_or_fetch(|| Ok(("fetched_token".to_string(), 3600)));
assert!(result.is_ok());
assert_eq!(result.unwrap(), "fetched_token");
let result2: Result<String, &str> = provider.get_or_fetch(|| {
panic!("Should not be called - token is cached");
});
assert!(result2.is_ok());
assert_eq!(result2.unwrap(), "fetched_token");
}
#[test]
fn test_reset() {
let provider =
OAuthTokenProvider::new("client_id".to_string(), "client_secret".to_string());
provider.set_token("test_token".to_string(), 3600);
assert!(provider.get_token().is_some());
provider.reset();
assert!(provider.get_token().is_none());
assert!(provider.needs_refresh());
}
#[test]
fn test_concurrent_access() {
let provider = Arc::new(OAuthTokenProvider::new(
"client_id".to_string(),
"client_secret".to_string(),
));
let fetch_count = Arc::new(AtomicUsize::new(0));
let mut handles = vec![];
for _ in 0..10 {
let provider_clone = Arc::clone(&provider);
let fetch_count_clone = Arc::clone(&fetch_count);
let handle = thread::spawn(move || {
let result: Result<String, &str> = provider_clone.get_or_fetch(|| {
fetch_count_clone.fetch_add(1, Ordering::SeqCst);
thread::sleep(Duration::from_millis(10));
Ok(("concurrent_token".to_string(), 3600))
});
assert!(result.is_ok());
assert_eq!(result.unwrap(), "concurrent_token");
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
let count = fetch_count.load(Ordering::SeqCst);
assert!(count >= 1 && count <= 3, "Fetch was called {} times", count);
}
}