modo/auth/session/jwt/
middleware.rs1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use axum::body::Body;
7use axum::response::IntoResponse;
8use http::Request;
9use tower::{Layer, Service};
10
11use crate::auth::session::Session;
12
13use super::claims::Claims;
14use super::decoder::{JwtDecoder, auth_err, jwt_err};
15use super::error::JwtError;
16use super::service::JwtSessionService;
17use super::source::{BearerSource, TokenSource};
18use crate::auth::session::token::SessionToken;
19
20#[derive(Clone)]
33pub struct JwtLayer {
34 decoder: JwtDecoder,
35 sources: Arc<[Arc<dyn TokenSource>]>,
36 service: Option<JwtSessionService>,
40}
41
42impl JwtLayer {
43 pub fn new(decoder: JwtDecoder) -> Self {
50 Self {
51 decoder,
52 sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
53 service: None,
54 }
55 }
56
57 pub fn from_service(service: JwtSessionService) -> Self {
68 let decoder = service.decoder().clone();
69 Self {
70 decoder,
71 sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
72 service: Some(service),
73 }
74 }
75
76 pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
80 self.sources = Arc::from(sources);
81 self
82 }
83}
84
85impl<Svc> Layer<Svc> for JwtLayer {
86 type Service = JwtMiddleware<Svc>;
87
88 fn layer(&self, inner: Svc) -> Self::Service {
89 JwtMiddleware {
90 inner,
91 decoder: self.decoder.clone(),
92 sources: self.sources.clone(),
93 service: self.service.clone(),
94 }
95 }
96}
97
98pub struct JwtMiddleware<Svc> {
100 inner: Svc,
101 decoder: JwtDecoder,
102 sources: Arc<[Arc<dyn TokenSource>]>,
103 service: Option<JwtSessionService>,
104}
105
106impl<Svc: Clone> Clone for JwtMiddleware<Svc> {
107 fn clone(&self) -> Self {
108 Self {
109 inner: self.inner.clone(),
110 decoder: self.decoder.clone(),
111 sources: self.sources.clone(),
112 service: self.service.clone(),
113 }
114 }
115}
116
117impl<Svc> Service<Request<Body>> for JwtMiddleware<Svc>
118where
119 Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
120 Svc::Future: Send + 'static,
121 Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
122{
123 type Response = http::Response<Body>;
124 type Error = Svc::Error;
125 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
126
127 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
128 self.inner.poll_ready(cx)
129 }
130
131 fn call(&mut self, request: Request<Body>) -> Self::Future {
132 let decoder = self.decoder.clone();
133 let sources = self.sources.clone();
134 let service = self.service.clone();
135 let mut inner = self.inner.clone();
136 std::mem::swap(&mut self.inner, &mut inner);
137
138 Box::pin(async move {
139 let (mut parts, body) = request.into_parts();
140
141 let token = match sources.iter().find_map(|s| s.extract(&parts)) {
142 Some(t) => t,
143 None => return Ok(jwt_err(JwtError::MissingToken).into_response()),
144 };
145
146 let claims: Claims = match decoder.decode(&token) {
147 Ok(c) => c,
148 Err(e) => return Ok(e.into_response()),
149 };
150
151 if let Some(svc) = service {
152 if claims.aud.as_deref() != Some("access") {
153 return Ok(auth_err("auth:aud_mismatch").into_response());
154 }
155
156 if svc.config().stateful_validation {
157 let session_token = match claims.jti.as_deref().and_then(SessionToken::from_raw)
158 {
159 Some(t) => t,
160 None => {
161 return Ok(auth_err("auth:session_not_found").into_response());
162 }
163 };
164
165 let raw = match svc.store().read_by_token_hash(&session_token.hash()).await {
166 Err(e) => return Ok(e.into_response()),
167 Ok(None) => {
168 return Ok(auth_err("auth:session_not_found").into_response());
169 }
170 Ok(Some(row)) => row,
171 };
172
173 parts.extensions.insert(Session::from(raw));
174 }
175 }
176
177 parts.extensions.insert(claims);
178
179 let request = Request::from_parts(parts, body);
180 inner.call(request).await
181 })
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use super::*;
188 use http::{Response, StatusCode};
189 use std::convert::Infallible;
190 use tower::ServiceExt;
191
192 use crate::auth::session::jwt::{Claims, JwtEncoder, JwtSessionsConfig};
193
194 fn test_config() -> JwtSessionsConfig {
195 JwtSessionsConfig {
196 signing_secret: "test-secret-key-at-least-32-bytes-long!".into(),
197 ..JwtSessionsConfig::default()
198 }
199 }
200
201 fn now_secs() -> u64 {
202 std::time::SystemTime::now()
203 .duration_since(std::time::UNIX_EPOCH)
204 .unwrap()
205 .as_secs()
206 }
207
208 fn make_token(config: &JwtSessionsConfig) -> String {
209 let encoder = JwtEncoder::from_config(config);
210 let claims = Claims::new().with_sub("user_1").with_exp(now_secs() + 3600);
211 encoder.encode(&claims).unwrap()
212 }
213
214 async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
215 let has_claims = req.extensions().get::<Claims>().is_some();
216 let body = if has_claims { "ok" } else { "no-claims" };
217 Ok(Response::new(Body::from(body)))
218 }
219
220 #[tokio::test]
221 async fn valid_token_passes_through() {
222 let config = test_config();
223 let decoder = JwtDecoder::from_config(&config);
224 let token = make_token(&config);
225 let layer = JwtLayer::new(decoder);
226 let svc = layer.layer(tower::service_fn(echo_handler));
227
228 let req = Request::builder()
229 .header("Authorization", format!("Bearer {token}"))
230 .body(Body::empty())
231 .unwrap();
232 let resp = svc.oneshot(req).await.unwrap();
233 assert_eq!(resp.status(), StatusCode::OK);
234 }
235
236 #[tokio::test]
237 async fn missing_header_returns_401() {
238 let config = test_config();
239 let decoder = JwtDecoder::from_config(&config);
240 let layer = JwtLayer::new(decoder);
241 let svc = layer.layer(tower::service_fn(echo_handler));
242
243 let req = Request::builder().body(Body::empty()).unwrap();
244 let resp = svc.oneshot(req).await.unwrap();
245 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
246 }
247
248 #[tokio::test]
249 async fn expired_token_returns_401() {
250 let config = test_config();
251 let encoder = JwtEncoder::from_config(&config);
252 let decoder = JwtDecoder::from_config(&config);
253 let claims = Claims::new().with_exp(now_secs() - 10);
254 let token = encoder.encode(&claims).unwrap();
255 let layer = JwtLayer::new(decoder);
256 let svc = layer.layer(tower::service_fn(echo_handler));
257
258 let req = Request::builder()
259 .header("Authorization", format!("Bearer {token}"))
260 .body(Body::empty())
261 .unwrap();
262 let resp = svc.oneshot(req).await.unwrap();
263 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
264 }
265
266 #[tokio::test]
267 async fn tampered_token_returns_401() {
268 let config = test_config();
269 let decoder = JwtDecoder::from_config(&config);
270 let token = make_token(&config);
271 let dot = token.rfind('.').unwrap();
275 let mid = dot + (token.len() - dot) / 2;
276 let mut bytes = token.into_bytes();
277 bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
278 let token = String::from_utf8(bytes).unwrap();
279 let layer = JwtLayer::new(decoder);
280 let svc = layer.layer(tower::service_fn(echo_handler));
281
282 let req = Request::builder()
283 .header("Authorization", format!("Bearer {token}"))
284 .body(Body::empty())
285 .unwrap();
286 let resp = svc.oneshot(req).await.unwrap();
287 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
288 }
289
290 #[tokio::test]
291 async fn claims_inserted_into_extensions() {
292 let config = test_config();
293 let decoder = JwtDecoder::from_config(&config);
294 let token = make_token(&config);
295 let layer = JwtLayer::new(decoder);
296
297 let inner = tower::service_fn(|req: Request<Body>| async move {
298 let claims = req.extensions().get::<Claims>().unwrap();
299 assert_eq!(claims.subject(), Some("user_1"));
300 Ok::<_, Infallible>(Response::new(Body::empty()))
301 });
302
303 let svc = layer.layer(inner);
304 let req = Request::builder()
305 .header("Authorization", format!("Bearer {token}"))
306 .body(Body::empty())
307 .unwrap();
308 let resp = svc.oneshot(req).await.unwrap();
309 assert_eq!(resp.status(), StatusCode::OK);
310 }
311
312 #[tokio::test]
313 async fn custom_token_source_works() {
314 let config = test_config();
315 let decoder = JwtDecoder::from_config(&config);
316 let token = make_token(&config);
317 let layer = JwtLayer::new(decoder).with_sources(vec![Arc::new(
318 super::super::source::QuerySource("token"),
319 ) as Arc<dyn TokenSource>]);
320 let svc = layer.layer(tower::service_fn(echo_handler));
321
322 let req = Request::builder()
323 .uri(format!("/path?token={token}"))
324 .body(Body::empty())
325 .unwrap();
326 let resp = svc.oneshot(req).await.unwrap();
327 assert_eq!(resp.status(), StatusCode::OK);
328 }
329}