axum_jwt_auth/
local.rs

1use async_trait::async_trait;
2use derive_builder::Builder;
3use jsonwebtoken::{DecodingKey, TokenData, Validation};
4use serde::de::DeserializeOwned;
5
6use crate::{Error, JwtDecoder};
7
8/// Local decoder
9/// It uses the given JWKS to decode the JWT tokens.
10#[derive(Clone, Builder)]
11pub struct LocalDecoder {
12    keys: Vec<DecodingKey>,
13    validation: Validation,
14}
15
16impl LocalDecoder {
17    pub fn new(keys: Vec<DecodingKey>, validation: Validation) -> Result<Self, Error> {
18        if keys.is_empty() {
19            return Err(Error::Configuration("No decoding keys provided".into()));
20        }
21
22        if validation.algorithms.is_empty() {
23            return Err(Error::Configuration(
24                "Validation algorithm is required".into(),
25            ));
26        }
27
28        if validation.aud.is_none() {
29            return Err(Error::Configuration(
30                "Validation audience is required".into(),
31            ));
32        }
33
34        Ok(Self { keys, validation })
35    }
36
37    pub fn builder() -> LocalDecoderBuilder {
38        LocalDecoderBuilder::default()
39    }
40}
41
42#[async_trait]
43impl<T> JwtDecoder<T> for LocalDecoder
44where
45    T: for<'de> DeserializeOwned,
46{
47    async fn decode(&self, token: &str) -> Result<TokenData<T>, Error> {
48        // Try to decode the token with each key in the cache
49        // If none of them work, return the error from the last one
50        let mut last_error: Option<Error> = None;
51        for key in self.keys.iter() {
52            match jsonwebtoken::decode::<T>(token, key, &self.validation) {
53                Ok(token_data) => return Ok(token_data),
54                Err(e) => {
55                    tracing::error!("Error decoding token: {}", e);
56                    last_error = Some(Error::Jwt(e));
57                }
58            }
59        }
60
61        Err(last_error.unwrap_or_else(|| Error::Configuration("No keys available".into())))
62    }
63}