1use base64ct::{Base64UrlUnpadded, Encoding};
16use rand::thread_rng;
17use serde::{de::DeserializeOwned, Serialize};
18use signature::{rand_core::CryptoRngCore, RandomizedSigner, SignatureEncoding, Verifier};
19use thiserror::Error;
20
21use super::{header::JsonWebSignatureHeader, raw::RawJwt};
22use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet};
23
24#[derive(Clone, PartialEq, Eq)]
25pub struct Jwt<'a, T> {
26 raw: RawJwt<'a>,
27 header: JsonWebSignatureHeader,
28 payload: T,
29 signature: Vec<u8>,
30}
31
32impl<'a, T> std::fmt::Display for Jwt<'a, T> {
33 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34 write!(f, "{}", self.raw)
35 }
36}
37
38impl<'a, T> std::fmt::Debug for Jwt<'a, T>
39where
40 T: std::fmt::Debug,
41{
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 f.debug_struct("Jwt")
44 .field("raw", &"...")
45 .field("header", &self.header)
46 .field("payload", &self.payload)
47 .field("signature", &"...")
48 .finish()
49 }
50}
51
52#[derive(Debug, Error)]
53pub enum JwtDecodeError {
54 #[error(transparent)]
55 RawDecode {
56 #[from]
57 inner: super::raw::DecodeError,
58 },
59
60 #[error("failed to decode JWT header")]
61 DecodeHeader {
62 #[source]
63 inner: base64ct::Error,
64 },
65
66 #[error("failed to deserialize JWT header")]
67 DeserializeHeader {
68 #[source]
69 inner: serde_json::Error,
70 },
71
72 #[error("failed to decode JWT payload")]
73 DecodePayload {
74 #[source]
75 inner: base64ct::Error,
76 },
77
78 #[error("failed to deserialize JWT payload")]
79 DeserializePayload {
80 #[source]
81 inner: serde_json::Error,
82 },
83
84 #[error("failed to decode JWT signature")]
85 DecodeSignature {
86 #[source]
87 inner: base64ct::Error,
88 },
89}
90
91impl JwtDecodeError {
92 fn decode_header(inner: base64ct::Error) -> Self {
93 Self::DecodeHeader { inner }
94 }
95
96 fn deserialize_header(inner: serde_json::Error) -> Self {
97 Self::DeserializeHeader { inner }
98 }
99
100 fn decode_payload(inner: base64ct::Error) -> Self {
101 Self::DecodePayload { inner }
102 }
103
104 fn deserialize_payload(inner: serde_json::Error) -> Self {
105 Self::DeserializePayload { inner }
106 }
107
108 fn decode_signature(inner: base64ct::Error) -> Self {
109 Self::DecodeSignature { inner }
110 }
111}
112
113impl<'a, T> TryFrom<RawJwt<'a>> for Jwt<'a, T>
114where
115 T: DeserializeOwned,
116{
117 type Error = JwtDecodeError;
118 fn try_from(raw: RawJwt<'a>) -> Result<Self, Self::Error> {
119 let header_reader =
120 base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes())
121 .map_err(JwtDecodeError::decode_header)?;
122 let header =
123 serde_json::from_reader(header_reader).map_err(JwtDecodeError::deserialize_header)?;
124
125 let payload_reader =
126 base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.payload().as_bytes())
127 .map_err(JwtDecodeError::decode_payload)?;
128 let payload =
129 serde_json::from_reader(payload_reader).map_err(JwtDecodeError::deserialize_payload)?;
130
131 let signature = Base64UrlUnpadded::decode_vec(raw.signature())
132 .map_err(JwtDecodeError::decode_signature)?;
133
134 Ok(Self {
135 raw,
136 header,
137 payload,
138 signature,
139 })
140 }
141}
142
143impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
144where
145 T: DeserializeOwned,
146{
147 type Error = JwtDecodeError;
148 fn try_from(value: &'a str) -> Result<Self, Self::Error> {
149 let raw = RawJwt::try_from(value)?;
150 Self::try_from(raw)
151 }
152}
153
154impl<T> TryFrom<String> for Jwt<'static, T>
155where
156 T: DeserializeOwned,
157{
158 type Error = JwtDecodeError;
159 fn try_from(value: String) -> Result<Self, Self::Error> {
160 let raw = RawJwt::try_from(value)?;
161 Self::try_from(raw)
162 }
163}
164
165#[derive(Debug, Error)]
166pub enum JwtVerificationError {
167 #[error("failed to parse signature")]
168 ParseSignature,
169
170 #[error("signature verification failed")]
171 Verify {
172 #[source]
173 inner: signature::Error,
174 },
175}
176
177impl JwtVerificationError {
178 #[allow(clippy::needless_pass_by_value)]
179 fn parse_signature<E>(_inner: E) -> Self {
180 Self::ParseSignature
181 }
182
183 fn verify(inner: signature::Error) -> Self {
184 Self::Verify { inner }
185 }
186}
187
188#[derive(Debug, Error, Default)]
189#[error("none of the keys worked")]
190pub struct NoKeyWorked {
191 _inner: (),
192}
193
194impl<'a, T> Jwt<'a, T> {
195 pub fn header(&self) -> &JsonWebSignatureHeader {
197 &self.header
198 }
199
200 pub fn payload(&self) -> &T {
202 &self.payload
203 }
204
205 pub fn into_owned(self) -> Jwt<'static, T> {
206 Jwt {
207 raw: self.raw.into_owned(),
208 header: self.header,
209 payload: self.payload,
210 signature: self.signature,
211 }
212 }
213
214 pub fn verify<K, S>(&self, key: &K) -> Result<(), JwtVerificationError>
220 where
221 K: Verifier<S>,
222 S: SignatureEncoding,
223 {
224 let signature =
225 S::try_from(&self.signature).map_err(JwtVerificationError::parse_signature)?;
226
227 key.verify(self.raw.signed_part().as_bytes(), &signature)
228 .map_err(JwtVerificationError::verify)
229 }
230
231 pub fn verify_with_shared_secret(&self, secret: Vec<u8>) -> Result<(), NoKeyWorked> {
238 let verifier = crate::jwa::SymmetricKey::new_for_alg(secret, self.header().alg())
239 .map_err(|_| NoKeyWorked::default())?;
240
241 self.verify(&verifier).map_err(|_| NoKeyWorked::default())?;
242
243 Ok(())
244 }
245
246 pub fn verify_with_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> {
253 let constraints = ConstraintSet::from(self.header());
254 let candidates = constraints.filter(&**jwks);
255
256 for candidate in candidates {
257 let Ok(key) = crate::jwa::AsymmetricVerifyingKey::from_jwk_and_alg(
258 candidate.params(),
259 self.header().alg(),
260 ) else {
261 continue;
262 };
263
264 if self.verify(&key).is_ok() {
265 return Ok(());
266 }
267 }
268
269 Err(NoKeyWorked::default())
270 }
271
272 pub fn as_str(&'a self) -> &'a str {
274 &self.raw
275 }
276
277 pub fn into_string(self) -> String {
279 self.raw.into()
280 }
281
282 pub fn into_parts(self) -> (JsonWebSignatureHeader, T) {
284 (self.header, self.payload)
285 }
286}
287
288#[derive(Debug, Error)]
289pub enum JwtSignatureError {
290 #[error("failed to serialize header")]
291 EncodeHeader {
292 #[source]
293 inner: serde_json::Error,
294 },
295
296 #[error("failed to serialize payload")]
297 EncodePayload {
298 #[source]
299 inner: serde_json::Error,
300 },
301
302 #[error("failed to sign")]
303 Signature {
304 #[from]
305 inner: signature::Error,
306 },
307}
308
309impl JwtSignatureError {
310 fn encode_header(inner: serde_json::Error) -> Self {
311 Self::EncodeHeader { inner }
312 }
313
314 fn encode_payload(inner: serde_json::Error) -> Self {
315 Self::EncodePayload { inner }
316 }
317}
318
319impl<T> Jwt<'static, T> {
320 pub fn sign<K, S>(
327 header: JsonWebSignatureHeader,
328 payload: T,
329 key: &K,
330 ) -> Result<Self, JwtSignatureError>
331 where
332 K: RandomizedSigner<S>,
333 S: SignatureEncoding,
334 T: Serialize,
335 {
336 #[allow(clippy::disallowed_methods)]
337 Self::sign_with_rng(&mut thread_rng(), header, payload, key)
338 }
339
340 pub fn sign_with_rng<R, K, S>(
347 rng: &mut R,
348 header: JsonWebSignatureHeader,
349 payload: T,
350 key: &K,
351 ) -> Result<Self, JwtSignatureError>
352 where
353 R: CryptoRngCore,
354 K: RandomizedSigner<S>,
355 S: SignatureEncoding,
356 T: Serialize,
357 {
358 let header_ = serde_json::to_vec(&header).map_err(JwtSignatureError::encode_header)?;
359 let header_ = Base64UrlUnpadded::encode_string(&header_);
360
361 let payload_ = serde_json::to_vec(&payload).map_err(JwtSignatureError::encode_payload)?;
362 let payload_ = Base64UrlUnpadded::encode_string(&payload_);
363
364 let mut inner = format!("{header_}.{payload_}");
365
366 let first_dot = header_.len();
367 let second_dot = inner.len();
368
369 let signature = key.try_sign_with_rng(rng, inner.as_bytes())?.to_vec();
370 let signature_ = Base64UrlUnpadded::encode_string(&signature);
371 inner.reserve_exact(1 + signature_.len());
372 inner.push('.');
373 inner.push_str(&signature_);
374
375 let raw = RawJwt::new(inner, first_dot, second_dot);
376
377 Ok(Self {
378 raw,
379 header,
380 payload,
381 signature,
382 })
383 }
384}
385
386#[cfg(test)]
387mod tests {
388 #![allow(clippy::disallowed_methods)]
389 use mas_iana::jose::JsonWebSignatureAlg;
390 use rand::thread_rng;
391
392 use super::*;
393
394 #[test]
395 fn test_jwt_decode() {
396 let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
397 let jwt: Jwt<'_, serde_json::Value> = Jwt::try_from(jwt).unwrap();
398 assert_eq!(jwt.raw.header(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9");
399 assert_eq!(
400 jwt.raw.payload(),
401 "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
402 );
403 assert_eq!(
404 jwt.raw.signature(),
405 "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
406 );
407 assert_eq!(jwt.raw.signed_part(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ");
408 }
409
410 #[test]
411 fn test_jwt_sign_and_verify() {
412 let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Es256);
413 let payload = serde_json::json!({"hello": "world"});
414
415 let key = ecdsa::SigningKey::<p256::NistP256>::random(&mut thread_rng());
416 let signed = Jwt::sign::<_, ecdsa::Signature<_>>(header, payload, &key).unwrap();
417 signed
418 .verify::<_, ecdsa::Signature<_>>(key.verifying_key())
419 .unwrap();
420 }
421}