use async_trait::async_trait;
use secrecy::{ExposeSecret, SecretString};
use thiserror::Error;
use crate::error::{Error, Result};
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum AuthError {
#[error("auth: no credential available{}", source_hint.as_ref().map(|s| format!(" (source: {s})")).unwrap_or_default())]
Missing {
source_hint: Option<String>,
},
#[error("auth: credential refused: {message}")]
Refused {
message: String,
},
#[error("auth: credential expired{}", message.as_ref().map(|m| format!(": {m}")).unwrap_or_default())]
Expired {
message: Option<String>,
},
#[error("auth: credential source unreachable: {message}")]
SourceUnreachable {
message: String,
},
}
impl AuthError {
#[must_use]
pub const fn missing() -> Self {
Self::Missing { source_hint: None }
}
pub fn missing_from(source: impl Into<String>) -> Self {
Self::Missing {
source_hint: Some(source.into()),
}
}
pub fn refused(message: impl Into<String>) -> Self {
Self::Refused {
message: message.into(),
}
}
#[must_use]
pub const fn expired() -> Self {
Self::Expired { message: None }
}
pub fn expired_with(message: impl Into<String>) -> Self {
Self::Expired {
message: Some(message.into()),
}
}
pub fn source_unreachable(message: impl Into<String>) -> Self {
Self::SourceUnreachable {
message: message.into(),
}
}
}
impl From<AuthError> for Error {
fn from(err: AuthError) -> Self {
Self::Auth(err)
}
}
#[derive(Clone, Debug)]
pub struct Credentials {
pub header_name: http::HeaderName,
pub header_value: SecretString,
}
#[async_trait]
pub trait CredentialProvider: Send + Sync + 'static {
async fn resolve(&self) -> Result<Credentials>;
}
#[derive(Debug)]
pub struct ApiKeyProvider {
header_name: http::HeaderName,
api_key: SecretString,
}
impl ApiKeyProvider {
pub fn new(header_name: &str, api_key: impl Into<SecretString>) -> Result<Self> {
let header_name = http::HeaderName::from_bytes(header_name.as_bytes())
.map_err(|e| Error::config(format!("invalid header name: {e}")))?;
Ok(Self {
header_name,
api_key: api_key.into(),
})
}
pub fn anthropic(api_key: impl Into<SecretString>) -> Self {
Self {
header_name: http::HeaderName::from_static("x-api-key"),
api_key: api_key.into(),
}
}
}
#[async_trait]
impl CredentialProvider for ApiKeyProvider {
async fn resolve(&self) -> Result<Credentials> {
Ok(Credentials {
header_name: self.header_name.clone(),
header_value: self.api_key.clone(),
})
}
}
#[derive(Debug)]
pub struct BearerProvider {
token: SecretString,
}
impl BearerProvider {
pub fn new(token: impl Into<SecretString>) -> Self {
Self {
token: token.into(),
}
}
}
#[async_trait]
impl CredentialProvider for BearerProvider {
async fn resolve(&self) -> Result<Credentials> {
let formatted = format!("Bearer {}", self.token.expose_secret());
Ok(Credentials {
header_name: http::header::AUTHORIZATION,
header_value: SecretString::from(formatted),
})
}
}
pub struct CachedCredentialProvider<P> {
inner: std::sync::Arc<P>,
ttl: std::time::Duration,
state: tokio::sync::Mutex<CachedState>,
}
struct CachedState {
cached: Option<(Credentials, std::time::Instant)>,
}
impl<P> CachedCredentialProvider<P>
where
P: CredentialProvider,
{
pub fn new(inner: P, ttl: std::time::Duration) -> Self {
Self {
inner: std::sync::Arc::new(inner),
ttl,
state: tokio::sync::Mutex::new(CachedState { cached: None }),
}
}
pub fn from_arc(inner: std::sync::Arc<P>, ttl: std::time::Duration) -> Self {
Self {
inner,
ttl,
state: tokio::sync::Mutex::new(CachedState { cached: None }),
}
}
pub const fn ttl(&self) -> std::time::Duration {
self.ttl
}
}
#[async_trait]
impl<P> CredentialProvider for CachedCredentialProvider<P>
where
P: CredentialProvider,
{
async fn resolve(&self) -> Result<Credentials> {
let mut guard = self.state.lock().await;
if let Some((creds, fetched_at)) = &guard.cached
&& fetched_at.elapsed() < self.ttl
{
let cached = creds.clone();
drop(guard);
return Ok(cached);
}
let fresh = self.inner.resolve().await?;
guard.cached = Some((fresh.clone(), std::time::Instant::now()));
drop(guard);
Ok(fresh)
}
}
pub struct ChainedCredentialProvider {
providers: Vec<std::sync::Arc<dyn CredentialProvider>>,
}
impl ChainedCredentialProvider {
#[must_use]
pub const fn new(providers: Vec<std::sync::Arc<dyn CredentialProvider>>) -> Self {
Self { providers }
}
#[must_use]
pub fn len(&self) -> usize {
self.providers.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.providers.is_empty()
}
}
#[async_trait]
impl CredentialProvider for ChainedCredentialProvider {
async fn resolve(&self) -> Result<Credentials> {
for provider in &self.providers {
match provider.resolve().await {
Ok(creds) => return Ok(creds),
Err(Error::Auth(AuthError::Missing { .. })) => {}
Err(other) => return Err(other),
}
}
Err(AuthError::missing_from(format!(
"chained: {} provider(s) exhausted",
self.providers.len()
))
.into())
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
struct CountingProvider {
calls: Arc<AtomicUsize>,
outcome: Outcome,
}
enum Outcome {
Ok(SecretString),
Missing,
Refused(String),
}
impl CountingProvider {
fn ok(token: &str) -> (Self, Arc<AtomicUsize>) {
let calls = Arc::new(AtomicUsize::new(0));
(
Self {
calls: calls.clone(),
outcome: Outcome::Ok(SecretString::from(token.to_owned())),
},
calls,
)
}
fn missing() -> (Self, Arc<AtomicUsize>) {
let calls = Arc::new(AtomicUsize::new(0));
(
Self {
calls: calls.clone(),
outcome: Outcome::Missing,
},
calls,
)
}
fn refused(msg: &str) -> (Self, Arc<AtomicUsize>) {
let calls = Arc::new(AtomicUsize::new(0));
(
Self {
calls: calls.clone(),
outcome: Outcome::Refused(msg.to_owned()),
},
calls,
)
}
}
#[async_trait]
impl CredentialProvider for CountingProvider {
async fn resolve(&self) -> Result<Credentials> {
self.calls.fetch_add(1, Ordering::SeqCst);
match &self.outcome {
Outcome::Ok(token) => Ok(Credentials {
header_name: http::header::AUTHORIZATION,
header_value: token.clone(),
}),
Outcome::Missing => Err(AuthError::missing().into()),
Outcome::Refused(msg) => Err(AuthError::refused(msg.clone()).into()),
}
}
}
#[tokio::test]
async fn cached_provider_serves_from_cache_within_ttl() {
let (inner, calls) = CountingProvider::ok("tok-1");
let cached = CachedCredentialProvider::new(inner, Duration::from_mins(1));
let _ = cached.resolve().await.unwrap();
let _ = cached.resolve().await.unwrap();
let _ = cached.resolve().await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn cached_provider_refreshes_after_ttl() {
let (inner, calls) = CountingProvider::ok("tok-2");
let cached = CachedCredentialProvider::new(inner, Duration::from_millis(20));
let _ = cached.resolve().await.unwrap();
tokio::time::sleep(Duration::from_millis(40)).await;
let _ = cached.resolve().await.unwrap();
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn chained_provider_falls_through_on_missing() {
let (a, a_calls) = CountingProvider::missing();
let (b, b_calls) = CountingProvider::ok("from-b");
let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
let creds = chain.resolve().await.unwrap();
assert_eq!(creds.header_name, http::header::AUTHORIZATION);
assert_eq!(creds.header_value.expose_secret(), "from-b");
assert_eq!(a_calls.load(Ordering::SeqCst), 1);
assert_eq!(b_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn chained_provider_short_circuits_on_real_error() {
let (a, a_calls) = CountingProvider::refused("vault: 401");
let (b, b_calls) = CountingProvider::ok("from-b");
let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
let err = chain.resolve().await.unwrap_err();
assert!(matches!(err, Error::Auth(AuthError::Refused { .. })));
assert_eq!(a_calls.load(Ordering::SeqCst), 1);
assert_eq!(
b_calls.load(Ordering::SeqCst),
0,
"chain must not consult later providers after a real failure"
);
}
#[tokio::test]
async fn chained_provider_returns_missing_when_all_sources_exhausted() {
let (a, _) = CountingProvider::missing();
let (b, _) = CountingProvider::missing();
let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
let err = chain.resolve().await.unwrap_err();
assert!(matches!(err, Error::Auth(AuthError::Missing { .. })));
}
}