1use base64ct::Base64UrlUnpadded;
2use base64ct::Encoding;
3use serde::{de::DeserializeOwned, Serialize};
4
5use crate::claims::*;
6use crate::common::*;
7use crate::ensure;
8use crate::error::*;
9use crate::jwt_header::*;
10
11pub const MAX_HEADER_LENGTH: usize = 8192;
12
13pub struct Token;
15
16#[derive(Debug, Clone, Default)]
18pub struct TokenMetadata {
19 pub(crate) jwt_header: JWTHeader,
20}
21
22impl TokenMetadata {
23 pub fn algorithm(&self) -> &str {
28 &self.jwt_header.algorithm
29 }
30
31 pub fn content_type(&self) -> Option<&str> {
33 self.jwt_header.content_type.as_deref()
34 }
35
36 pub fn key_id(&self) -> Option<&str> {
38 self.jwt_header.key_id.as_deref()
39 }
40
41 pub fn signature_type(&self) -> Option<&str> {
43 self.jwt_header.signature_type.as_deref()
44 }
45
46 pub fn critical(&self) -> Option<&[String]> {
48 self.jwt_header.critical.as_deref()
49 }
50
51 pub fn certificate_chain(&self) -> Option<&[String]> {
55 self.jwt_header.certificate_chain.as_deref()
56 }
57
58 pub fn key_set_url(&self) -> Option<&str> {
63 self.jwt_header.key_set_url.as_deref()
64 }
65
66 pub fn public_key(&self) -> Option<&str> {
71 self.jwt_header.public_key.as_deref()
72 }
73
74 pub fn certificate_url(&self) -> Option<&str> {
79 self.jwt_header.certificate_url.as_deref()
80 }
81
82 pub fn certificate_sha256_thumbprint(&self) -> Option<&str> {
87 self.jwt_header.certificate_sha256_thumbprint.as_deref()
88 }
89}
90
91impl Token {
92 pub(crate) fn build<AuthenticationOrSignatureFn, CustomClaims: Serialize + DeserializeOwned>(
93 jwt_header: &JWTHeader,
94 claims: JWTClaims<CustomClaims>,
95 authentication_or_signature_fn: AuthenticationOrSignatureFn,
96 ) -> Result<String, JWTError>
97 where
98 AuthenticationOrSignatureFn: FnOnce(&str) -> Result<Vec<u8>, JWTError>,
99 {
100 let jwt_header_json = serde_json::to_string(&jwt_header)?;
101 let claims_json = serde_json::to_string(&claims)?;
102 let authenticated = format!(
103 "{}.{}",
104 Base64UrlUnpadded::encode_string(jwt_header_json.as_bytes()),
105 Base64UrlUnpadded::encode_string(claims_json.as_bytes())
106 );
107 let authentication_tag_or_signature = authentication_or_signature_fn(&authenticated)?;
108 let mut token = authenticated;
109 token.push('.');
110 token.push_str(&Base64UrlUnpadded::encode_string(
111 &authentication_tag_or_signature,
112 ));
113 Ok(token)
114 }
115
116 pub(crate) fn verify<AuthenticationOrSignatureFn, CustomClaims: Serialize + DeserializeOwned>(
117 jwt_alg_name: &'static str,
118 token: &str,
119 options: Option<VerificationOptions>,
120 authentication_or_signature_fn: AuthenticationOrSignatureFn,
121 ) -> Result<JWTClaims<CustomClaims>, JWTError>
122 where
123 AuthenticationOrSignatureFn: FnOnce(&str, &[u8]) -> Result<(), JWTError>,
124 {
125 let options = options.unwrap_or_default();
126
127 if let Some(max_token_length) = options.max_token_length {
128 ensure!(token.len() <= max_token_length, JWTError::TokenTooLong);
129 }
130
131 let mut parts = token.split('.');
132 let jwt_header_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
133 ensure!(
134 jwt_header_b64.len() <= options.max_header_length.unwrap_or(MAX_HEADER_LENGTH),
135 JWTError::HeaderTooLarge
136 );
137 let claims_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
138 let authentication_tag_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
139 ensure!(parts.next().is_none(), JWTError::CompactEncodingError);
140 let jwt_header: JWTHeader =
141 serde_json::from_slice(&Base64UrlUnpadded::decode_vec(jwt_header_b64)?)?;
142 if let Some(signature_type) = &jwt_header.signature_type {
143 let signature_type_uc = signature_type.to_uppercase();
144 ensure!(
145 signature_type_uc == "JWT" || signature_type_uc.ends_with("+JWT"),
146 JWTError::NotJWT
147 );
148 }
149 ensure!(
150 jwt_header.algorithm == jwt_alg_name,
151 JWTError::AlgorithmMismatch
152 );
153 if let Some(required_key_id) = &options.required_key_id {
154 if let Some(key_id) = &jwt_header.key_id {
155 ensure!(key_id == required_key_id, JWTError::KeyIdentifierMismatch);
156 } else {
157 return Err(JWTError::MissingJWTKeyIdentifier);
158 }
159 }
160 let authentication_tag = Base64UrlUnpadded::decode_vec(authentication_tag_b64)?;
161 let authenticated = &token[..jwt_header_b64.len() + 1 + claims_b64.len()];
162 authentication_or_signature_fn(authenticated, &authentication_tag)?;
163 let claims: JWTClaims<CustomClaims> =
164 serde_json::from_slice(&Base64UrlUnpadded::decode_vec(claims_b64)?)?;
165 claims.validate(&options)?;
166 Ok(claims)
167 }
168
169 pub fn decode_metadata(token: &str) -> Result<TokenMetadata, JWTError> {
172 let mut parts = token.split('.');
173 let jwt_header_b64 = parts.next().ok_or(JWTError::CompactEncodingError)?;
174 ensure!(
175 jwt_header_b64.len() <= MAX_HEADER_LENGTH,
176 JWTError::HeaderTooLarge
177 );
178 let jwt_header: JWTHeader =
179 serde_json::from_slice(&Base64UrlUnpadded::decode_vec(jwt_header_b64)?)?;
180 Ok(TokenMetadata { jwt_header })
181 }
182}
183
184#[test]
185fn should_verify_token() {
186 use crate::prelude::*;
187
188 let key_pair = Ed25519KeyPair::generate();
189
190 let issuer = "issuer";
191 let audience = "recipient";
192 let nonce = "some_nonce";
193 let claims = Claims::create(Duration::from_mins(10))
194 .with_issuer(issuer)
195 .with_audience(audience)
196 .with_nonce(nonce);
197
198 let token = key_pair.sign(claims).unwrap();
199
200 let options = VerificationOptions {
201 required_nonce: Some(nonce.to_string()),
202 allowed_issuers: Some(HashSet::from_strings(&[issuer])),
203 allowed_audiences: Some(HashSet::from_strings(&[audience])),
204 ..Default::default()
205 };
206 key_pair
207 .public_key()
208 .verify_token::<NoCustomClaims>(&token, Some(options))
209 .unwrap();
210}
211
212#[test]
213fn multiple_audiences() {
214 use std::collections::HashSet;
215
216 use crate::prelude::*;
217
218 let key_pair = Ed25519KeyPair::generate();
219
220 let mut audiences = HashSet::new();
221 audiences.insert("audience 1");
222 audiences.insert("audience 2");
223 audiences.insert("audience 3");
224 let claims = Claims::create(Duration::from_mins(10)).with_audiences(audiences);
225 let token = key_pair.sign(claims).unwrap();
226
227 let options = VerificationOptions {
228 allowed_audiences: Some(HashSet::from_strings(&["audience 1"])),
229 ..Default::default()
230 };
231 key_pair
232 .public_key()
233 .verify_token::<NoCustomClaims>(&token, Some(options))
234 .unwrap();
235}
236
237#[test]
238fn explicitly_empty_audiences() {
239 use std::collections::HashSet;
240
241 use crate::prelude::*;
242
243 let key_pair = Ed25519KeyPair::generate();
244
245 let audiences: HashSet<&str> = HashSet::new();
246 let claims = Claims::create(Duration::from_mins(10)).with_audiences(audiences);
247 let token = key_pair.sign(claims).unwrap();
248 let decoded = key_pair
249 .public_key()
250 .verify_token::<NoCustomClaims>(&token, None)
251 .unwrap();
252 assert!(decoded.audiences.is_some());
253
254 let claims = Claims::create(Duration::from_mins(10)).with_audience("");
255 let token = key_pair.sign(claims).unwrap();
256 let decoded = key_pair
257 .public_key()
258 .verify_token::<NoCustomClaims>(&token, None)
259 .unwrap();
260 assert!(decoded.audiences.is_some());
261
262 let claims = Claims::create(Duration::from_mins(10));
263 let token = key_pair.sign(claims).unwrap();
264 let decoded = key_pair
265 .public_key()
266 .verify_token::<NoCustomClaims>(&token, None)
267 .unwrap();
268 assert!(decoded.audiences.is_none());
269}