1use std::future::Future;
2use std::marker::PhantomData;
3use std::pin::Pin;
4use std::sync::Arc;
5use std::task::{Context, Poll};
6
7use axum::body::Body;
8use axum::response::IntoResponse;
9use http::Request;
10use serde::de::DeserializeOwned;
11use tower::{Layer, Service};
12
13use crate::Error;
14
15use super::claims::Claims;
16use super::decoder::JwtDecoder;
17use super::error::JwtError;
18use super::revocation::Revocation;
19use super::source::{BearerSource, TokenSource};
20
21pub struct JwtLayer<T> {
32 decoder: JwtDecoder,
33 sources: Arc<[Arc<dyn TokenSource>]>,
34 revocation: Option<Arc<dyn Revocation>>,
35 _marker: PhantomData<T>,
36}
37
38impl<T> JwtLayer<T>
39where
40 T: DeserializeOwned + Clone + Send + Sync + 'static,
41{
42 pub fn new(decoder: JwtDecoder) -> Self {
45 Self {
46 decoder,
47 sources: Arc::from(vec![Arc::new(BearerSource) as Arc<dyn TokenSource>]),
48 revocation: None,
49 _marker: PhantomData,
50 }
51 }
52
53 pub fn with_sources(mut self, sources: Vec<Arc<dyn TokenSource>>) -> Self {
57 self.sources = Arc::from(sources);
58 self
59 }
60
61 pub fn with_revocation(mut self, revocation: Arc<dyn Revocation>) -> Self {
64 self.revocation = Some(revocation);
65 self
66 }
67}
68
69impl<T> Clone for JwtLayer<T> {
70 fn clone(&self) -> Self {
71 Self {
72 decoder: self.decoder.clone(),
73 sources: self.sources.clone(),
74 revocation: self.revocation.clone(),
75 _marker: PhantomData,
76 }
77 }
78}
79
80impl<Svc, T> Layer<Svc> for JwtLayer<T>
81where
82 T: DeserializeOwned + Clone + Send + Sync + 'static,
83{
84 type Service = JwtMiddleware<Svc, T>;
85
86 fn layer(&self, inner: Svc) -> Self::Service {
87 JwtMiddleware {
88 inner,
89 decoder: self.decoder.clone(),
90 sources: self.sources.clone(),
91 revocation: self.revocation.clone(),
92 _marker: PhantomData,
93 }
94 }
95}
96
97pub struct JwtMiddleware<Svc, T> {
99 inner: Svc,
100 decoder: JwtDecoder,
101 sources: Arc<[Arc<dyn TokenSource>]>,
102 revocation: Option<Arc<dyn Revocation>>,
103 _marker: PhantomData<T>,
104}
105
106impl<Svc: Clone, T> Clone for JwtMiddleware<Svc, T> {
107 fn clone(&self) -> Self {
108 Self {
109 inner: self.inner.clone(),
110 decoder: self.decoder.clone(),
111 sources: self.sources.clone(),
112 revocation: self.revocation.clone(),
113 _marker: PhantomData,
114 }
115 }
116}
117
118impl<Svc, T> Service<Request<Body>> for JwtMiddleware<Svc, T>
119where
120 Svc: Service<Request<Body>, Response = http::Response<Body>> + Clone + Send + 'static,
121 Svc::Future: Send + 'static,
122 Svc::Error: Into<Box<dyn std::error::Error + Send + Sync>> + Send + 'static,
123 T: DeserializeOwned + Clone + Send + Sync + 'static,
124{
125 type Response = http::Response<Body>;
126 type Error = Svc::Error;
127 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
128
129 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130 self.inner.poll_ready(cx)
131 }
132
133 fn call(&mut self, request: Request<Body>) -> Self::Future {
134 let decoder = self.decoder.clone();
135 let sources = self.sources.clone();
136 let revocation = self.revocation.clone();
137 let mut inner = self.inner.clone();
138 std::mem::swap(&mut self.inner, &mut inner);
139
140 Box::pin(async move {
141 let (mut parts, body) = request.into_parts();
142
143 let token = sources.iter().find_map(|s| s.extract(&parts));
145 let token = match token {
146 Some(t) => t,
147 None => {
148 let err = Error::unauthorized("unauthorized")
149 .chain(JwtError::MissingToken)
150 .with_code(JwtError::MissingToken.code());
151 return Ok(err.into_response());
152 }
153 };
154
155 let claims: Claims<T> = match decoder.decode(&token) {
157 Ok(c) => c,
158 Err(e) => return Ok(e.into_response()),
159 };
160
161 if let (Some(rev), Some(jti)) = (&revocation, claims.token_id()) {
163 match rev.is_revoked(jti).await {
164 Ok(true) => {
165 let err = Error::unauthorized("unauthorized")
166 .chain(JwtError::Revoked)
167 .with_code(JwtError::Revoked.code());
168 return Ok(err.into_response());
169 }
170 Err(e) => {
171 tracing::warn!(error = %e, jti = jti, "JWT revocation check failed");
172 let err = Error::unauthorized("unauthorized")
173 .chain(JwtError::RevocationCheckFailed)
174 .with_code(JwtError::RevocationCheckFailed.code());
175 return Ok(err.into_response());
176 }
177 Ok(false) => {} }
179 }
180
181 parts.extensions.insert(claims);
183
184 let request = Request::from_parts(parts, body);
185 inner.call(request).await
186 })
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use http::{Response, StatusCode};
194 use std::convert::Infallible;
195 use tower::ServiceExt;
196
197 use crate::auth::jwt::{Claims, JwtConfig, JwtEncoder};
198
199 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
200 struct TestClaims {
201 role: String,
202 }
203
204 fn test_config() -> JwtConfig {
205 JwtConfig {
206 secret: "test-secret-key-at-least-32-bytes-long!".into(),
207 default_expiry: None,
208 leeway: 0,
209 issuer: None,
210 audience: None,
211 }
212 }
213
214 fn now_secs() -> u64 {
215 std::time::SystemTime::now()
216 .duration_since(std::time::UNIX_EPOCH)
217 .unwrap()
218 .as_secs()
219 }
220
221 fn make_token(config: &JwtConfig) -> String {
222 let encoder = JwtEncoder::from_config(config);
223 let claims = Claims::new(TestClaims {
224 role: "admin".into(),
225 })
226 .with_sub("user_1")
227 .with_exp(now_secs() + 3600);
228 encoder.encode(&claims).unwrap()
229 }
230
231 async fn echo_handler(req: Request<Body>) -> Result<Response<Body>, Infallible> {
232 let has_claims = req.extensions().get::<Claims<TestClaims>>().is_some();
233 let body = if has_claims { "ok" } else { "no-claims" };
234 Ok(Response::new(Body::from(body)))
235 }
236
237 #[tokio::test]
238 async fn valid_token_passes_through() {
239 let config = test_config();
240 let decoder = JwtDecoder::from_config(&config);
241 let token = make_token(&config);
242 let layer = JwtLayer::<TestClaims>::new(decoder);
243 let svc = layer.layer(tower::service_fn(echo_handler));
244
245 let req = Request::builder()
246 .header("Authorization", format!("Bearer {token}"))
247 .body(Body::empty())
248 .unwrap();
249 let resp = svc.oneshot(req).await.unwrap();
250 assert_eq!(resp.status(), StatusCode::OK);
251 }
252
253 #[tokio::test]
254 async fn missing_header_returns_401() {
255 let config = test_config();
256 let decoder = JwtDecoder::from_config(&config);
257 let layer = JwtLayer::<TestClaims>::new(decoder);
258 let svc = layer.layer(tower::service_fn(echo_handler));
259
260 let req = Request::builder().body(Body::empty()).unwrap();
261 let resp = svc.oneshot(req).await.unwrap();
262 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
263 }
264
265 #[tokio::test]
266 async fn expired_token_returns_401() {
267 let config = test_config();
268 let encoder = JwtEncoder::from_config(&config);
269 let decoder = JwtDecoder::from_config(&config);
270 let claims = Claims::new(TestClaims {
271 role: "admin".into(),
272 })
273 .with_exp(now_secs() - 10);
274 let token = encoder.encode(&claims).unwrap();
275 let layer = JwtLayer::<TestClaims>::new(decoder);
276 let svc = layer.layer(tower::service_fn(echo_handler));
277
278 let req = Request::builder()
279 .header("Authorization", format!("Bearer {token}"))
280 .body(Body::empty())
281 .unwrap();
282 let resp = svc.oneshot(req).await.unwrap();
283 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
284 }
285
286 #[tokio::test]
287 async fn tampered_token_returns_401() {
288 let config = test_config();
289 let decoder = JwtDecoder::from_config(&config);
290 let token = make_token(&config);
291 let dot = token.rfind('.').unwrap();
295 let mid = dot + (token.len() - dot) / 2;
296 let mut bytes = token.into_bytes();
297 bytes[mid] = if bytes[mid] == b'A' { b'Z' } else { b'A' };
298 let token = String::from_utf8(bytes).unwrap();
299 let layer = JwtLayer::<TestClaims>::new(decoder);
300 let svc = layer.layer(tower::service_fn(echo_handler));
301
302 let req = Request::builder()
303 .header("Authorization", format!("Bearer {token}"))
304 .body(Body::empty())
305 .unwrap();
306 let resp = svc.oneshot(req).await.unwrap();
307 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
308 }
309
310 #[tokio::test]
311 async fn claims_inserted_into_extensions() {
312 let config = test_config();
313 let decoder = JwtDecoder::from_config(&config);
314 let token = make_token(&config);
315 let layer = JwtLayer::<TestClaims>::new(decoder);
316
317 let inner = tower::service_fn(|req: Request<Body>| async move {
318 let claims = req.extensions().get::<Claims<TestClaims>>().unwrap();
319 assert_eq!(claims.custom.role, "admin");
320 assert_eq!(claims.subject(), Some("user_1"));
321 Ok::<_, Infallible>(Response::new(Body::empty()))
322 });
323
324 let svc = layer.layer(inner);
325 let req = Request::builder()
326 .header("Authorization", format!("Bearer {token}"))
327 .body(Body::empty())
328 .unwrap();
329 let resp = svc.oneshot(req).await.unwrap();
330 assert_eq!(resp.status(), StatusCode::OK);
331 }
332
333 #[tokio::test]
334 async fn custom_token_source_works() {
335 let config = test_config();
336 let decoder = JwtDecoder::from_config(&config);
337 let token = make_token(&config);
338 let layer = JwtLayer::<TestClaims>::new(decoder)
339 .with_sources(vec![
340 Arc::new(super::super::source::QuerySource("token")) as Arc<dyn TokenSource>
341 ]);
342 let svc = layer.layer(tower::service_fn(echo_handler));
343
344 let req = Request::builder()
345 .uri(format!("/path?token={token}"))
346 .body(Body::empty())
347 .unwrap();
348 let resp = svc.oneshot(req).await.unwrap();
349 assert_eq!(resp.status(), StatusCode::OK);
350 }
351}