Skip to main content

neuron_auth/
lib.rs

1#![deny(missing_docs)]
2//! Authentication providers for neuron.
3//!
4//! This crate defines the [`AuthProvider`] trait for obtaining authentication
5//! credentials to access secret backends. It also provides [`AuthProviderChain`]
6//! for composing multiple providers (try in order until one succeeds, like
7//! AWS DefaultCredentialsChain).
8//!
9//! ## Separation of Concerns
10//!
11//! Auth providers produce credentials (tokens). Secret resolvers consume them.
12//! A `VaultResolver` takes an `Arc<dyn AuthProvider>` and uses it to authenticate
13//! before fetching secrets. This separation follows the pattern established by
14//! AWS SDK (`ProvideCredentials` vs `SecretsManagerClient`), vaultrs
15//! (`auth::*` vs `kv2::*`), and Google Cloud SDK.
16
17use async_trait::async_trait;
18use neuron_secret::SecretValue;
19use std::sync::Arc;
20use std::time::SystemTime;
21use thiserror::Error;
22
23/// Errors from authentication providers (crate-local, not in layer0).
24#[non_exhaustive]
25#[derive(Debug, Error)]
26pub enum AuthError {
27    /// Authentication failed (bad credentials, expired token, etc.).
28    #[error("auth failed: {0}")]
29    AuthFailed(String),
30
31    /// The requested scope or audience is not available.
32    #[error("scope unavailable: {0}")]
33    ScopeUnavailable(String),
34
35    /// Backend communication failure.
36    #[error("backend error: {0}")]
37    BackendError(String),
38
39    /// Catch-all.
40    #[error("{0}")]
41    Other(#[from] Box<dyn std::error::Error + Send + Sync>),
42}
43
44/// Context for an authentication request.
45#[non_exhaustive]
46#[derive(Debug, Clone, Default)]
47pub struct AuthRequest {
48    /// Target audience (OIDC audience, API identifier).
49    pub audience: Option<String>,
50    /// Requested scopes (OIDC scopes, OAuth2 scopes).
51    pub scopes: Vec<String>,
52    /// Target resource identifier (e.g., Vault path, AWS region).
53    pub resource: Option<String>,
54    /// Actor identity for audit (workflow ID, agent ID).
55    pub actor: Option<String>,
56}
57
58impl AuthRequest {
59    /// Create an empty auth request (no specific context).
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    /// Set the target audience.
65    pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
66        self.audience = Some(audience.into());
67        self
68    }
69
70    /// Add a scope.
71    pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
72        self.scopes.push(scope.into());
73        self
74    }
75
76    /// Set the target resource.
77    pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
78        self.resource = Some(resource.into());
79        self
80    }
81
82    /// Set the actor identity.
83    pub fn with_actor(mut self, actor: impl Into<String>) -> Self {
84        self.actor = Some(actor.into());
85        self
86    }
87}
88
89/// An opaque authentication token with expiry.
90/// Uses [`SecretValue`] internally for in-memory protection.
91pub struct AuthToken {
92    inner: SecretValue,
93    expires_at: Option<SystemTime>,
94}
95
96impl AuthToken {
97    /// Create a new auth token.
98    pub fn new(bytes: Vec<u8>, expires_at: Option<SystemTime>) -> Self {
99        Self {
100            inner: SecretValue::new(bytes),
101            expires_at,
102        }
103    }
104
105    /// Create a token that never expires (for dev/test).
106    pub fn permanent(bytes: Vec<u8>) -> Self {
107        Self::new(bytes, None)
108    }
109
110    /// Scoped exposure of the token bytes.
111    pub fn with_bytes<R>(&self, f: impl FnOnce(&[u8]) -> R) -> R {
112        self.inner.with_bytes(f)
113    }
114
115    /// Check if this token has expired.
116    pub fn is_expired(&self) -> bool {
117        self.expires_at
118            .map(|exp| SystemTime::now() > exp)
119            .unwrap_or(false)
120    }
121
122    /// Returns when this token expires, if known.
123    pub fn expires_at(&self) -> Option<SystemTime> {
124        self.expires_at
125    }
126}
127
128impl std::fmt::Debug for AuthToken {
129    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130        f.debug_struct("AuthToken")
131            .field("value", &"[REDACTED]")
132            .field("expires_at", &self.expires_at)
133            .finish()
134    }
135}
136
137/// Provide authentication credentials for accessing a secret backend.
138#[async_trait]
139pub trait AuthProvider: Send + Sync {
140    /// Provide an authentication token for the given request context.
141    async fn provide(&self, request: &AuthRequest) -> Result<AuthToken, AuthError>;
142}
143
144/// Tries providers in order until one succeeds.
145pub struct AuthProviderChain {
146    providers: Vec<Arc<dyn AuthProvider>>,
147}
148
149impl AuthProviderChain {
150    /// Create a new empty chain.
151    pub fn new() -> Self {
152        Self {
153            providers: Vec::new(),
154        }
155    }
156
157    /// Add a provider to the end of the chain.
158    pub fn with_provider(mut self, provider: Arc<dyn AuthProvider>) -> Self {
159        self.providers.push(provider);
160        self
161    }
162
163    /// Add a provider to the end of the chain (mutable).
164    pub fn add(&mut self, provider: Arc<dyn AuthProvider>) {
165        self.providers.push(provider);
166    }
167}
168
169impl Default for AuthProviderChain {
170    fn default() -> Self {
171        Self::new()
172    }
173}
174
175#[async_trait]
176impl AuthProvider for AuthProviderChain {
177    async fn provide(&self, request: &AuthRequest) -> Result<AuthToken, AuthError> {
178        let mut last_err = None;
179        for provider in &self.providers {
180            match provider.provide(request).await {
181                Ok(token) => return Ok(token),
182                Err(e) => last_err = Some(e),
183            }
184        }
185        Err(last_err.unwrap_or_else(|| AuthError::AuthFailed("no providers configured".into())))
186    }
187}
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192
193    fn _assert_send_sync<T: Send + Sync>() {}
194
195    #[test]
196    fn auth_provider_is_object_safe_send_sync() {
197        _assert_send_sync::<Box<dyn AuthProvider>>();
198        _assert_send_sync::<Arc<dyn AuthProvider>>();
199    }
200
201    #[test]
202    fn auth_token_debug_is_redacted() {
203        let token = AuthToken::permanent(b"secret-token".to_vec());
204        let debug = format!("{:?}", token);
205        assert!(debug.contains("[REDACTED]"));
206        assert!(!debug.contains("secret-token"));
207    }
208
209    #[test]
210    fn auth_token_with_bytes_exposes_content() {
211        let token = AuthToken::permanent(b"my-token".to_vec());
212        token.with_bytes(|bytes| {
213            assert_eq!(bytes, b"my-token");
214        });
215    }
216
217    #[test]
218    fn auth_token_permanent_never_expires() {
219        let token = AuthToken::permanent(b"token".to_vec());
220        assert!(!token.is_expired());
221        assert!(token.expires_at().is_none());
222    }
223
224    #[test]
225    fn auth_request_builder() {
226        let req = AuthRequest::new()
227            .with_audience("https://vault.internal")
228            .with_scope("read:secrets")
229            .with_scope("write:audit")
230            .with_resource("secret/data/api-key")
231            .with_actor("workflow-001");
232        assert_eq!(req.audience.as_deref(), Some("https://vault.internal"));
233        assert_eq!(req.scopes.len(), 2);
234        assert_eq!(req.resource.as_deref(), Some("secret/data/api-key"));
235        assert_eq!(req.actor.as_deref(), Some("workflow-001"));
236    }
237
238    struct AlwaysFailProvider;
239    #[async_trait]
240    impl AuthProvider for AlwaysFailProvider {
241        async fn provide(&self, _request: &AuthRequest) -> Result<AuthToken, AuthError> {
242            Err(AuthError::AuthFailed("always fails".into()))
243        }
244    }
245
246    struct StaticTokenProvider {
247        token: Vec<u8>,
248    }
249    #[async_trait]
250    impl AuthProvider for StaticTokenProvider {
251        async fn provide(&self, _request: &AuthRequest) -> Result<AuthToken, AuthError> {
252            Ok(AuthToken::permanent(self.token.clone()))
253        }
254    }
255
256    #[tokio::test]
257    async fn chain_empty_returns_error() {
258        let chain = AuthProviderChain::new();
259        assert!(chain.provide(&AuthRequest::new()).await.is_err());
260    }
261
262    #[tokio::test]
263    async fn chain_first_success_wins() {
264        let chain = AuthProviderChain::new()
265            .with_provider(Arc::new(StaticTokenProvider {
266                token: b"first".to_vec(),
267            }))
268            .with_provider(Arc::new(StaticTokenProvider {
269                token: b"second".to_vec(),
270            }));
271        let token = chain.provide(&AuthRequest::new()).await.unwrap();
272        token.with_bytes(|b| assert_eq!(b, b"first"));
273    }
274
275    #[tokio::test]
276    async fn chain_skips_failures() {
277        let chain = AuthProviderChain::new()
278            .with_provider(Arc::new(AlwaysFailProvider))
279            .with_provider(Arc::new(StaticTokenProvider {
280                token: b"fallback".to_vec(),
281            }));
282        let token = chain.provide(&AuthRequest::new()).await.unwrap();
283        token.with_bytes(|b| assert_eq!(b, b"fallback"));
284    }
285
286    #[tokio::test]
287    async fn chain_all_fail_returns_last_error() {
288        let chain = AuthProviderChain::new()
289            .with_provider(Arc::new(AlwaysFailProvider))
290            .with_provider(Arc::new(AlwaysFailProvider));
291        let result = chain.provide(&AuthRequest::new()).await;
292        assert!(result.is_err());
293        assert_eq!(result.unwrap_err().to_string(), "auth failed: always fails");
294    }
295
296    #[test]
297    fn auth_error_display_all_variants() {
298        assert_eq!(
299            AuthError::AuthFailed("bad token".into()).to_string(),
300            "auth failed: bad token"
301        );
302        assert_eq!(
303            AuthError::ScopeUnavailable("admin".into()).to_string(),
304            "scope unavailable: admin"
305        );
306        assert_eq!(
307            AuthError::BackendError("connection refused".into()).to_string(),
308            "backend error: connection refused"
309        );
310    }
311}