1#![deny(missing_docs)]
2use async_trait::async_trait;
18use neuron_secret::SecretValue;
19use std::sync::Arc;
20use std::time::SystemTime;
21use thiserror::Error;
22
23#[non_exhaustive]
25#[derive(Debug, Error)]
26pub enum AuthError {
27 #[error("auth failed: {0}")]
29 AuthFailed(String),
30
31 #[error("scope unavailable: {0}")]
33 ScopeUnavailable(String),
34
35 #[error("backend error: {0}")]
37 BackendError(String),
38
39 #[error("{0}")]
41 Other(#[from] Box<dyn std::error::Error + Send + Sync>),
42}
43
44#[non_exhaustive]
46#[derive(Debug, Clone, Default)]
47pub struct AuthRequest {
48 pub audience: Option<String>,
50 pub scopes: Vec<String>,
52 pub resource: Option<String>,
54 pub actor: Option<String>,
56}
57
58impl AuthRequest {
59 pub fn new() -> Self {
61 Self::default()
62 }
63
64 pub fn with_audience(mut self, audience: impl Into<String>) -> Self {
66 self.audience = Some(audience.into());
67 self
68 }
69
70 pub fn with_scope(mut self, scope: impl Into<String>) -> Self {
72 self.scopes.push(scope.into());
73 self
74 }
75
76 pub fn with_resource(mut self, resource: impl Into<String>) -> Self {
78 self.resource = Some(resource.into());
79 self
80 }
81
82 pub fn with_actor(mut self, actor: impl Into<String>) -> Self {
84 self.actor = Some(actor.into());
85 self
86 }
87}
88
89pub struct AuthToken {
92 inner: SecretValue,
93 expires_at: Option<SystemTime>,
94}
95
96impl AuthToken {
97 pub fn new(bytes: Vec<u8>, expires_at: Option<SystemTime>) -> Self {
99 Self {
100 inner: SecretValue::new(bytes),
101 expires_at,
102 }
103 }
104
105 pub fn permanent(bytes: Vec<u8>) -> Self {
107 Self::new(bytes, None)
108 }
109
110 pub fn with_bytes<R>(&self, f: impl FnOnce(&[u8]) -> R) -> R {
112 self.inner.with_bytes(f)
113 }
114
115 pub fn is_expired(&self) -> bool {
117 self.expires_at
118 .map(|exp| SystemTime::now() > exp)
119 .unwrap_or(false)
120 }
121
122 pub fn expires_at(&self) -> Option<SystemTime> {
124 self.expires_at
125 }
126}
127
128impl std::fmt::Debug for AuthToken {
129 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
130 f.debug_struct("AuthToken")
131 .field("value", &"[REDACTED]")
132 .field("expires_at", &self.expires_at)
133 .finish()
134 }
135}
136
137#[async_trait]
139pub trait AuthProvider: Send + Sync {
140 async fn provide(&self, request: &AuthRequest) -> Result<AuthToken, AuthError>;
142}
143
144pub struct AuthProviderChain {
146 providers: Vec<Arc<dyn AuthProvider>>,
147}
148
149impl AuthProviderChain {
150 pub fn new() -> Self {
152 Self {
153 providers: Vec::new(),
154 }
155 }
156
157 pub fn with_provider(mut self, provider: Arc<dyn AuthProvider>) -> Self {
159 self.providers.push(provider);
160 self
161 }
162
163 pub fn add(&mut self, provider: Arc<dyn AuthProvider>) {
165 self.providers.push(provider);
166 }
167}
168
169impl Default for AuthProviderChain {
170 fn default() -> Self {
171 Self::new()
172 }
173}
174
175#[async_trait]
176impl AuthProvider for AuthProviderChain {
177 async fn provide(&self, request: &AuthRequest) -> Result<AuthToken, AuthError> {
178 let mut last_err = None;
179 for provider in &self.providers {
180 match provider.provide(request).await {
181 Ok(token) => return Ok(token),
182 Err(e) => last_err = Some(e),
183 }
184 }
185 Err(last_err.unwrap_or_else(|| AuthError::AuthFailed("no providers configured".into())))
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use super::*;
192
193 fn _assert_send_sync<T: Send + Sync>() {}
194
195 #[test]
196 fn auth_provider_is_object_safe_send_sync() {
197 _assert_send_sync::<Box<dyn AuthProvider>>();
198 _assert_send_sync::<Arc<dyn AuthProvider>>();
199 }
200
201 #[test]
202 fn auth_token_debug_is_redacted() {
203 let token = AuthToken::permanent(b"secret-token".to_vec());
204 let debug = format!("{:?}", token);
205 assert!(debug.contains("[REDACTED]"));
206 assert!(!debug.contains("secret-token"));
207 }
208
209 #[test]
210 fn auth_token_with_bytes_exposes_content() {
211 let token = AuthToken::permanent(b"my-token".to_vec());
212 token.with_bytes(|bytes| {
213 assert_eq!(bytes, b"my-token");
214 });
215 }
216
217 #[test]
218 fn auth_token_permanent_never_expires() {
219 let token = AuthToken::permanent(b"token".to_vec());
220 assert!(!token.is_expired());
221 assert!(token.expires_at().is_none());
222 }
223
224 #[test]
225 fn auth_request_builder() {
226 let req = AuthRequest::new()
227 .with_audience("https://vault.internal")
228 .with_scope("read:secrets")
229 .with_scope("write:audit")
230 .with_resource("secret/data/api-key")
231 .with_actor("workflow-001");
232 assert_eq!(req.audience.as_deref(), Some("https://vault.internal"));
233 assert_eq!(req.scopes.len(), 2);
234 assert_eq!(req.resource.as_deref(), Some("secret/data/api-key"));
235 assert_eq!(req.actor.as_deref(), Some("workflow-001"));
236 }
237
238 struct AlwaysFailProvider;
239 #[async_trait]
240 impl AuthProvider for AlwaysFailProvider {
241 async fn provide(&self, _request: &AuthRequest) -> Result<AuthToken, AuthError> {
242 Err(AuthError::AuthFailed("always fails".into()))
243 }
244 }
245
246 struct StaticTokenProvider {
247 token: Vec<u8>,
248 }
249 #[async_trait]
250 impl AuthProvider for StaticTokenProvider {
251 async fn provide(&self, _request: &AuthRequest) -> Result<AuthToken, AuthError> {
252 Ok(AuthToken::permanent(self.token.clone()))
253 }
254 }
255
256 #[tokio::test]
257 async fn chain_empty_returns_error() {
258 let chain = AuthProviderChain::new();
259 assert!(chain.provide(&AuthRequest::new()).await.is_err());
260 }
261
262 #[tokio::test]
263 async fn chain_first_success_wins() {
264 let chain = AuthProviderChain::new()
265 .with_provider(Arc::new(StaticTokenProvider {
266 token: b"first".to_vec(),
267 }))
268 .with_provider(Arc::new(StaticTokenProvider {
269 token: b"second".to_vec(),
270 }));
271 let token = chain.provide(&AuthRequest::new()).await.unwrap();
272 token.with_bytes(|b| assert_eq!(b, b"first"));
273 }
274
275 #[tokio::test]
276 async fn chain_skips_failures() {
277 let chain = AuthProviderChain::new()
278 .with_provider(Arc::new(AlwaysFailProvider))
279 .with_provider(Arc::new(StaticTokenProvider {
280 token: b"fallback".to_vec(),
281 }));
282 let token = chain.provide(&AuthRequest::new()).await.unwrap();
283 token.with_bytes(|b| assert_eq!(b, b"fallback"));
284 }
285
286 #[tokio::test]
287 async fn chain_all_fail_returns_last_error() {
288 let chain = AuthProviderChain::new()
289 .with_provider(Arc::new(AlwaysFailProvider))
290 .with_provider(Arc::new(AlwaysFailProvider));
291 let result = chain.provide(&AuthRequest::new()).await;
292 assert!(result.is_err());
293 assert_eq!(result.unwrap_err().to_string(), "auth failed: always fails");
294 }
295
296 #[test]
297 fn auth_error_display_all_variants() {
298 assert_eq!(
299 AuthError::AuthFailed("bad token".into()).to_string(),
300 "auth failed: bad token"
301 );
302 assert_eq!(
303 AuthError::ScopeUnavailable("admin".into()).to_string(),
304 "scope unavailable: admin"
305 );
306 assert_eq!(
307 AuthError::BackendError("connection refused".into()).to_string(),
308 "backend error: connection refused"
309 );
310 }
311}