axum_gate/
jwt.rs

1//! JWT related models like claims or encoding.
2use crate::Error;
3use crate::codecs::CodecService;
4use crate::jsonwebtoken::{DecodingKey, EncodingKey, Header, Validation};
5use chrono::{TimeDelta, Utc};
6use serde::{Deserialize, Serialize, de::DeserializeOwned};
7use serde_with::skip_serializing_none;
8use std::collections::HashSet;
9use std::marker::PhantomData;
10
11/// Registered/reserved claims by IANA/JWT spec, see
12/// [auth0](https://auth0.com/docs/secure/tokens/json-web-tokens/json-web-token-claims) for more
13/// information.
14#[derive(Serialize, Deserialize, Clone, Debug)]
15#[skip_serializing_none]
16pub struct RegisteredClaims {
17    /// Issuer of the JWT
18    #[serde(rename = "iss")]
19    pub issuer: Option<HashSet<String>>,
20    /// Subject of the JWT (the user)
21    #[serde(rename = "sub")]
22    pub subject: Option<String>,
23    /// Recipient for which the JWT is intended
24    #[serde(rename = "aud")]
25    pub audience: Option<HashSet<String>>,
26    /// Time after which the JWT expires
27    #[serde(rename = "exp")]
28    pub expiration_time: Option<u64>,
29    /// Time before which the JWT must not be accepted for processing
30    #[serde(rename = "nbf")]
31    pub not_before_time: Option<u64>,
32    /// Time at which the JWT was issued; can be used to determine age of the JWT
33    #[serde(rename = "iat")]
34    pub issued_at_time: Option<u64>,
35    /// Unique identifier; can be used to prevent the JWT from being replayed (allows a token to be used only once)
36    #[serde(rename = "jti")]
37    pub jwt_id: Option<String>,
38}
39
40impl Default for RegisteredClaims {
41    /// Initializes the claims with `expiration_time` set to 1 week.
42    fn default() -> Self {
43        Self {
44            issuer: None,
45            subject: None,
46            audience: None,
47            expiration_time: Some((Utc::now() + TimeDelta::weeks(1)).timestamp() as u64),
48            not_before_time: None,
49            issued_at_time: None,
50            jwt_id: None,
51        }
52    }
53}
54
55/// Default claims for the use with `axum-gate`s [JsonWebToken] codec.
56#[derive(Serialize, Deserialize, Clone, Debug)]
57pub struct JwtClaims<CustomClaims> {
58    /// The registered claims of a JWT.
59    #[serde(flatten)]
60    pub registered_claims: RegisteredClaims,
61    /// Your custom claims that are added to the JWT.
62    #[serde(flatten)]
63    pub custom_claims: CustomClaims,
64}
65
66impl<CustomClaims> JwtClaims<CustomClaims> {
67    /// Creates a new claim with default registered claims and the given custom claims.
68    pub fn new(custom_claims: CustomClaims) -> Self {
69        Self {
70            registered_claims: RegisteredClaims::default(),
71            custom_claims,
72        }
73    }
74
75    /// Creates new claims with the given registered claims.
76    pub fn new_with_registered(
77        custom_claims: CustomClaims,
78        registered_claims: RegisteredClaims,
79    ) -> Self {
80        Self {
81            custom_claims,
82            registered_claims,
83        }
84    }
85}
86
87/// Options to configure the [JsonWebToken] codec.
88pub struct JsonWebTokenOptions {
89    /// Key for encoding.
90    pub enc_key: EncodingKey,
91    /// Key for decoding.
92    pub dec_key: DecodingKey,
93    /// The header used for encoding.
94    pub header: Option<Header>,
95    /// Validation options.
96    pub validation: Option<Validation>,
97}
98
99impl Default for JsonWebTokenOptions {
100    /// Creates a random, alphanumeric 60 char key and uses it for en- and decoding (symmetric).
101    /// [Header] and [Validation] are set with its default values.
102    fn default() -> Self {
103        use rand::{Rng, distr::Alphanumeric, rng};
104
105        let authentication_secret: String = rng()
106            .sample_iter(&Alphanumeric)
107            .take(60)
108            .map(char::from)
109            .collect();
110        Self {
111            enc_key: EncodingKey::from_secret(authentication_secret.as_bytes()),
112            dec_key: DecodingKey::from_secret(authentication_secret.as_bytes()),
113            header: Some(Header::default()),
114            validation: Some(Validation::default()),
115        }
116    }
117}
118
119/// Encrypts using the given keys as JWT using [jsonwebtoken].
120#[derive(Clone)]
121pub struct JsonWebToken<P> {
122    /// Key for encoding.
123    enc_key: EncodingKey,
124    /// Key for decoding.
125    dec_key: DecodingKey,
126    /// The header used for encoding.
127    pub header: Header,
128    /// Validation options for the JWT.
129    pub validation: Validation,
130    phantom_payload: PhantomData<P>,
131}
132
133impl<P> JsonWebToken<P> {
134    /// Creates a new instance with the given encoding and decoding keys.
135    pub fn new_with_options(options: JsonWebTokenOptions) -> Self {
136        let JsonWebTokenOptions {
137            enc_key,
138            dec_key,
139            header,
140            validation,
141        } = options;
142        Self {
143            enc_key,
144            dec_key,
145            header: header.unwrap_or(Header::default()),
146            validation: validation.unwrap_or(Validation::default()),
147            phantom_payload: PhantomData,
148        }
149    }
150}
151
152impl<P> Default for JsonWebToken<P> {
153    fn default() -> Self {
154        Self::new_with_options(JsonWebTokenOptions::default())
155    }
156}
157
158impl<P> CodecService for JsonWebToken<P>
159where
160    P: Serialize + DeserializeOwned + Clone,
161{
162    type Payload = P;
163    fn encode(&self, payload: &Self::Payload) -> Result<Vec<u8>, Error> {
164        let web_token = jsonwebtoken::encode(&self.header, payload, &self.enc_key)
165            .map_err(|e| Error::Codec(format!("{e}")))?;
166        Ok(web_token.as_bytes().to_vec())
167    }
168    /// Decodes the given value.
169    ///
170    /// # Errors
171    /// Returns an error if the header stored in [JsonWebToken] does not match the decoded value.
172    /// The header can be retrieved from [JsonWebToken::header].
173    fn decode(&self, encoded_value: &[u8]) -> Result<Self::Payload, Error> {
174        let claims = jsonwebtoken::decode::<Self::Payload>(
175            &String::from_utf8_lossy(encoded_value),
176            &self.dec_key,
177            &self.validation,
178        )
179        .map_err(|e| Error::Codec(format!("{e}")))?;
180
181        if self.header != claims.header {
182            return Err(Error::Codec(format!(
183                "Header of the decoded value does not match the one used for encoding."
184            )));
185        }
186
187        Ok(claims.claims)
188    }
189}