keygate_jwt/
token.rs

1use base64ct::Base64UrlUnpadded;
2use base64ct::Encoding;
3use serde::{de::DeserializeOwned, Serialize};
4
5use crate::claims::*;
6use crate::common::*;
7use crate::ensure;
8use crate::error::*;
9use crate::jwt_header::*;
10
11pub const MAX_HEADER_LENGTH: usize = 8192;
12
13/// Utilities to get information about a JWT token
14pub struct Token;
15
16/// JWT token information useful before signature/tag verification
17#[derive(Debug, Clone, Default)]
18pub struct TokenMetadata {
19    pub(crate) jwt_header: JWTHeader,
20}
21
22impl TokenMetadata {
23    /// The JWT algorithm for this token ("alg")
24    /// This information should not be trusted: it is unprotected and can be
25    /// freely modified by a third party. Clients should ignore it and use
26    /// the correct type of key directly.
27    pub fn algorithm(&self) -> &str {
28        &self.jwt_header.algorithm
29    }
30
31    /// The content type for this token ("cty")
32    pub fn content_type(&self) -> Option<&str> {
33        self.jwt_header.content_type.as_deref()
34    }
35
36    /// The key, or public key identifier for this token ("kid")
37    pub fn key_id(&self) -> Option<&str> {
38        self.jwt_header.key_id.as_deref()
39    }
40
41    /// The signature type for this token ("typ")
42    pub fn signature_type(&self) -> Option<&str> {
43        self.jwt_header.signature_type.as_deref()
44    }
45
46    /// The set of raw critical properties for this token ("crit")
47    pub fn critical(&self) -> Option<&[String]> {
48        self.jwt_header.critical.as_deref()
49    }
50
51    /// The certificate chain for this token ("x5c")
52    /// This information should not be trusted: it is unprotected and can be
53    /// freely modified by a third party.
54    pub fn certificate_chain(&self) -> Option<&[String]> {
55        self.jwt_header.certificate_chain.as_deref()
56    }
57
58    /// The key set URL for this token ("jku")
59    /// This information should not be trusted: it is unprotected and can be
60    /// freely modified by a third party. At the bare minimum, you should
61    /// check that the URL belongs to the domain you expect.
62    pub fn key_set_url(&self) -> Option<&str> {
63        self.jwt_header.key_set_url.as_deref()
64    }
65
66    /// The public key for this token ("jwk")
67    /// This information should not be trusted: it is unprotected and can be
68    /// freely modified by a third party. At the bare minimum, you should
69    /// check that it's in a set of public keys you already trust.
70    pub fn public_key(&self) -> Option<&str> {
71        self.jwt_header.public_key.as_deref()
72    }
73
74    /// The certificate URL for this token ("x5u")
75    /// This information should not be trusted: it is unprotected and can be
76    /// freely modified by a third party. At the bare minimum, you should
77    /// check that the URL belongs to the domain you expect.
78    pub fn certificate_url(&self) -> Option<&str> {
79        self.jwt_header.certificate_url.as_deref()
80    }
81
82    /// URLsafe-base64-encoded SHA256 hash of the X.509 certificate for this
83    /// token ("x5t#256") In practice, it can also be any string
84    /// representing the public key. This information should not be trusted:
85    /// it is unprotected and can be freely modified by a third party.
86    pub fn certificate_sha256_thumbprint(&self) -> Option<&str> {
87        self.jwt_header.certificate_sha256_thumbprint.as_deref()
88    }
89}
90
91impl Token {
92    pub(crate) fn build<AuthenticationOrSignatureFn, CustomClaims: Serialize + DeserializeOwned>(
93        jwt_header: &JWTHeader,
94        claims: JWTClaims<CustomClaims>,
95        authentication_or_signature_fn: AuthenticationOrSignatureFn,
96    ) -> Result<String, JWTError>
97    where
98        AuthenticationOrSignatureFn: FnOnce(&str) -> Result<Vec<u8>, JWTError>,
99    {
100        let jwt_header_json = serde_json::to_string(&jwt_header)?;
101        let claims_json = serde_json::to_string(&claims)?;
102        let authenticated = format!(
103            "{}.{}",
104            Base64UrlUnpadded::encode_string(jwt_header_json.as_bytes()),
105            Base64UrlUnpadded::encode_string(claims_json.as_bytes())
106        );
107        let authentication_tag_or_signature = authentication_or_signature_fn(&authenticated)?;
108        let mut token = authenticated;
109        token.push('.');
110        token.push_str(&Base64UrlUnpadded::encode_string(
111            &authentication_tag_or_signature,
112        ));
113        Ok(token)
114    }
115
116    pub(crate) fn verify<AuthenticationOrSignatureFn, CustomClaims: Serialize + DeserializeOwned>(
117        jwt_alg_name: &'static str,
118        token: &str,
119        options: Option<VerificationOptions>,
120        authentication_or_signature_fn: AuthenticationOrSignatureFn,
121    ) -> Result<JWTClaims<CustomClaims>, JWTError>
122    where
123        AuthenticationOrSignatureFn: FnOnce(&str, &[u8]) -> Result<(), JWTError>,
124    {
125        let options = options.unwrap_or_default();
126
127        if let Some(max_token_length) = options.max_token_length {
128            ensure!(token.len() <= max_token_length, JWTError::TokenTooLong);
129        }
130
131        let mut parts = token.split('.');
132        let jwt_header_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
133        ensure!(
134            jwt_header_b64.len() <= options.max_header_length.unwrap_or(MAX_HEADER_LENGTH),
135            JWTError::HeaderTooLarge
136        );
137        let claims_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
138        let authentication_tag_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
139        ensure!(parts.next().is_none(), JWTError::CompactEncodingError);
140        let jwt_header: JWTHeader =
141            serde_json::from_slice(&Base64UrlUnpadded::decode_vec(jwt_header_b64)?)?;
142        if let Some(signature_type) = &jwt_header.signature_type {
143            let signature_type_uc = signature_type.to_uppercase();
144            ensure!(
145                signature_type_uc == "JWT" || signature_type_uc.ends_with("+JWT"),
146                JWTError::NotJWT
147            );
148        }
149        ensure!(
150            jwt_header.algorithm == jwt_alg_name,
151            JWTError::AlgorithmMismatch
152        );
153        if let Some(required_key_id) = &options.required_key_id {
154            if let Some(key_id) = &jwt_header.key_id {
155                ensure!(key_id == required_key_id, JWTError::KeyIdentifierMismatch);
156            } else {
157                return Err(JWTError::MissingJWTKeyIdentifier);
158            }
159        }
160        let authentication_tag = Base64UrlUnpadded::decode_vec(authentication_tag_b64)?;
161        let authenticated = &token[..jwt_header_b64.len() + 1 + claims_b64.len()];
162        authentication_or_signature_fn(authenticated, &authentication_tag)?;
163        let claims: JWTClaims<CustomClaims> =
164            serde_json::from_slice(&Base64UrlUnpadded::decode_vec(claims_b64)?)?;
165        claims.validate(&options)?;
166        Ok(claims)
167    }
168
169    /// Decode token information that can be usedful prior to signature/tag
170    /// verification
171    pub fn decode_metadata(token: &str) -> Result<TokenMetadata, JWTError> {
172        let mut parts = token.split('.');
173        let jwt_header_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
174        ensure!(
175            jwt_header_b64.len() <= MAX_HEADER_LENGTH,
176            JWTError::HeaderTooLarge
177        );
178        let jwt_header: JWTHeader =
179            serde_json::from_slice(&Base64UrlUnpadded::decode_vec(jwt_header_b64)?)?;
180        Ok(TokenMetadata { jwt_header })
181    }
182}
183
184#[test]
185fn should_verify_token() {
186    use crate::prelude::*;
187
188    let key_pair = Ed25519KeyPair::generate();
189
190    let issuer = "issuer";
191    let audience = "recipient";
192    let nonce = "some_nonce";
193    let claims = Claims::create(Duration::from_mins(10))
194        .with_issuer(issuer)
195        .with_audience(audience)
196        .with_nonce(nonce);
197
198    let token = key_pair.sign(claims).unwrap();
199
200    let options = VerificationOptions {
201        required_nonce: Some(nonce.to_string()),
202        allowed_issuers: Some(HashSet::from_strings(&[issuer])),
203        allowed_audiences: Some(HashSet::from_strings(&[audience])),
204        ..Default::default()
205    };
206    key_pair
207        .public_key()
208        .verify_token::<NoCustomClaims>(&token, Some(options))
209        .unwrap();
210}
211
212#[test]
213fn multiple_audiences() {
214    use std::collections::HashSet;
215
216    use crate::prelude::*;
217
218    let key_pair = Ed25519KeyPair::generate();
219
220    let mut audiences = HashSet::new();
221    audiences.insert("audience 1");
222    audiences.insert("audience 2");
223    audiences.insert("audience 3");
224    let claims = Claims::create(Duration::from_mins(10)).with_audiences(audiences);
225    let token = key_pair.sign(claims).unwrap();
226
227    let options = VerificationOptions {
228        allowed_audiences: Some(HashSet::from_strings(&["audience 1"])),
229        ..Default::default()
230    };
231    key_pair
232        .public_key()
233        .verify_token::<NoCustomClaims>(&token, Some(options))
234        .unwrap();
235}
236
237#[test]
238fn explicitly_empty_audiences() {
239    use std::collections::HashSet;
240
241    use crate::prelude::*;
242
243    let key_pair = Ed25519KeyPair::generate();
244
245    let audiences: HashSet<&str> = HashSet::new();
246    let claims = Claims::create(Duration::from_mins(10)).with_audiences(audiences);
247    let token = key_pair.sign(claims).unwrap();
248    let decoded = key_pair
249        .public_key()
250        .verify_token::<NoCustomClaims>(&token, None)
251        .unwrap();
252    assert!(decoded.audiences.is_some());
253
254    let claims = Claims::create(Duration::from_mins(10)).with_audience("");
255    let token = key_pair.sign(claims).unwrap();
256    let decoded = key_pair
257        .public_key()
258        .verify_token::<NoCustomClaims>(&token, None)
259        .unwrap();
260    assert!(decoded.audiences.is_some());
261
262    let claims = Claims::create(Duration::from_mins(10));
263    let token = key_pair.sign(claims).unwrap();
264    let decoded = key_pair
265        .public_key()
266        .verify_token::<NoCustomClaims>(&token, None)
267        .unwrap();
268    assert!(decoded.audiences.is_none());
269}