modo/auth/session/jwt/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use tower::{Layer, Service};
10
11use crate::Error;
12use crate::auth::session::Session;
13
14use super::claims::Claims;
15use super::decoder::JwtDecoder;
16use super::error::JwtError;
17use super::service::JwtSessionService;
18use super::source::{BearerSource, TokenSource};
19use crate::auth::session::token::SessionToken;
20
21#[derive(Clone)]
34pub struct JwtLayer {
35 decoder: JwtDecoder,
36 sources: Arc<[Arc<dyn TokenSource>]>,
37 service: Option<JwtSessionService>,
41}
42
43impl JwtLayer {
44 pub fn new(decoder: JwtDecoder) -> Self {
51 Self {
52 decoder,
53 sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
54 service: None,
55 }
56 }
57
58 pub fn from_service(service: JwtSessionService) -> Self {
69 let decoder = service.decoder().clone();
70 Self {
71 decoder,
72 sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
73 service: Some(service),
74 }
75 }
76
77 pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
81 self.sources = Arc::from(sources);
82 self
83 }
84}
85
86impl<Svc> Layer<Svc> for JwtLayer {
87 type Service = JwtMiddleware<Svc>;
88
89 fn layer(&self, inner: Svc) -> Self::Service {
90 JwtMiddleware {
91 inner,
92 decoder: self.decoder.clone(),
93 sources: self.sources.clone(),
94 service: self.service.clone(),
95 }
96 }
97}
98
99pub struct JwtMiddleware<Svc> {
101 inner: Svc,
102 decoder: JwtDecoder,
103 sources: Arc<[Arc<dyn TokenSource>]>,
104 service: Option<JwtSessionService>,
105}
106
107impl<Svc: Clone> Clone for JwtMiddleware<Svc> {
108 fn clone(&self) -> Self {
109 Self {
110 inner: self.inner.clone(),
111 decoder: self.decoder.clone(),
112 sources: self.sources.clone(),
113 service: self.service.clone(),
114 }
115 }
116}
117
118impl<Svc> Service<Request<Body>> for JwtMiddleware<Svc>
119where
120 Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
121 Svc::Future: Send + 'static,
122 Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
123{
124 type Response = http::Response<Body>;
125 type Error = Svc::Error;
126 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
127
128 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
129 self.inner.poll_ready(cx)
130 }
131
132 fn call(&mut self, request: Request<Body>) -> Self::Future {
133 let decoder = self.decoder.clone();
134 let sources = self.sources.clone();
135 let service = self.service.clone();
136 let mut inner = self.inner.clone();
137 std::mem::swap(&mut self.inner, &mut inner);
138
139 Box::pin(async move {
140 let (mut parts, body) = request.into_parts();
141
142 let token = sources.iter().find_map(|s| s.extract(&parts));
144 let token = match token {
145 Some(t) => t,
146 None => {
147 let err = Error::unauthorized("unauthorized")
148 .chain(JwtError::MissingToken)
149 .with_code(JwtError::MissingToken.code());
150 return Ok(err.into_response());
151 }
152 };
153
154 let claims: Claims = match decoder.decode(&token) {
156 Ok(c) => c,
157 Err(e) => return Ok(e.into_response()),
158 };
159
160 if let Some(svc) = service {
165 if claims.aud.as_deref() != Some("access") {
167 let err = Error::unauthorized("unauthorized").with_code("auth:aud_mismatch");
168 return Ok(err.into_response());
169 }
170
171 if svc.config().stateful_validation {
173 let jti = match claims.jti.as_deref() {
174 Some(j) => j,
175 None => {
176 let err = Error::unauthorized("unauthorized")
177 .with_code("auth:session_not_found");
178 return Ok(err.into_response());
179 }
180 };
181
182 let session_token = match SessionToken::from_raw(jti) {
183 Some(t) => t,
184 None => {
185 let err = Error::unauthorized("unauthorized")
186 .with_code("auth:session_not_found");
187 return Ok(err.into_response());
188 }
189 };
190
191 let lookup = svc.store().read_by_token_hash(&session_token.hash()).await;
193 let raw = match lookup {
194 Err(e) => return Ok(e.into_response()),
195 Ok(None) => {
196 let err = Error::unauthorized("unauthorized")
197 .with_code("auth:session_not_found");
198 return Ok(err.into_response());
199 }
200 Ok(Some(row)) => row,
201 };
202
203 parts.extensions.insert(Session::from(raw));
204 }
205 }
206
207 parts.extensions.insert(claims);
209
210 let request = Request::from_parts(parts, body);
211 inner.call(request).await
212 })
213 }
214}
215
216#[cfg(test)]
217mod tests {
218 use super::*;
219 use http::{Response, StatusCode};
220 use std::convert::Infallible;
221 use tower::ServiceExt;
222
223 use crate::auth::session::jwt::{Claims, JwtEncoder, JwtSessionsConfig};
224
225 fn test_config() -> JwtSessionsConfig {
226 JwtSessionsConfig {
227 signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
228 ..JwtSessionsConfig::default()
229 }
230 }
231
232 fn now_secs() -> u64 {
233 std::time::SystemTime::now()
234 .duration_since(std::time::UNIX_EPOCH)
235 .unwrap()
236 .as_secs()
237 }
238
239 fn make_token(config: &JwtSessionsConfig) -> String {
240 let encoder = JwtEncoder::from_config(config);
241 let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
242 encoder.encode(&claims).unwrap()
243 }
244
245 async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
246 let has_claims = req.extensions().get::<Claims>().is_some();
247 let body = if has_claims { "ok" } else { "no-claims" };
248 Ok(Response::new(Body::from(body)))
249 }
250
251 #[tokio::test]
252 async fn valid_token_passes_through() {
253 let config = test_config();
254 let decoder = JwtDecoder::from_config(&config);
255 let token = make_token(&config);
256 let layer = JwtLayer::new(decoder);
257 let svc = layer.layer(tower::service_fn(echo_handler));
258
259 let req = Request::builder()
260 .header("Authorization", format!("Bearer {token}"))
261 .body(Body::empty())
262 .unwrap();
263 let resp = svc.oneshot(req).await.unwrap();
264 assert_eq!(resp.status(), StatusCode::OK);
265 }
266
267 #[tokio::test]
268 async fn missing_header_returns_401() {
269 let config = test_config();
270 let decoder = JwtDecoder::from_config(&config);
271 let layer = JwtLayer::new(decoder);
272 let svc = layer.layer(tower::service_fn(echo_handler));
273
274 let req = Request::builder().body(Body::empty()).unwrap();
275 let resp = svc.oneshot(req).await.unwrap();
276 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
277 }
278
279 #[tokio::test]
280 async fn expired_token_returns_401() {
281 let config = test_config();
282 let encoder = JwtEncoder::from_config(&config);
283 let decoder = JwtDecoder::from_config(&config);
284 let claims = Claims::new().with_exp(now_secs() - 10);
285 let token = encoder.encode(&claims).unwrap();
286 let layer = JwtLayer::new(decoder);
287 let svc = layer.layer(tower::service_fn(echo_handler));
288
289 let req = Request::builder()
290 .header("Authorization", format!("Bearer {token}"))
291 .body(Body::empty())
292 .unwrap();
293 let resp = svc.oneshot(req).await.unwrap();
294 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
295 }
296
297 #[tokio::test]
298 async fn tampered_token_returns_401() {
299 let config = test_config();
300 let decoder = JwtDecoder::from_config(&config);
301 let token = make_token(&config);
302 let dot = token.rfind('.').unwrap();
306 let mid = dot + (token.len() - dot) / 2;
307 let mut bytes = token.into_bytes();
308 bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
309 let token = String::from_utf8(bytes).unwrap();
310 let layer = JwtLayer::new(decoder);
311 let svc = layer.layer(tower::service_fn(echo_handler));
312
313 let req = Request::builder()
314 .header("Authorization", format!("Bearer {token}"))
315 .body(Body::empty())
316 .unwrap();
317 let resp = svc.oneshot(req).await.unwrap();
318 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
319 }
320
321 #[tokio::test]
322 async fn claims_inserted_into_extensions() {
323 let config = test_config();
324 let decoder = JwtDecoder::from_config(&config);
325 let token = make_token(&config);
326 let layer = JwtLayer::new(decoder);
327
328 let inner = tower::service_fn(|req: Request<Body>| async move {
329 let claims = req.extensions().get::<Claims>().unwrap();
330 assert_eq!(claims.subject(), Some("user_1"));
331 Ok::<_, Infallible>(Response::new(Body::empty()))
332 });
333
334 let svc = layer.layer(inner);
335 let req = Request::builder()
336 .header("Authorization", format!("Bearer {token}"))
337 .body(Body::empty())
338 .unwrap();
339 let resp = svc.oneshot(req).await.unwrap();
340 assert_eq!(resp.status(), StatusCode::OK);
341 }
342
343 #[tokio::test]
344 async fn custom_token_source_works() {
345 let config = test_config();
346 let decoder = JwtDecoder::from_config(&config);
347 let token = make_token(&config);
348 let layer = JwtLayer::new(decoder).with_sources(vec![Arc::new(
349 super::super::source::QuerySource("token"),
350 ) as Arc<dyn TokenSource>]);
351 let svc = layer.layer(tower::service_fn(echo_handler));
352
353 let req = Request::builder()
354 .uri(format!("/path?token={token}"))
355 .body(Body::empty())
356 .unwrap();
357 let resp = svc.oneshot(req).await.unwrap();
358 assert_eq!(resp.status(), StatusCode::OK);
359 }
360}