fabryk_auth/
middleware.rs1use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11
12use axum::body::Body;
13use axum::response::IntoResponse;
14use http::{Request, StatusCode};
15use tower::{Layer, Service};
16
17use crate::{AuthConfig, TokenValidator};
18
19pub struct AuthLayer<V: TokenValidator> {
21 validator: Arc<V>,
22 config: AuthConfig,
23}
24
25impl<V: TokenValidator> Clone for AuthLayer<V> {
30 fn clone(&self) -> Self {
31 Self {
32 validator: self.validator.clone(),
33 config: self.config.clone(),
34 }
35 }
36}
37
38impl<V: TokenValidator> AuthLayer<V> {
39 pub fn new(validator: Arc<V>, config: AuthConfig) -> Self {
41 Self { validator, config }
42 }
43}
44
45impl<V: TokenValidator, S> Layer<S> for AuthLayer<V> {
46 type Service = AuthService<V, S>;
47
48 fn layer(&self, inner: S) -> Self::Service {
49 AuthService {
50 inner,
51 validator: self.validator.clone(),
52 config: self.config.clone(),
53 }
54 }
55}
56
57pub struct AuthService<V: TokenValidator, S> {
62 inner: S,
63 validator: Arc<V>,
64 config: AuthConfig,
65}
66
67impl<V: TokenValidator, S: Clone> Clone for AuthService<V, S> {
69 fn clone(&self) -> Self {
70 Self {
71 inner: self.inner.clone(),
72 validator: self.validator.clone(),
73 config: self.config.clone(),
74 }
75 }
76}
77
78impl<V, S> Service<Request<Body>> for AuthService<V, S>
79where
80 V: TokenValidator,
81 S: Service<Request<Body>, Error = Infallible> + Clone + Send + 'static,
82 S::Response: IntoResponse,
83 S::Future: Send,
84{
85 type Response = axum::response::Response;
86 type Error = Infallible;
87 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
88
89 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90 self.inner.poll_ready(cx)
91 }
92
93 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
94 let clone = self.inner.clone();
95 let mut inner = std::mem::replace(&mut self.inner, clone);
96
97 let validator = self.validator.clone();
98 let config = self.config.clone();
99
100 Box::pin(async move {
101 if !config.enabled {
103 let resp = inner
104 .call(req)
105 .await
106 .unwrap_or_else(|infallible| match infallible {});
107 return Ok(resp.into_response());
108 }
109
110 let token = match extract_bearer_token(&req) {
112 Some(t) => t.to_string(),
113 None => return Ok(unauthorized_response("missing or invalid bearer token")),
114 };
115
116 match validator.validate(&token, &config).await {
118 Ok(user) => {
119 req.extensions_mut().insert(user);
120 let resp = inner
121 .call(req)
122 .await
123 .unwrap_or_else(|infallible| match infallible {});
124 Ok(resp.into_response())
125 }
126 Err(auth_err) => {
127 log::warn!("Authentication failed: {auth_err}");
128 Ok(unauthorized_response(&auth_err.to_string()))
129 }
130 }
131 })
132 }
133}
134
135fn extract_bearer_token(req: &Request<Body>) -> Option<&str> {
137 req.headers()
138 .get(http::header::AUTHORIZATION)
139 .and_then(|v| v.to_str().ok())
140 .and_then(|v| v.strip_prefix("Bearer "))
141}
142
143fn unauthorized_response(message: &str) -> axum::response::Response {
145 let body = serde_json::json!({
146 "error": {
147 "category": "authentication",
148 "message": message,
149 }
150 });
151
152 let resource_url = std::env::var("KASU_RESOURCE_URL")
153 .or_else(|_| std::env::var("TAPROOT_RESOURCE_URL"))
154 .unwrap_or_default();
155 let www_auth = format!(
156 r#"Bearer resource_metadata="{resource_url}/.well-known/oauth-protected-resource""#,
157 );
158
159 let mut response = (
160 StatusCode::UNAUTHORIZED,
161 [(http::header::CONTENT_TYPE, "application/json")],
162 serde_json::to_string(&body).unwrap_or_default(),
163 )
164 .into_response();
165
166 if let Ok(value) = http::HeaderValue::from_str(&www_auth) {
167 response
168 .headers_mut()
169 .insert(http::header::WWW_AUTHENTICATE, value);
170 }
171
172 response
173}
174
175#[cfg(test)]
176mod tests {
177 use super::*;
178 use crate::{AuthError, AuthenticatedUser};
179 use std::sync::Mutex;
180 use tower::ServiceExt;
181
182 struct TestValidator;
184
185 impl TokenValidator for TestValidator {
186 fn validate(
187 &self,
188 token: &str,
189 _config: &AuthConfig,
190 ) -> Pin<Box<dyn Future<Output = Result<AuthenticatedUser, AuthError>> + Send + '_>>
191 {
192 let token = token.to_string();
193 Box::pin(async move {
194 if token == "valid-token" {
195 Ok(AuthenticatedUser {
196 email: "alice@banyan.com".to_string(),
197 subject: "sub_123".to_string(),
198 })
199 } else {
200 Err(AuthError::InvalidSignature("bad token".to_string()))
201 }
202 })
203 }
204 }
205
206 fn test_config_enabled() -> AuthConfig {
207 AuthConfig {
208 enabled: true,
209 audience: "test-audience".to_string(),
210 domain: "banyan.com".to_string(),
211 }
212 }
213
214 fn test_config_disabled() -> AuthConfig {
215 AuthConfig {
216 enabled: false,
217 ..Default::default()
218 }
219 }
220
221 #[derive(Clone)]
223 struct MockService {
224 captured_user: Arc<Mutex<Option<AuthenticatedUser>>>,
225 }
226
227 impl MockService {
228 fn new() -> Self {
229 Self {
230 captured_user: Arc::new(Mutex::new(None)),
231 }
232 }
233 }
234
235 impl Service<Request<Body>> for MockService {
236 type Response = axum::response::Response;
237 type Error = Infallible;
238 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
239
240 fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
241 Poll::Ready(Ok(()))
242 }
243
244 fn call(&mut self, req: Request<Body>) -> Self::Future {
245 let captured = self.captured_user.clone();
246 Box::pin(async move {
247 let user = req.extensions().get::<AuthenticatedUser>().cloned();
248 *captured.lock().unwrap() = user;
249 Ok((StatusCode::OK, "ok").into_response())
250 })
251 }
252 }
253
254 #[test]
255 fn test_extract_bearer_token_valid() {
256 let req = Request::builder()
257 .header("Authorization", "Bearer my-token-123")
258 .body(Body::empty())
259 .unwrap();
260 assert_eq!(extract_bearer_token(&req), Some("my-token-123"));
261 }
262
263 #[test]
264 fn test_extract_bearer_token_missing() {
265 let req = Request::builder().body(Body::empty()).unwrap();
266 assert_eq!(extract_bearer_token(&req), None);
267 }
268
269 #[test]
270 fn test_extract_bearer_token_wrong_scheme() {
271 let req = Request::builder()
272 .header("Authorization", "Basic dXNlcjpwYXNz")
273 .body(Body::empty())
274 .unwrap();
275 assert_eq!(extract_bearer_token(&req), None);
276 }
277
278 #[test]
279 fn test_unauthorized_response_status() {
280 let resp = unauthorized_response("test error");
281 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
282 }
283
284 #[tokio::test]
285 async fn test_middleware_disabled_passes_through() {
286 let mock = MockService::new();
287 let layer = AuthLayer::new(Arc::new(TestValidator), test_config_disabled());
288 let service = layer.layer(mock);
289
290 let req = Request::builder().body(Body::empty()).unwrap();
291 let resp = service.oneshot(req).await.unwrap();
292 assert_eq!(resp.status(), StatusCode::OK);
293 }
294
295 #[tokio::test]
296 async fn test_middleware_missing_token_returns_401() {
297 let mock = MockService::new();
298 let layer = AuthLayer::new(Arc::new(TestValidator), test_config_enabled());
299 let service = layer.layer(mock);
300
301 let req = Request::builder().body(Body::empty()).unwrap();
302 let resp = service.oneshot(req).await.unwrap();
303 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
304 }
305
306 #[tokio::test]
307 async fn test_middleware_invalid_token_returns_401() {
308 let mock = MockService::new();
309 let layer = AuthLayer::new(Arc::new(TestValidator), test_config_enabled());
310 let service = layer.layer(mock);
311
312 let req = Request::builder()
313 .header("Authorization", "Bearer bad-token")
314 .body(Body::empty())
315 .unwrap();
316 let resp = service.oneshot(req).await.unwrap();
317 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
318 }
319
320 #[tokio::test]
321 async fn test_middleware_valid_token_passes_and_injects_user() {
322 let mock = MockService::new();
323 let captured = mock.captured_user.clone();
324 let layer = AuthLayer::new(Arc::new(TestValidator), test_config_enabled());
325 let service = layer.layer(mock);
326
327 let req = Request::builder()
328 .header("Authorization", "Bearer valid-token")
329 .body(Body::empty())
330 .unwrap();
331 let resp = service.oneshot(req).await.unwrap();
332 assert_eq!(resp.status(), StatusCode::OK);
333
334 let user = captured.lock().unwrap();
335 let user = user.as_ref().expect("AuthenticatedUser should be present");
336 assert_eq!(user.email, "alice@banyan.com");
337 assert_eq!(user.subject, "sub_123");
338 }
339}