Skip to main content

entelix_core/
auth.rs

1//! Credential resolution for transports.
2//!
3//! Per invariant 10, credentials live exclusively in this module and are
4//! plumbed through `Transport`. `ExecutionContext` does NOT embed a
5//! [`CredentialProvider`], so `Tool::execute` never sees a token.
6//!
7//! Two ready-made impls cover the two header conventions used by
8//! every shipped provider:
9//! - [`ApiKeyProvider`] — emits a custom header (e.g. Anthropic
10//!   `x-api-key: <key>`).
11//! - [`BearerProvider`] — emits `authorization: Bearer <token>`.
12//!
13//! Failures surface as [`Error::Auth`] carrying a typed [`AuthError`]
14//! so credential-chain bugs (missing keys, expired tokens, refused
15//! refresh) are distinguishable from generic provider HTTP failures
16//! at the application layer.
17
18use async_trait::async_trait;
19use secrecy::{ExposeSecret, SecretString};
20use thiserror::Error;
21
22use crate::error::{Error, Result};
23
24/// Typed credential failure. Public APIs raise [`Error::Auth`] which
25/// wraps this enum; downstream layers (retry policies, circuit
26/// breakers, dashboards) can match on the variant rather than on a
27/// stringly-typed `Error::Provider` blob.
28#[derive(Debug, Clone, Error)]
29#[non_exhaustive]
30pub enum AuthError {
31    /// No credential is configured for the requested scope. Most
32    /// often a deployment-time misconfiguration: the operator forgot
33    /// to wire a [`CredentialProvider`] into the transport, or a
34    /// chained provider exhausted every source without finding one.
35    #[error("auth: no credential available{}", source_hint.as_ref().map(|s| format!(" (source: {s})")).unwrap_or_default())]
36    Missing {
37        /// Human-readable hint about which source was expected
38        /// (`"env:ANTHROPIC_API_KEY"`, `"vault:secret/llm"`).
39        source_hint: Option<String>,
40    },
41
42    /// The credential resolved successfully but the provider rejected
43    /// it (HTTP 401/403, vendor-specific "invalid token" payloads).
44    /// Distinct from [`Self::Expired`] because retries against the
45    /// same source will keep failing — the operator must rotate the
46    /// secret or fix the IAM grant.
47    #[error("auth: credential refused: {message}")]
48    Refused {
49        /// Provider-supplied rejection message, normalised.
50        message: String,
51    },
52
53    /// The credential's TTL elapsed and the refresh path failed (or
54    /// is not configured). Caller can react by triggering a
55    /// rotation; downstream retry policies often back off briefly
56    /// and retry once.
57    #[error("auth: credential expired{}", message.as_ref().map(|m| format!(": {m}")).unwrap_or_default())]
58    Expired {
59        /// Optional detail (`"refresh endpoint returned 503"`).
60        message: Option<String>,
61    },
62
63    /// Resolving the credential required talking to a remote service
64    /// (vault, IMDS, KMS, OAuth refresh endpoint) and that service
65    /// was unreachable. Distinct from [`Self::Refused`] because the
66    /// credential itself may still be valid; the issue is transport
67    /// to the credential source.
68    #[error("auth: credential source unreachable: {message}")]
69    SourceUnreachable {
70        /// Description of the failed source call.
71        message: String,
72    },
73}
74
75impl AuthError {
76    /// Build a `Missing` variant with no source hint.
77    #[must_use]
78    pub const fn missing() -> Self {
79        Self::Missing { source_hint: None }
80    }
81
82    /// Build a `Missing` variant labelled with the source the caller
83    /// expected (`"env:OPENAI_API_KEY"`, `"chained:[env, vault]"`).
84    pub fn missing_from(source: impl Into<String>) -> Self {
85        Self::Missing {
86            source_hint: Some(source.into()),
87        }
88    }
89
90    /// Build a `Refused` variant from the provider's rejection message.
91    pub fn refused(message: impl Into<String>) -> Self {
92        Self::Refused {
93            message: message.into(),
94        }
95    }
96
97    /// Build an `Expired` variant with no extra detail.
98    #[must_use]
99    pub const fn expired() -> Self {
100        Self::Expired { message: None }
101    }
102
103    /// Build an `Expired` variant with refresh-path detail.
104    pub fn expired_with(message: impl Into<String>) -> Self {
105        Self::Expired {
106            message: Some(message.into()),
107        }
108    }
109
110    /// Build a `SourceUnreachable` variant.
111    pub fn source_unreachable(message: impl Into<String>) -> Self {
112        Self::SourceUnreachable {
113            message: message.into(),
114        }
115    }
116}
117
118impl From<AuthError> for Error {
119    fn from(err: AuthError) -> Self {
120        Self::Auth(err)
121    }
122}
123
124/// Header pair a transport adds immediately before sending.
125///
126/// The value is a `SecretString` so logging or `Debug` output never leaks
127/// it; the transport calls [`ExposeSecret::expose_secret`] only when
128/// assembling the wire request.
129#[derive(Clone, Debug)]
130pub struct Credentials {
131    /// HTTP header name (`x-api-key`, `authorization`, etc.).
132    pub header_name: http::HeaderName,
133    /// Secret-wrapped header value. Use `expose_secret()` at send time.
134    pub header_value: SecretString,
135}
136
137/// Async source-of-truth for credentials.
138///
139/// Implementors may cache, refresh OAuth tokens, call a vault, etc. The
140/// transport calls `resolve()` once per request and discards the result
141/// after the headers are written.
142#[async_trait]
143pub trait CredentialProvider: Send + Sync + 'static {
144    /// Resolve current credentials. Long-running impls should respect
145    /// `tokio` cancellation in their internals; the transport supplies the
146    /// `ExecutionContext` indirectly via the surrounding async task.
147    async fn resolve(&self) -> Result<Credentials>;
148}
149
150/// Static API-key provider. The header name is configurable so this works
151/// for both Anthropic (`x-api-key`) and any other vendor that uses a
152/// non-`Authorization` header.
153#[derive(Debug)]
154pub struct ApiKeyProvider {
155    header_name: http::HeaderName,
156    api_key: SecretString,
157}
158
159impl ApiKeyProvider {
160    /// Construct from a header name and a raw key string.
161    ///
162    /// Returns `Error::Config` if `header_name` cannot be parsed as a valid
163    /// HTTP header name.
164    pub fn new(header_name: &str, api_key: impl Into<SecretString>) -> Result<Self> {
165        let header_name = http::HeaderName::from_bytes(header_name.as_bytes())
166            .map_err(|e| Error::config(format!("invalid header name: {e}")))?;
167        Ok(Self {
168            header_name,
169            api_key: api_key.into(),
170        })
171    }
172
173    /// Convenience: Anthropic-style `x-api-key` provider.
174    pub fn anthropic(api_key: impl Into<SecretString>) -> Self {
175        Self {
176            header_name: http::HeaderName::from_static("x-api-key"),
177            api_key: api_key.into(),
178        }
179    }
180}
181
182#[async_trait]
183impl CredentialProvider for ApiKeyProvider {
184    async fn resolve(&self) -> Result<Credentials> {
185        Ok(Credentials {
186            header_name: self.header_name.clone(),
187            header_value: self.api_key.clone(),
188        })
189    }
190}
191
192/// `Authorization: Bearer <token>` provider. Used by `OpenAI`, Gemini, and
193/// most cloud transports as the inner credential.
194#[derive(Debug)]
195pub struct BearerProvider {
196    token: SecretString,
197}
198
199impl BearerProvider {
200    /// Construct from a raw token string.
201    pub fn new(token: impl Into<SecretString>) -> Self {
202        Self {
203            token: token.into(),
204        }
205    }
206}
207
208#[async_trait]
209impl CredentialProvider for BearerProvider {
210    async fn resolve(&self) -> Result<Credentials> {
211        let formatted = format!("Bearer {}", self.token.expose_secret());
212        Ok(Credentials {
213            header_name: http::header::AUTHORIZATION,
214            header_value: SecretString::from(formatted),
215        })
216    }
217}
218
219/// TTL cache wrapping any inner [`CredentialProvider`]. Resolves
220/// once, hands back the cached value until `ttl` elapses, then
221/// refreshes by calling the inner provider exactly once even under
222/// concurrent load (concurrent waiters share the in-flight refresh
223/// future).
224///
225/// The wrapper is the recommended baseline for production credential
226/// chains: short-lived bearer tokens (OAuth, AWS STS, Azure AAD) all
227/// expose a TTL, and refusing to cache hammers the credential source
228/// once per request.
229///
230/// On refresh failure the cache surfaces the inner error and does
231/// **not** poison the slot — a subsequent call retries. This keeps
232/// transient credential-source outages from cascading into
233/// permanent agent failure.
234pub struct CachedCredentialProvider<P> {
235    inner: std::sync::Arc<P>,
236    ttl: std::time::Duration,
237    state: tokio::sync::Mutex<CachedState>,
238}
239
240struct CachedState {
241    cached: Option<(Credentials, std::time::Instant)>,
242}
243
244impl<P> CachedCredentialProvider<P>
245where
246    P: CredentialProvider,
247{
248    /// Wrap `inner` with a TTL cache. The first call to `resolve`
249    /// populates the cache; subsequent calls within `ttl` reuse it.
250    pub fn new(inner: P, ttl: std::time::Duration) -> Self {
251        Self {
252            inner: std::sync::Arc::new(inner),
253            ttl,
254            state: tokio::sync::Mutex::new(CachedState { cached: None }),
255        }
256    }
257
258    /// Convenience constructor for impls already wrapped in `Arc`.
259    pub fn from_arc(inner: std::sync::Arc<P>, ttl: std::time::Duration) -> Self {
260        Self {
261            inner,
262            ttl,
263            state: tokio::sync::Mutex::new(CachedState { cached: None }),
264        }
265    }
266
267    /// Effective TTL.
268    pub const fn ttl(&self) -> std::time::Duration {
269        self.ttl
270    }
271}
272
273#[async_trait]
274impl<P> CredentialProvider for CachedCredentialProvider<P>
275where
276    P: CredentialProvider,
277{
278    async fn resolve(&self) -> Result<Credentials> {
279        let mut guard = self.state.lock().await;
280        if let Some((creds, fetched_at)) = &guard.cached
281            && fetched_at.elapsed() < self.ttl
282        {
283            let cached = creds.clone();
284            // Drop the guard before returning so callers waiting
285            // on the lock don't block on the cache-hit path's
286            // implicit drop at end-of-scope.
287            drop(guard);
288            return Ok(cached);
289        }
290        // Slot is empty or stale — refresh under the lock so
291        // concurrent callers share the result rather than pile on
292        // the credential source.
293        let fresh = self.inner.resolve().await?;
294        guard.cached = Some((fresh.clone(), std::time::Instant::now()));
295        drop(guard);
296        Ok(fresh)
297    }
298}
299
300/// Try a sequence of [`CredentialProvider`]s in order, returning
301/// the first one that resolves successfully. A provider that
302/// returns [`AuthError::Missing`] (the configured "not my source"
303/// signal) is skipped; any other [`enum@Error`] short-circuits and is
304/// returned to the caller — failed-but-real credential sources
305/// must surface their failure rather than silently fall through.
306///
307/// Typical layout: try environment first, fall back to vault, fall
308/// back to instance metadata. The chain is built once at
309/// transport-construction time; `resolve` is hot-path safe.
310pub struct ChainedCredentialProvider {
311    providers: Vec<std::sync::Arc<dyn CredentialProvider>>,
312}
313
314impl ChainedCredentialProvider {
315    /// Build a chain from the supplied provider list. An empty list
316    /// is permitted but pointless — every `resolve` call returns
317    /// [`AuthError::Missing`].
318    #[must_use]
319    pub const fn new(providers: Vec<std::sync::Arc<dyn CredentialProvider>>) -> Self {
320        Self { providers }
321    }
322
323    /// Number of providers in the chain.
324    #[must_use]
325    pub fn len(&self) -> usize {
326        self.providers.len()
327    }
328
329    /// True when no providers are registered.
330    #[must_use]
331    pub fn is_empty(&self) -> bool {
332        self.providers.is_empty()
333    }
334}
335
336#[async_trait]
337impl CredentialProvider for ChainedCredentialProvider {
338    async fn resolve(&self) -> Result<Credentials> {
339        for provider in &self.providers {
340            match provider.resolve().await {
341                Ok(creds) => return Ok(creds),
342                // Missing → fall through to the next provider; any
343                // other error is a real failure that must surface
344                // rather than be masked.
345                Err(Error::Auth(AuthError::Missing { .. })) => {}
346                Err(other) => return Err(other),
347            }
348        }
349        Err(AuthError::missing_from(format!(
350            "chained: {} provider(s) exhausted",
351            self.providers.len()
352        ))
353        .into())
354    }
355}
356
357#[cfg(test)]
358#[allow(clippy::unwrap_used)]
359mod tests {
360    use super::*;
361    use std::sync::Arc;
362    use std::sync::atomic::{AtomicUsize, Ordering};
363    use std::time::Duration;
364
365    /// Minimal counting provider used to verify cache hits and chain ordering.
366    struct CountingProvider {
367        calls: Arc<AtomicUsize>,
368        outcome: Outcome,
369    }
370
371    enum Outcome {
372        Ok(SecretString),
373        Missing,
374        Refused(String),
375    }
376
377    impl CountingProvider {
378        fn ok(token: &str) -> (Self, Arc<AtomicUsize>) {
379            let calls = Arc::new(AtomicUsize::new(0));
380            (
381                Self {
382                    calls: calls.clone(),
383                    outcome: Outcome::Ok(SecretString::from(token.to_owned())),
384                },
385                calls,
386            )
387        }
388
389        fn missing() -> (Self, Arc<AtomicUsize>) {
390            let calls = Arc::new(AtomicUsize::new(0));
391            (
392                Self {
393                    calls: calls.clone(),
394                    outcome: Outcome::Missing,
395                },
396                calls,
397            )
398        }
399
400        fn refused(msg: &str) -> (Self, Arc<AtomicUsize>) {
401            let calls = Arc::new(AtomicUsize::new(0));
402            (
403                Self {
404                    calls: calls.clone(),
405                    outcome: Outcome::Refused(msg.to_owned()),
406                },
407                calls,
408            )
409        }
410    }
411
412    #[async_trait]
413    impl CredentialProvider for CountingProvider {
414        async fn resolve(&self) -> Result<Credentials> {
415            self.calls.fetch_add(1, Ordering::SeqCst);
416            match &self.outcome {
417                Outcome::Ok(token) => Ok(Credentials {
418                    header_name: http::header::AUTHORIZATION,
419                    header_value: token.clone(),
420                }),
421                Outcome::Missing => Err(AuthError::missing().into()),
422                Outcome::Refused(msg) => Err(AuthError::refused(msg.clone()).into()),
423            }
424        }
425    }
426
427    #[tokio::test]
428    async fn cached_provider_serves_from_cache_within_ttl() {
429        let (inner, calls) = CountingProvider::ok("tok-1");
430        let cached = CachedCredentialProvider::new(inner, Duration::from_mins(1));
431        let _ = cached.resolve().await.unwrap();
432        let _ = cached.resolve().await.unwrap();
433        let _ = cached.resolve().await.unwrap();
434        assert_eq!(calls.load(Ordering::SeqCst), 1);
435    }
436
437    #[tokio::test]
438    async fn cached_provider_refreshes_after_ttl() {
439        let (inner, calls) = CountingProvider::ok("tok-2");
440        let cached = CachedCredentialProvider::new(inner, Duration::from_millis(20));
441        let _ = cached.resolve().await.unwrap();
442        tokio::time::sleep(Duration::from_millis(40)).await;
443        let _ = cached.resolve().await.unwrap();
444        assert_eq!(calls.load(Ordering::SeqCst), 2);
445    }
446
447    #[tokio::test]
448    async fn chained_provider_falls_through_on_missing() {
449        let (a, a_calls) = CountingProvider::missing();
450        let (b, b_calls) = CountingProvider::ok("from-b");
451        let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
452        let creds = chain.resolve().await.unwrap();
453        assert_eq!(creds.header_name, http::header::AUTHORIZATION);
454        assert_eq!(creds.header_value.expose_secret(), "from-b");
455        assert_eq!(a_calls.load(Ordering::SeqCst), 1);
456        assert_eq!(b_calls.load(Ordering::SeqCst), 1);
457    }
458
459    #[tokio::test]
460    async fn chained_provider_short_circuits_on_real_error() {
461        // Refused is NOT Missing — chain must surface immediately
462        // rather than mask the rejection by falling through.
463        let (a, a_calls) = CountingProvider::refused("vault: 401");
464        let (b, b_calls) = CountingProvider::ok("from-b");
465        let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
466        let err = chain.resolve().await.unwrap_err();
467        assert!(matches!(err, Error::Auth(AuthError::Refused { .. })));
468        assert_eq!(a_calls.load(Ordering::SeqCst), 1);
469        assert_eq!(
470            b_calls.load(Ordering::SeqCst),
471            0,
472            "chain must not consult later providers after a real failure"
473        );
474    }
475
476    #[tokio::test]
477    async fn chained_provider_returns_missing_when_all_sources_exhausted() {
478        let (a, _) = CountingProvider::missing();
479        let (b, _) = CountingProvider::missing();
480        let chain = ChainedCredentialProvider::new(vec![Arc::new(a), Arc::new(b)]);
481        let err = chain.resolve().await.unwrap_err();
482        assert!(matches!(err, Error::Auth(AuthError::Missing { .. })));
483    }
484}