Skip to main content

fabryk_auth/
middleware.rs

1//! Generic Tower authentication middleware.
2//!
3//! `AuthLayer` and `AuthService` wrap any inner service with token validation.
4//! Generic over `TokenValidator` — plug in any identity provider.
5
6use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use axum::body::Body;
13use axum::response::IntoResponse;
14use http::{Request, StatusCode};
15use tower::{Layer, Service};
16
17use crate::{AuthConfig, TokenValidator};
18
19/// Tower `Layer` that wraps services with token authentication.
20pub struct AuthLayer<V: TokenValidator> {
21    validator: Arc<V>,
22    config: AuthConfig,
23}
24
25// Manual Clone impl — only requires V: TokenValidator (V is behind Arc,
26// so Clone on V itself is never needed). Using #[derive(Clone)] would
27// add an unnecessary V: Clone bound that breaks validators containing
28// non-Clone types like RwLock.
29impl<V: TokenValidator> Clone for AuthLayer<V> {
30    fn clone(&self) -> Self {
31        Self {
32            validator: self.validator.clone(),
33            config: self.config.clone(),
34        }
35    }
36}
37
38impl<V: TokenValidator> AuthLayer<V> {
39    /// Create a new auth layer with the given validator and config.
40    pub fn new(validator: Arc<V>, config: AuthConfig) -> Self {
41        Self { validator, config }
42    }
43}
44
45impl<V: TokenValidator, S> Layer<S> for AuthLayer<V> {
46    type Service = AuthService<V, S>;
47
48    fn layer(&self, inner: S) -> Self::Service {
49        AuthService {
50            inner,
51            validator: self.validator.clone(),
52            config: self.config.clone(),
53        }
54    }
55}
56
57/// Tower `Service` that validates tokens before forwarding requests.
58///
59/// On successful validation, inserts `AuthenticatedUser` into request
60/// extensions where it's available to downstream handlers.
61pub struct AuthService<V: TokenValidator, S> {
62    inner: S,
63    validator: Arc<V>,
64    config: AuthConfig,
65}
66
67// Manual Clone impl — only requires S: Clone (V is behind Arc).
68impl<V: TokenValidator, S: Clone> Clone for AuthService<V, S> {
69    fn clone(&self) -> Self {
70        Self {
71            inner: self.inner.clone(),
72            validator: self.validator.clone(),
73            config: self.config.clone(),
74        }
75    }
76}
77
78impl<V, S> Service<Request<Body>> for AuthService<V, S>
79where
80    V: TokenValidator,
81    S: Service<Request<Body>, Error = Infallible> + Clone + Send + 'static,
82    S::Response: IntoResponse,
83    S::Future: Send,
84{
85    type Response = axum::response::Response;
86    type Error = Infallible;
87    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
88
89    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90        self.inner.poll_ready(cx)
91    }
92
93    fn call(&mut self, mut req: Request<Body>) -> Self::Future {
94        let clone = self.inner.clone();
95        let mut inner = std::mem::replace(&mut self.inner, clone);
96
97        let validator = self.validator.clone();
98        let config = self.config.clone();
99
100        Box::pin(async move {
101            // Dev mode — no auth required
102            if !config.enabled {
103                let resp = inner
104                    .call(req)
105                    .await
106                    .unwrap_or_else(|infallible| match infallible {});
107                return Ok(resp.into_response());
108            }
109
110            // Extract bearer token
111            let token = match extract_bearer_token(&req) {
112                Some(t) => t.to_string(),
113                None => return Ok(unauthorized_response("missing or invalid bearer token")),
114            };
115
116            // Validate the token
117            match validator.validate(&token, &config).await {
118                Ok(user) => {
119                    req.extensions_mut().insert(user);
120                    let resp = inner
121                        .call(req)
122                        .await
123                        .unwrap_or_else(|infallible| match infallible {});
124                    Ok(resp.into_response())
125                }
126                Err(auth_err) => {
127                    log::warn!("Authentication failed: {auth_err}");
128                    Ok(unauthorized_response(&auth_err.to_string()))
129                }
130            }
131        })
132    }
133}
134
135/// Extract bearer token from the Authorization header.
136fn extract_bearer_token(req: &Request<Body>) -> Option<&str> {
137    req.headers()
138        .get(http::header::AUTHORIZATION)
139        .and_then(|v| v.to_str().ok())
140        .and_then(|v| v.strip_prefix("Bearer "))
141}
142
143/// Build a 401 Unauthorized response with WWW-Authenticate header.
144fn unauthorized_response(message: &str) -> axum::response::Response {
145    let body = serde_json::json!({
146        "error": {
147            "category": "authentication",
148            "message": message,
149        }
150    });
151
152    let resource_url = std::env::var("KASU_RESOURCE_URL")
153        .or_else(|_| std::env::var("TAPROOT_RESOURCE_URL"))
154        .unwrap_or_default();
155    let www_auth = format!(
156        r#"Bearer resource_metadata="{resource_url}/.well-known/oauth-protected-resource""#,
157    );
158
159    let mut response = (
160        StatusCode::UNAUTHORIZED,
161        [(http::header::CONTENT_TYPE, "application/json")],
162        serde_json::to_string(&body).unwrap_or_default(),
163    )
164        .into_response();
165
166    if let Ok(value) = http::HeaderValue::from_str(&www_auth) {
167        response
168            .headers_mut()
169            .insert(http::header::WWW_AUTHENTICATE, value);
170    }
171
172    response
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178    use crate::{AuthError, AuthenticatedUser};
179    use std::sync::Mutex;
180    use tower::ServiceExt;
181
182    // A simple test validator that accepts "valid-token" and rejects everything else.
183    struct TestValidator;
184
185    impl TokenValidator for TestValidator {
186        fn validate(
187            &self,
188            token: &str,
189            _config: &AuthConfig,
190        ) -> Pin<Box<dyn Future<Output = Result<AuthenticatedUser, AuthError>> + Send + '_>>
191        {
192            let token = token.to_string();
193            Box::pin(async move {
194                if token == "valid-token" {
195                    Ok(AuthenticatedUser {
196                        email: "alice@banyan.com".to_string(),
197                        subject: "sub_123".to_string(),
198                    })
199                } else {
200                    Err(AuthError::InvalidSignature("bad token".to_string()))
201                }
202            })
203        }
204    }
205
206    fn test_config_enabled() -> AuthConfig {
207        AuthConfig {
208            enabled: true,
209            audience: "test-audience".to_string(),
210            domain: "banyan.com".to_string(),
211        }
212    }
213
214    fn test_config_disabled() -> AuthConfig {
215        AuthConfig {
216            enabled: false,
217            ..Default::default()
218        }
219    }
220
221    /// Mock inner service that captures the AuthenticatedUser.
222    #[derive(Clone)]
223    struct MockService {
224        captured_user: Arc<Mutex<Option<AuthenticatedUser>>>,
225    }
226
227    impl MockService {
228        fn new() -> Self {
229            Self {
230                captured_user: Arc::new(Mutex::new(None)),
231            }
232        }
233    }
234
235    impl Service<Request<Body>> for MockService {
236        type Response = axum::response::Response;
237        type Error = Infallible;
238        type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
239
240        fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241            Poll::Ready(Ok(()))
242        }
243
244        fn call(&mut self, req: Request<Body>) -> Self::Future {
245            let captured = self.captured_user.clone();
246            Box::pin(async move {
247                let user = req.extensions().get::<AuthenticatedUser>().cloned();
248                *captured.lock().unwrap() = user;
249                Ok((StatusCode::OK, "ok").into_response())
250            })
251        }
252    }
253
254    #[test]
255    fn test_extract_bearer_token_valid() {
256        let req = Request::builder()
257            .header("Authorization", "Bearer my-token-123")
258            .body(Body::empty())
259            .unwrap();
260        assert_eq!(extract_bearer_token(&req), Some("my-token-123"));
261    }
262
263    #[test]
264    fn test_extract_bearer_token_missing() {
265        let req = Request::builder().body(Body::empty()).unwrap();
266        assert_eq!(extract_bearer_token(&req), None);
267    }
268
269    #[test]
270    fn test_extract_bearer_token_wrong_scheme() {
271        let req = Request::builder()
272            .header("Authorization", "Basic dXNlcjpwYXNz")
273            .body(Body::empty())
274            .unwrap();
275        assert_eq!(extract_bearer_token(&req), None);
276    }
277
278    #[test]
279    fn test_unauthorized_response_status() {
280        let resp = unauthorized_response("test error");
281        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
282    }
283
284    #[tokio::test]
285    async fn test_middleware_disabled_passes_through() {
286        let mock = MockService::new();
287        let layer = AuthLayer::new(Arc::new(TestValidator), test_config_disabled());
288        let service = layer.layer(mock);
289
290        let req = Request::builder().body(Body::empty()).unwrap();
291        let resp = service.oneshot(req).await.unwrap();
292        assert_eq!(resp.status(), StatusCode::OK);
293    }
294
295    #[tokio::test]
296    async fn test_middleware_missing_token_returns_401() {
297        let mock = MockService::new();
298        let layer = AuthLayer::new(Arc::new(TestValidator), test_config_enabled());
299        let service = layer.layer(mock);
300
301        let req = Request::builder().body(Body::empty()).unwrap();
302        let resp = service.oneshot(req).await.unwrap();
303        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
304    }
305
306    #[tokio::test]
307    async fn test_middleware_invalid_token_returns_401() {
308        let mock = MockService::new();
309        let layer = AuthLayer::new(Arc::new(TestValidator), test_config_enabled());
310        let service = layer.layer(mock);
311
312        let req = Request::builder()
313            .header("Authorization", "Bearer bad-token")
314            .body(Body::empty())
315            .unwrap();
316        let resp = service.oneshot(req).await.unwrap();
317        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
318    }
319
320    #[tokio::test]
321    async fn test_middleware_valid_token_passes_and_injects_user() {
322        let mock = MockService::new();
323        let captured = mock.captured_user.clone();
324        let layer = AuthLayer::new(Arc::new(TestValidator), test_config_enabled());
325        let service = layer.layer(mock);
326
327        let req = Request::builder()
328            .header("Authorization", "Bearer valid-token")
329            .body(Body::empty())
330            .unwrap();
331        let resp = service.oneshot(req).await.unwrap();
332        assert_eq!(resp.status(), StatusCode::OK);
333
334        let user = captured.lock().unwrap();
335        let user = user.as_ref().expect("AuthenticatedUser should be present");
336        assert_eq!(user.email, "alice@banyan.com");
337        assert_eq!(user.subject, "sub_123");
338    }
339}