pub mod service_credentials;
mod static_credentials;
pub mod token_store;
pub mod user_credentials;
pub use service_credentials::{ServiceAccessKeyCredentials, ServiceCredentials, ServiceToken};
pub use static_credentials::StaticCredentials;
#[cfg(feature = "tokio")]
pub mod auto_refresh;
use std::{
sync::Arc,
time::{Duration, SystemTime, UNIX_EPOCH},
};
use async_trait::async_trait;
use miette::Diagnostic;
use serde::{Deserialize, Serialize};
use thiserror::Error;
fn now_secs() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Expected system time to be greater than UNIX_EPOCH")
.as_secs()
}
pub trait TokenExpiry<'a>: Clone + Serialize + Deserialize<'a> {
const EXPIRY_LEEWAY_SECONDS: u64 = 60;
const REFRESH_LEEWAY_SECONDS: u64 = 180;
const MIN_REFRESH_INTERVAL_SECONDS: u64 = 10;
fn expires_at_secs(&self) -> u64;
fn is_expired(&self) -> bool {
(now_secs() + Self::EXPIRY_LEEWAY_SECONDS) > self.expires_at_secs()
}
fn should_refresh(&self) -> bool {
(now_secs() + Self::REFRESH_LEEWAY_SECONDS) > self.expires_at_secs()
}
fn refresh_interval(&self) -> Duration {
let threshold = now_secs() + Self::REFRESH_LEEWAY_SECONDS;
let expires_at = self.expires_at_secs();
if expires_at > threshold {
Duration::from_secs(expires_at - threshold)
} else {
Duration::from_secs(Self::MIN_REFRESH_INTERVAL_SECONDS)
}
}
fn min_refresh_interval() -> Duration {
Duration::from_secs(Self::MIN_REFRESH_INTERVAL_SECONDS)
}
}
#[derive(Diagnostic, Error, Debug)]
pub enum GetTokenError {
#[error("RefreshTokenFailed: {0}")]
#[diagnostic(transparent)]
RefreshTokenFailed(Box<dyn Diagnostic + Send + Sync>),
#[error("AcquireNewTokenFailed: {0}")]
#[diagnostic(transparent)]
AcquireNewTokenFailed(Box<dyn Diagnostic + Send + Sync>),
#[error("PersistTokenError: {0}")]
#[diagnostic(transparent)]
PersistTokenError(Box<dyn Diagnostic + Send + Sync>),
#[error("Token missing or expired")]
#[diagnostic(help("Token is missing or expired"))]
MissingOrExpired,
}
#[derive(Error, Debug)]
#[error("RefreshTokenFailed: {0}")]
pub struct ClearTokenError(pub Box<dyn Diagnostic + Send + Sync>);
#[async_trait]
pub trait Credentials: Send + Sync + 'static {
type Token;
async fn get_token(&self) -> Result<Self::Token, GetTokenError>;
async fn clear_token(&self) -> Result<(), ClearTokenError>;
async fn valid(&self) -> bool {
tracing::debug!(target: "credentials", "Attempting to acquire token");
self.get_token().await.is_ok()
}
}
#[async_trait]
pub trait AutoRefreshable: Credentials {
async fn refresh(&self) -> Duration;
}
#[async_trait]
impl<C: Credentials> Credentials for Arc<C> {
type Token = C::Token;
async fn get_token(&self) -> Result<Self::Token, GetTokenError> {
self.as_ref().get_token().await
}
async fn clear_token(&self) -> Result<(), ClearTokenError> {
self.as_ref().clear_token().await
}
}
#[cfg(test)]
pub(crate) mod test_utils {
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
#[derive(Clone)]
pub(crate) struct CountingState {
pub total: Arc<AtomicUsize>,
pub current: Arc<AtomicUsize>,
pub peak: Arc<AtomicUsize>,
}
impl CountingState {
pub fn new() -> Self {
Self {
total: Arc::new(AtomicUsize::new(0)),
current: Arc::new(AtomicUsize::new(0)),
peak: Arc::new(AtomicUsize::new(0)),
}
}
pub fn enter(&self) {
self.total.fetch_add(1, Ordering::SeqCst);
let prev = self.current.fetch_add(1, Ordering::SeqCst);
self.peak.fetch_max(prev + 1, Ordering::SeqCst);
}
pub fn exit(&self) {
self.current.fetch_sub(1, Ordering::SeqCst);
}
pub fn peak(&self) -> usize {
self.peak.load(Ordering::SeqCst)
}
pub fn total(&self) -> usize {
self.total.load(Ordering::SeqCst)
}
}
}