use std::sync::Arc;
use reqwest::header::HeaderValue;
use crate::ClientError;
use crate::auth::AuthProvider;
#[derive(Debug)]
pub struct Chain {
providers: Vec<Arc<dyn AuthProvider>>,
}
impl Chain {
#[must_use]
pub fn new(providers: Vec<Arc<dyn AuthProvider>>) -> Self {
Self { providers }
}
}
#[async_trait::async_trait]
impl AuthProvider for Chain {
async fn authorization_header(&self) -> crate::Result<HeaderValue> {
let mut last_err: Option<ClientError> = None;
for provider in &self.providers {
match provider.authorization_header().await {
Ok(header) => return Ok(header),
Err(e) => last_err = Some(e),
}
}
Err(last_err.unwrap_or_else(|| ClientError::Auth("empty AuthProvider chain".into())))
}
async fn refresh(&self) -> crate::Result<()> {
for provider in &self.providers {
if provider.authorization_header().await.is_ok() {
return provider.refresh().await;
}
}
Ok(())
}
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
reason = "test code: unwrap on a constructor success is the expected diagnostic"
)]
mod tests {
use std::sync::Arc;
use super::{AuthProvider, Chain, ClientError, HeaderValue};
use crate::auth::{Basic, Bearer};
#[derive(Debug)]
struct AlwaysFails(&'static str);
#[async_trait::async_trait]
impl AuthProvider for AlwaysFails {
async fn authorization_header(&self) -> crate::Result<HeaderValue> {
Err(ClientError::Auth(self.0.into()))
}
async fn refresh(&self) -> crate::Result<()> {
Err(ClientError::Auth(format!("{} (refresh)", self.0)))
}
}
#[derive(Debug, Default)]
struct RefreshCounter {
calls: std::sync::atomic::AtomicUsize,
}
#[async_trait::async_trait]
impl AuthProvider for RefreshCounter {
async fn authorization_header(&self) -> crate::Result<HeaderValue> {
Ok(HeaderValue::from_static("Bearer dummy"))
}
async fn refresh(&self) -> crate::Result<()> {
self.calls.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
Ok(())
}
}
#[derive(Debug)]
struct FailingRefresher;
#[async_trait::async_trait]
impl AuthProvider for FailingRefresher {
async fn authorization_header(&self) -> crate::Result<HeaderValue> {
Ok(HeaderValue::from_static("Bearer stale"))
}
async fn refresh(&self) -> crate::Result<()> {
Err(ClientError::Auth("idp down".into()))
}
}
#[tokio::test]
async fn returns_first_successful_header() {
let chain = Chain::new(vec![
Arc::new(AlwaysFails("first-fail")),
Arc::new(Bearer::new("good-token").unwrap()),
Arc::new(Basic::new("never", "used").unwrap()),
]);
let header = chain.authorization_header().await.unwrap();
assert_eq!(header, "Bearer good-token");
}
#[tokio::test]
async fn propagates_last_error_when_all_fail() {
let chain = Chain::new(vec![
Arc::new(AlwaysFails("first")),
Arc::new(AlwaysFails("second")),
Arc::new(AlwaysFails("last")),
]);
let err = chain.authorization_header().await.unwrap_err();
assert!(
matches!(&err, ClientError::Auth(msg) if msg == "last"),
"expected Auth(\"last\"), got {err:?}"
);
}
#[tokio::test]
async fn empty_chain_errors() {
let chain = Chain::new(vec![]);
let err = chain.authorization_header().await.unwrap_err();
assert!(matches!(err, ClientError::Auth(_)), "got {err:?}");
}
#[tokio::test]
async fn default_refresh_is_noop_for_static_providers() {
let bearer = Bearer::new("opaque").unwrap();
bearer.refresh().await.unwrap();
let basic = Basic::new("alice", "pw").unwrap();
basic.refresh().await.unwrap();
}
#[tokio::test]
async fn chain_refresh_targets_only_the_first_header_producer() {
let first = Arc::new(RefreshCounter::default());
let second = Arc::new(RefreshCounter::default());
let chain = Chain::new(vec![first.clone(), second.clone()]);
chain.refresh().await.unwrap();
assert_eq!(first.calls.load(std::sync::atomic::Ordering::SeqCst), 1);
assert_eq!(
second.calls.load(std::sync::atomic::Ordering::SeqCst),
0,
"later members must not be touched when an earlier one already produces the header"
);
}
#[tokio::test]
async fn chain_refresh_skips_members_that_cannot_produce_a_header() {
let counter = Arc::new(RefreshCounter::default());
let chain = Chain::new(vec![Arc::new(AlwaysFails("nope")), counter.clone()]);
chain.refresh().await.unwrap();
assert_eq!(counter.calls.load(std::sync::atomic::Ordering::SeqCst), 1);
}
#[tokio::test]
async fn chain_refresh_propagates_error_from_targeted_provider() {
let chain = Chain::new(vec![
Arc::new(FailingRefresher),
Arc::new(Bearer::new("static").unwrap()),
]);
let err = chain.refresh().await.unwrap_err();
assert!(
matches!(&err, ClientError::Auth(msg) if msg == "idp down"),
"expected the stateful provider's refresh error to surface, got {err:?}"
);
}
#[tokio::test]
async fn chain_refresh_is_ok_when_no_member_produces_a_header() {
let chain = Chain::new(vec![
Arc::new(AlwaysFails("first")),
Arc::new(AlwaysFails("last")),
]);
chain.refresh().await.unwrap();
}
#[tokio::test]
async fn empty_chain_refresh_is_ok() {
let chain = Chain::new(vec![]);
chain.refresh().await.unwrap();
}
}