1use async_trait::async_trait;
19use secrecy::{ExposeSecret, SecretString};
20use thiserror::Error;
21
22use crate::error::{Error, Result};
23
24#[derive(Debug, Clone, Error)]
29#[non_exhaustive]
30pub enum AuthError {
31 #[error("auth: no credential available{}", source_hint.as_ref().map(|s| format!(" (source: {s})")).unwrap_or_default())]
36 Missing {
37 source_hint: Option<String>,
40 },
41
42 #[error("auth: credential refused: {message}")]
48 Refused {
49 message: String,
51 },
52
53 #[error("auth: credential expired{}", message.as_ref().map(|m| format!(": {m}")).unwrap_or_default())]
58 Expired {
59 message: Option<String>,
61 },
62
63 #[error("auth: credential source unreachable: {message}")]
69 SourceUnreachable {
70 message: String,
72 },
73}
74
75impl AuthError {
76 #[must_use]
78 pub const fn missing() -> Self {
79 Self::Missing { source_hint: None }
80 }
81
82 pub fn missing_from(source: impl Into<String>) -> Self {
85 Self::Missing {
86 source_hint: Some(source.into()),
87 }
88 }
89
90 pub fn refused(message: impl Into<String>) -> Self {
92 Self::Refused {
93 message: message.into(),
94 }
95 }
96
97 #[must_use]
99 pub const fn expired() -> Self {
100 Self::Expired { message: None }
101 }
102
103 pub fn expired_with(message: impl Into<String>) -> Self {
105 Self::Expired {
106 message: Some(message.into()),
107 }
108 }
109
110 pub fn source_unreachable(message: impl Into<String>) -> Self {
112 Self::SourceUnreachable {
113 message: message.into(),
114 }
115 }
116}
117
118impl From<AuthError> for Error {
119 fn from(err: AuthError) -> Self {
120 Self::Auth(err)
121 }
122}
123
124#[derive(Clone, Debug)]
130pub struct Credentials {
131 pub header_name: http::HeaderName,
133 pub header_value: SecretString,
135}
136
137#[async_trait]
143pub trait CredentialProvider: Send + Sync + 'static {
144 async fn resolve(&self) -> Result<Credentials>;
148}
149
150#[derive(Debug)]
154pub struct ApiKeyProvider {
155 header_name: http::HeaderName,
156 api_key: SecretString,
157}
158
159impl ApiKeyProvider {
160 pub fn new(header_name: &str, api_key: impl Into<SecretString>) -> Result<Self> {
165 let header_name = http::HeaderName::from_bytes(header_name.as_bytes())
166 .map_err(|e| Error::config(format!("invalid header name: {e}")))?;
167 Ok(Self {
168 header_name,
169 api_key: api_key.into(),
170 })
171 }
172
173 pub fn anthropic(api_key: impl Into<SecretString>) -> Self {
175 Self {
176 header_name: http::HeaderName::from_static("x-api-key"),
177 api_key: api_key.into(),
178 }
179 }
180}
181
182#[async_trait]
183impl CredentialProvider for ApiKeyProvider {
184 async fn resolve(&self) -> Result<Credentials> {
185 Ok(Credentials {
186 header_name: self.header_name.clone(),
187 header_value: self.api_key.clone(),
188 })
189 }
190}
191
192#[derive(Debug)]
195pub struct BearerProvider {
196 token: SecretString,
197}
198
199impl BearerProvider {
200 pub fn new(token: impl Into<SecretString>) -> Self {
202 Self {
203 token: token.into(),
204 }
205 }
206}
207
208#[async_trait]
209impl CredentialProvider for BearerProvider {
210 async fn resolve(&self) -> Result<Credentials> {
211 let formatted = format!("Bearer {}", self.token.expose_secret());
212 Ok(Credentials {
213 header_name: http::header::AUTHORIZATION,
214 header_value: SecretString::from(formatted),
215 })
216 }
217}
218
219pub struct CachedCredentialProvider<P> {
235 inner: std::sync::Arc<P>,
236 ttl: std::time::Duration,
237 state: tokio::sync::Mutex<CachedState>,
238}
239
240struct CachedState {
241 cached: Option<(Credentials, std::time::Instant)>,
242}
243
244impl<P> CachedCredentialProvider<P>
245where
246 P: CredentialProvider,
247{
248 pub fn new(inner: P, ttl: std::time::Duration) -> Self {
251 Self {
252 inner: std::sync::Arc::new(inner),
253 ttl,
254 state: tokio::sync::Mutex::new(CachedState { cached: None }),
255 }
256 }
257
258 pub fn from_arc(inner: std::sync::Arc<P>, ttl: std::time::Duration) -> Self {
260 Self {
261 inner,
262 ttl,
263 state: tokio::sync::Mutex::new(CachedState { cached: None }),
264 }
265 }
266
267 pub const fn ttl(&self) -> std::time::Duration {
269 self.ttl
270 }
271}
272
273#[async_trait]
274impl<P> CredentialProvider for CachedCredentialProvider<P>
275where
276 P: CredentialProvider,
277{
278 async fn resolve(&self) -> Result<Credentials> {
279 let mut guard = self.state.lock().await;
280 if let Some((creds, fetched_at)) = &guard.cached
281 && fetched_at.elapsed() < self.ttl
282 {
283 let cached = creds.clone();
284 drop(guard);
288 return Ok(cached);
289 }
290 let fresh = self.inner.resolve().await?;
294 guard.cached = Some((fresh.clone(), std::time::Instant::now()));
295 drop(guard);
296 Ok(fresh)
297 }
298}
299
300pub struct ChainedCredentialProvider {
311 providers: Vec<std::sync::Arc<dyn CredentialProvider>>,
312}
313
314impl ChainedCredentialProvider {
315 #[must_use]
319 pub const fn new(providers: Vec<std::sync::Arc<dyn CredentialProvider>>) -> Self {
320 Self { providers }
321 }
322
323 #[must_use]
325 pub fn len(&self) -> usize {
326 self.providers.len()
327 }
328
329 #[must_use]
331 pub fn is_empty(&self) -> bool {
332 self.providers.is_empty()
333 }
334}
335
336#[async_trait]
337impl CredentialProvider for ChainedCredentialProvider {
338 async fn resolve(&self) -> Result<Credentials> {
339 for provider in &self.providers {
340 match provider.resolve().await {
341 Ok(creds) => return Ok(creds),
342 Err(Error::Auth(AuthError::Missing { .. })) => {}
346 Err(other) => return Err(other),
347 }
348 }
349 Err(AuthError::missing_from(format!(
350 "chained: {} provider(s) exhausted",
351 self.providers.len()
352 ))
353 .into())
354 }
355}
356
357#[cfg(test)]
358#[allow(clippy::unwrap_used)]
359mod tests {
360 use super::*;
361 use std::sync::Arc;
362 use std::sync::atomic::{AtomicUsize, Ordering};
363 use std::time::Duration;
364
365 struct CountingProvider {
367 calls: Arc<AtomicUsize>,
368 outcome: Outcome,
369 }
370
371 enum Outcome {
372 Ok(SecretString),
373 Missing,
374 Refused(String),
375 }
376
377 impl CountingProvider {
378 fn ok(token: &str) -> (Self, Arc<AtomicUsize>) {
379 let calls = Arc::new(AtomicUsize::new(0));
380 (
381 Self {
382 calls: calls.clone(),
383 outcome: Outcome::Ok(SecretString::from(token.to_owned())),
384 },
385 calls,
386 )
387 }
388
389 fn missing() -> (Self, Arc<AtomicUsize>) {
390 let calls = Arc::new(AtomicUsize::new(0));
391 (
392 Self {
393 calls: calls.clone(),
394 outcome: Outcome::Missing,
395 },
396 calls,
397 )
398 }
399
400 fn refused(msg: &str) -> (Self, Arc<AtomicUsize>) {
401 let calls = Arc::new(AtomicUsize::new(0));
402 (
403 Self {
404 calls: calls.clone(),
405 outcome: Outcome::Refused(msg.to_owned()),
406 },
407 calls,
408 )
409 }
410 }
411
412 #[async_trait]
413 impl CredentialProvider for CountingProvider {
414 async fn resolve(&self) -> Result<Credentials> {
415 self.calls.fetch_add(1, Ordering::SeqCst);
416 match &self.outcome {
417 Outcome::Ok(token) => Ok(Credentials {
418 header_name: http::header::AUTHORIZATION,
419 header_value: token.clone(),
420 }),
421 Outcome::Missing => Err(AuthError::missing().into()),
422 Outcome::Refused(msg) => Err(AuthError::refused(msg.clone()).into()),
423 }
424 }
425 }
426
427 #[tokio::test]
428 async fn cached_provider_serves_from_cache_within_ttl() {
429 let (inner, calls) = CountingProvider::ok("tok-1");
430 let cached = CachedCredentialProvider::new(inner, Duration::from_mins(1));
431 let _ = cached.resolve().await.unwrap();
432 let _ = cached.resolve().await.unwrap();
433 let _ = cached.resolve().await.unwrap();
434 assert_eq!(calls.load(Ordering::SeqCst), 1);
435 }
436
437 #[tokio::test]
438 async fn cached_provider_refreshes_after_ttl() {
439 let (inner, calls) = CountingProvider::ok("tok-2");
440 let cached = CachedCredentialProvider::new(inner, Duration::from_millis(20));
441 let _ = cached.resolve().await.unwrap();
442 tokio::time::sleep(Duration::from_millis(40)).await;
443 let _ = cached.resolve().await.unwrap();
444 assert_eq!(calls.load(Ordering::SeqCst), 2);
445 }
446
447 #[tokio::test]
448 async fn chained_provider_falls_through_on_missing() {
449 let (a, a_calls) = CountingProvider::missing();
450 let (b, b_calls) = CountingProvider::ok("from-b");
451 let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
452 let creds = chain.resolve().await.unwrap();
453 assert_eq!(creds.header_name, http::header::AUTHORIZATION);
454 assert_eq!(creds.header_value.expose_secret(), "from-b");
455 assert_eq!(a_calls.load(Ordering::SeqCst), 1);
456 assert_eq!(b_calls.load(Ordering::SeqCst), 1);
457 }
458
459 #[tokio::test]
460 async fn chained_provider_short_circuits_on_real_error() {
461 let (a, a_calls) = CountingProvider::refused("vault: 401");
464 let (b, b_calls) = CountingProvider::ok("from-b");
465 let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
466 let err = chain.resolve().await.unwrap_err();
467 assert!(matches!(err, Error::Auth(AuthError::Refused { .. })));
468 assert_eq!(a_calls.load(Ordering::SeqCst), 1);
469 assert_eq!(
470 b_calls.load(Ordering::SeqCst),
471 0,
472 "chain must not consult later providers after a real failure"
473 );
474 }
475
476 #[tokio::test]
477 async fn chained_provider_returns_missing_when_all_sources_exhausted() {
478 let (a, _) = CountingProvider::missing();
479 let (b, _) = CountingProvider::missing();
480 let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
481 let err = chain.resolve().await.unwrap_err();
482 assert!(matches!(err, Error::Auth(AuthError::Missing { .. })));
483 }
484}