1#![doc = include_str!("../README.md")]
2
3use std::{convert::Infallible, future::Future, marker::PhantomData, pin::Pin, task::Poll};
4
5use async_trait::async_trait;
6use headers::{authorization::Bearer, Authorization, HeaderMapExt};
7use http::{Request, Response, StatusCode};
8use jsonwebtoken::decode;
9pub use jsonwebtoken::{DecodingKey, Validation};
10use pin_project::pin_project;
11use serde::Deserialize;
12use tower::{Layer, Service};
13use tracing::{error, trace};
14
15#[async_trait]
17pub trait DecodingKeyFn: Send + Sync + Clone {
18 type Error: std::error::Error + Send;
19
20 async fn decoding_key(&self) -> Result<DecodingKey, Self::Error>;
21}
22
23#[async_trait]
24impl<F, O> DecodingKeyFn for F
25where
26 F: Fn() -> O + Sync + Send + Clone,
27 O: Future<Output = DecodingKey> + Send,
28{
29 type Error = Infallible;
30
31 async fn decoding_key(&self) -> Result<DecodingKey, Self::Error> {
32 Ok((self)().await)
33 }
34}
35
36#[async_trait]
37impl DecodingKeyFn for DecodingKey {
38 type Error = Infallible;
39
40 async fn decoding_key(&self) -> Result<DecodingKey, Self::Error> {
41 Ok(self.clone())
42 }
43}
44
45#[derive(Clone)]
50pub struct JwtLayer<Claim, F = DecodingKey> {
51 decoding_key_fn: F,
53 validation: Validation,
55 _phantom: PhantomData<Claim>,
56}
57
58impl<Claim, F: DecodingKeyFn> JwtLayer<Claim, F> {
59 pub fn new(validation: Validation, decoding_key_fn: F) -> Self {
62 Self {
63 decoding_key_fn,
64 validation,
65 _phantom: PhantomData,
66 }
67 }
68}
69
70impl<S, Claim, F: DecodingKeyFn> Layer<S> for JwtLayer<Claim, F> {
71 type Service = Jwt<S, Claim, F>;
72
73 fn layer(&self, inner: S) -> Self::Service {
74 Jwt {
75 inner,
76 decoding_key_fn: self.decoding_key_fn.clone(),
77 validation: Box::new(self.validation.clone()),
78 _phantom: self._phantom,
79 }
80 }
81}
82
83#[derive(Clone)]
85pub struct Jwt<S, Claim, F> {
86 inner: S,
87 decoding_key_fn: F,
88 validation: Box<Validation>,
90 _phantom: PhantomData<Claim>,
91}
92
93type AsyncTraitFuture<A> = Pin<Box<dyn Future<Output = A> + Send>>;
94
95#[pin_project(project = JwtFutureProj, project_replace = JwtFutureProjOwn)]
96pub enum JwtFuture<
97 DecKeyFn: DecodingKeyFn,
98 TService: Service<Request<ReqBody>, Response = Response<ResBody>>,
99 ReqBody,
100 ResBody,
101 Claim,
102> {
103 Error,
105
106 WaitForFuture {
108 #[pin]
109 future: TService::Future,
110 },
111
112 HasTokenWaitingForDecodingKey {
114 bearer: Authorization<Bearer>,
115 request: Request<ReqBody>,
116 #[pin]
117 decoding_key_future: AsyncTraitFuture<Result<DecodingKey, DecKeyFn::Error>>,
118 validation: Box<Validation>,
119 service: TService,
120 _phantom: PhantomData<Claim>,
121 },
122}
123
124impl<DecKeyFn, TService, ReqBody, ResBody, Claim> Future
125 for JwtFuture<DecKeyFn, TService, ReqBody, ResBody, Claim>
126where
127 DecKeyFn: DecodingKeyFn + 'static,
128 TService: Service<Request<ReqBody>, Response = Response<ResBody>>,
129 ResBody: Default,
130 for<'de> Claim: Deserialize<'de> + Send + Sync + Clone + 'static,
131{
132 type Output = Result<TService::Response, TService::Error>;
133
134 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
135 match self.as_mut().project() {
136 JwtFutureProj::Error => {
137 let response = Response::builder()
138 .status(StatusCode::BAD_REQUEST)
139 .body(Default::default())
140 .unwrap();
141 Poll::Ready(Ok(response))
142 }
143 JwtFutureProj::WaitForFuture { future } => future.poll(cx),
144 JwtFutureProj::HasTokenWaitingForDecodingKey {
145 bearer,
146 decoding_key_future,
147 validation,
148 ..
149 } => match decoding_key_future.poll(cx) {
150 Poll::Pending => Poll::Pending,
151 Poll::Ready(Err(error)) => {
152 error!(
153 error = &error as &dyn std::error::Error,
154 "failed to get decoding key"
155 );
156 let response = Response::builder()
157 .status(StatusCode::SERVICE_UNAVAILABLE)
158 .body(Default::default())
159 .unwrap();
160
161 Poll::Ready(Ok(response))
162 }
163 Poll::Ready(Ok(decoding_key)) => {
164 let claim_result = RequestClaim::<Claim>::from_token(
165 bearer.token().trim(),
166 &decoding_key,
167 validation,
168 );
169 match claim_result {
170 Err(code) => {
171 error!(code = %code, "failed to decode JWT");
172
173 let response = Response::builder()
174 .status(code)
175 .body(Default::default())
176 .unwrap();
177
178 Poll::Ready(Ok(response))
179 }
180 Ok(claim) => {
181 let owned = self.as_mut().project_replace(JwtFuture::Error);
182 match owned {
183 JwtFutureProjOwn::HasTokenWaitingForDecodingKey {
184 mut request, mut service, ..
185 } => {
186 request.extensions_mut().insert(claim);
187 let future = service.call(request);
188 self.as_mut().set(JwtFuture::WaitForFuture { future });
189 self.poll(cx)
190 },
191 _ => unreachable!("We know that we're in the 'HasTokenWaitingForDecodingKey' state"),
192 }
193 }
194 }
195 }
196 },
197 }
198 }
199}
200
201impl<S, ReqBody, ResBody, Claim, F> Service<Request<ReqBody>> for Jwt<S, Claim, F>
202where
203 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Send + Clone + 'static,
204 S::Future: Send + 'static,
205 ResBody: Default,
206 F: DecodingKeyFn + 'static,
207 <F as DecodingKeyFn>::Error: 'static,
208 for<'de> Claim: Deserialize<'de> + Send + Sync + Clone + 'static,
209{
210 type Response = S::Response;
211 type Error = S::Error;
212 type Future = JwtFuture<F, S, ReqBody, ResBody, Claim>;
213
214 fn poll_ready(
215 &mut self,
216 cx: &mut std::task::Context<'_>,
217 ) -> std::task::Poll<Result<(), Self::Error>> {
218 self.inner.poll_ready(cx)
219 }
220
221 fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
222 match req.headers().typed_try_get::<Authorization<Bearer>>() {
223 Ok(Some(bearer)) => {
224 let decoding_key_fn = self.decoding_key_fn.clone();
225 let decoding_key_future =
226 Box::pin(async move { decoding_key_fn.decoding_key().await });
227 Self::Future::HasTokenWaitingForDecodingKey {
228 bearer,
229 request: req,
230 decoding_key_future,
231 validation: self.validation.clone(),
232 service: self.inner.clone(),
233 _phantom: self._phantom,
234 }
235 }
236 Ok(None) => {
237 let future = self.inner.call(req);
238
239 Self::Future::WaitForFuture { future }
240 }
241 Err(_) => Self::Future::Error,
242 }
243 }
244}
245
246#[derive(Clone, Debug)]
248pub struct RequestClaim<T>
249where
250 for<'de> T: Deserialize<'de>,
251{
252 pub claim: T,
254
255 pub token: String,
257}
258
259impl<T> RequestClaim<T>
260where
261 for<'de> T: Deserialize<'de>,
262{
263 pub fn from_token(
264 token: &str,
265 decoding_key: &DecodingKey,
266 validation: &Validation,
267 ) -> Result<Self, StatusCode> {
268 trace!("converting token to claim");
269 let claim: T = decode(token, decoding_key, validation)
270 .map_err(|err| {
271 error!(
272 error = &err as &dyn std::error::Error,
273 "failed to convert token to claim"
274 );
275 match err.kind() {
276 jsonwebtoken::errors::ErrorKind::ExpiredSignature => {
277 StatusCode::from_u16(499).unwrap() }
279 jsonwebtoken::errors::ErrorKind::InvalidSignature
280 | jsonwebtoken::errors::ErrorKind::InvalidAlgorithmName
281 | jsonwebtoken::errors::ErrorKind::InvalidIssuer
282 | jsonwebtoken::errors::ErrorKind::ImmatureSignature => {
283 StatusCode::UNAUTHORIZED
284 }
285 jsonwebtoken::errors::ErrorKind::InvalidToken
286 | jsonwebtoken::errors::ErrorKind::InvalidAlgorithm
287 | jsonwebtoken::errors::ErrorKind::Base64(_)
288 | jsonwebtoken::errors::ErrorKind::Json(_)
289 | jsonwebtoken::errors::ErrorKind::Utf8(_) => StatusCode::BAD_REQUEST,
290 jsonwebtoken::errors::ErrorKind::MissingAlgorithm => {
291 StatusCode::INTERNAL_SERVER_ERROR
292 }
293 jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
294 _ => StatusCode::INTERNAL_SERVER_ERROR,
295 }
296 })?
297 .claims;
298
299 Ok(Self {
300 claim,
301 token: token.to_string(),
302 })
303 }
304}
305
306#[cfg(test)]
307mod tests {
308 use axum::{routing::get, Extension, Router};
309 use chrono::{Duration, Utc};
310 use http_body_util::{BodyExt, Empty};
311 use jsonwebtoken::{encode, EncodingKey, Header};
312 use ring::{
313 rand,
314 signature::{self, Ed25519KeyPair, KeyPair},
315 };
316 use serde::Serialize;
317 use std::ops::Add;
318 use tower::ServiceExt;
319
320 use super::*;
321
322 #[derive(Deserialize, Clone, Serialize)]
323 pub struct Claim {
324 pub exp: usize,
326 iat: usize,
328 iss: String,
330 nbf: usize,
332 pub sub: String,
334 }
335
336 impl Claim {
337 pub fn new(sub: String) -> Self {
339 let iat = Utc::now();
340 let exp = iat.add(Duration::try_minutes(5).unwrap());
341
342 Self {
343 exp: exp.timestamp() as usize,
344 iat: iat.timestamp() as usize,
345 iss: "test-issuer".to_string(),
346 nbf: iat.timestamp() as usize,
347 sub,
348 }
349 }
350
351 pub fn into_token(self, encoding_key: &EncodingKey) -> Result<String, StatusCode> {
352 encode(
353 &Header::new(jsonwebtoken::Algorithm::EdDSA),
354 &self,
355 encoding_key,
356 )
357 .map_err(|err| {
358 error!(
359 error = &err as &dyn std::error::Error,
360 "failed to convert claim to token"
361 );
362 match err.kind() {
363 jsonwebtoken::errors::ErrorKind::Json(_) => StatusCode::INTERNAL_SERVER_ERROR,
364 jsonwebtoken::errors::ErrorKind::Crypto(_) => StatusCode::SERVICE_UNAVAILABLE,
365 _ => StatusCode::INTERNAL_SERVER_ERROR,
366 }
367 })
368 }
369 }
370
371 #[tokio::test]
372 async fn authorization_layer() {
373 let claim = Claim::new("ferries".to_string());
374
375 let doc = signature::Ed25519KeyPair::generate_pkcs8(&rand::SystemRandom::new()).unwrap();
376 let encoding_key = EncodingKey::from_ed_der(doc.as_ref());
377 let pair = Ed25519KeyPair::from_pkcs8(doc.as_ref()).unwrap();
378 let public_key = pair.public_key().as_ref().to_vec();
379 let decoding_key = DecodingKey::from_ed_der(&public_key);
380
381 let mut validation = Validation::new(jsonwebtoken::Algorithm::EdDSA);
382 validation.set_issuer(&["test-issuer"]);
383
384 let router = Router::new()
385 .route(
386 "/",
387 get(|claim: Option<Extension<RequestClaim<Claim>>>| async move {
388 if let Some(Extension(claim)) = claim {
389 (StatusCode::OK, format!("Hello, {}", claim.claim.sub))
390 } else {
391 (StatusCode::UNAUTHORIZED, "Not authorized".to_string())
392 }
393 }),
394 )
395 .layer(JwtLayer::<Claim, _>::new(validation, move || {
396 let decoding_key = decoding_key.clone();
397
398 async { decoding_key }
399 }));
400
401 let response = router
405 .clone()
406 .oneshot(
407 http::Request::builder()
408 .uri("/")
409 .body(Empty::new())
410 .unwrap(),
411 )
412 .await
413 .unwrap();
414
415 assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
416
417 let token = claim.clone().into_token(&encoding_key).unwrap();
421 let response = router
422 .clone()
423 .oneshot(
424 http::Request::builder()
425 .uri("/")
426 .header("authorization", token.clone())
427 .body(Empty::new())
428 .unwrap(),
429 )
430 .await
431 .unwrap();
432
433 assert_eq!(response.status(), StatusCode::BAD_REQUEST);
434
435 let response = router
439 .clone()
440 .oneshot(
441 http::Request::builder()
442 .uri("/")
443 .header("authorization", format!("Bearer {token}"))
444 .body(Empty::new())
445 .unwrap(),
446 )
447 .await
448 .unwrap();
449
450 assert_eq!(response.status(), StatusCode::OK);
451
452 let response = router
456 .clone()
457 .oneshot(
458 http::Request::builder()
459 .uri("/")
460 .header("Authorization", format!("Bearer {token} "))
461 .body(Empty::new())
462 .unwrap(),
463 )
464 .await
465 .unwrap();
466
467 assert_eq!(response.status(), StatusCode::OK);
468 let body = response.into_body().collect().await.unwrap().to_bytes();
469
470 assert_eq!(&body[..], b"Hello, ferries");
471 }
472}