Skip to main content

mas_jose/jwt/
signed.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use base64ct::{Base64UrlUnpadded, Encoding};
16use rand::thread_rng;
17use serde::{de::DeserializeOwned, Serialize};
18use signature::{rand_core::CryptoRngCore, RandomizedSigner, SignatureEncoding, Verifier};
19use thiserror::Error;
20
21use super::{header::JsonWebSignatureHeader, raw::RawJwt};
22use crate::{constraints::ConstraintSet, jwk::PublicJsonWebKeySet};
23
24#[derive(Clone, PartialEq, Eq)]
25pub struct Jwt<'a, T> {
26    raw: RawJwt<'a>,
27    header: JsonWebSignatureHeader,
28    payload: T,
29    signature: Vec<u8>,
30}
31
32impl<'a, T> std::fmt::Display for Jwt<'a, T> {
33    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
34        write!(f, "{}", self.raw)
35    }
36}
37
38impl<'a, T> std::fmt::Debug for Jwt<'a, T>
39where
40    T: std::fmt::Debug,
41{
42    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43        f.debug_struct("Jwt")
44            .field("raw", &"...")
45            .field("header", &self.header)
46            .field("payload", &self.payload)
47            .field("signature", &"...")
48            .finish()
49    }
50}
51
52#[derive(Debug, Error)]
53pub enum JwtDecodeError {
54    #[error(transparent)]
55    RawDecode {
56        #[from]
57        inner: super::raw::DecodeError,
58    },
59
60    #[error("failed to decode JWT header")]
61    DecodeHeader {
62        #[source]
63        inner: base64ct::Error,
64    },
65
66    #[error("failed to deserialize JWT header")]
67    DeserializeHeader {
68        #[source]
69        inner: serde_json::Error,
70    },
71
72    #[error("failed to decode JWT payload")]
73    DecodePayload {
74        #[source]
75        inner: base64ct::Error,
76    },
77
78    #[error("failed to deserialize JWT payload")]
79    DeserializePayload {
80        #[source]
81        inner: serde_json::Error,
82    },
83
84    #[error("failed to decode JWT signature")]
85    DecodeSignature {
86        #[source]
87        inner: base64ct::Error,
88    },
89}
90
91impl JwtDecodeError {
92    fn decode_header(inner: base64ct::Error) -> Self {
93        Self::DecodeHeader { inner }
94    }
95
96    fn deserialize_header(inner: serde_json::Error) -> Self {
97        Self::DeserializeHeader { inner }
98    }
99
100    fn decode_payload(inner: base64ct::Error) -> Self {
101        Self::DecodePayload { inner }
102    }
103
104    fn deserialize_payload(inner: serde_json::Error) -> Self {
105        Self::DeserializePayload { inner }
106    }
107
108    fn decode_signature(inner: base64ct::Error) -> Self {
109        Self::DecodeSignature { inner }
110    }
111}
112
113impl<'a, T> TryFrom<RawJwt<'a>> for Jwt<'a, T>
114where
115    T: DeserializeOwned,
116{
117    type Error = JwtDecodeError;
118    fn try_from(raw: RawJwt<'a>) -> Result<Self, Self::Error> {
119        let header_reader =
120            base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.header().as_bytes())
121                .map_err(JwtDecodeError::decode_header)?;
122        let header =
123            serde_json::from_reader(header_reader).map_err(JwtDecodeError::deserialize_header)?;
124
125        let payload_reader =
126            base64ct::Decoder::<'_, Base64UrlUnpadded>::new(raw.payload().as_bytes())
127                .map_err(JwtDecodeError::decode_payload)?;
128        let payload =
129            serde_json::from_reader(payload_reader).map_err(JwtDecodeError::deserialize_payload)?;
130
131        let signature = Base64UrlUnpadded::decode_vec(raw.signature())
132            .map_err(JwtDecodeError::decode_signature)?;
133
134        Ok(Self {
135            raw,
136            header,
137            payload,
138            signature,
139        })
140    }
141}
142
143impl<'a, T> TryFrom<&'a str> for Jwt<'a, T>
144where
145    T: DeserializeOwned,
146{
147    type Error = JwtDecodeError;
148    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
149        let raw = RawJwt::try_from(value)?;
150        Self::try_from(raw)
151    }
152}
153
154impl<T> TryFrom<String> for Jwt<'static, T>
155where
156    T: DeserializeOwned,
157{
158    type Error = JwtDecodeError;
159    fn try_from(value: String) -> Result<Self, Self::Error> {
160        let raw = RawJwt::try_from(value)?;
161        Self::try_from(raw)
162    }
163}
164
165#[derive(Debug, Error)]
166pub enum JwtVerificationError {
167    #[error("failed to parse signature")]
168    ParseSignature,
169
170    #[error("signature verification failed")]
171    Verify {
172        #[source]
173        inner: signature::Error,
174    },
175}
176
177impl JwtVerificationError {
178    #[allow(clippy::needless_pass_by_value)]
179    fn parse_signature<E>(_inner: E) -> Self {
180        Self::ParseSignature
181    }
182
183    fn verify(inner: signature::Error) -> Self {
184        Self::Verify { inner }
185    }
186}
187
188#[derive(Debug, Error, Default)]
189#[error("none of the keys worked")]
190pub struct NoKeyWorked {
191    _inner: (),
192}
193
194impl<'a, T> Jwt<'a, T> {
195    /// Get the JWT header
196    pub fn header(&self) -> &JsonWebSignatureHeader {
197        &self.header
198    }
199
200    /// Get the JWT payload
201    pub fn payload(&self) -> &T {
202        &self.payload
203    }
204
205    pub fn into_owned(self) -> Jwt<'static, T> {
206        Jwt {
207            raw: self.raw.into_owned(),
208            header: self.header,
209            payload: self.payload,
210            signature: self.signature,
211        }
212    }
213
214    /// Verify the signature of this JWT using the given key.
215    ///
216    /// # Errors
217    ///
218    /// Returns an error if the signature is invalid.
219    pub fn verify<K, S>(&self, key: &K) -> Result<(), JwtVerificationError>
220    where
221        K: Verifier<S>,
222        S: SignatureEncoding,
223    {
224        let signature =
225            S::try_from(&self.signature).map_err(JwtVerificationError::parse_signature)?;
226
227        key.verify(self.raw.signed_part().as_bytes(), &signature)
228            .map_err(JwtVerificationError::verify)
229    }
230
231    /// Verify the signature of this JWT using the given symmetric key.
232    ///
233    /// # Errors
234    ///
235    /// Returns an error if the signature is invalid or if the algorithm is not
236    /// supported.
237    pub fn verify_with_shared_secret(&self, secret: Vec<u8>) -> Result<(), NoKeyWorked> {
238        let verifier = crate::jwa::SymmetricKey::new_for_alg(secret, self.header().alg())
239            .map_err(|_| NoKeyWorked::default())?;
240
241        self.verify(&verifier).map_err(|_| NoKeyWorked::default())?;
242
243        Ok(())
244    }
245
246    /// Verify the signature of this JWT using the given JWKS.
247    ///
248    /// # Errors
249    ///
250    /// Returns an error if the signature is invalid, if no key matches the
251    /// constraints, or if the algorithm is not supported.
252    pub fn verify_with_jwks(&self, jwks: &PublicJsonWebKeySet) -> Result<(), NoKeyWorked> {
253        let constraints = ConstraintSet::from(self.header());
254        let candidates = constraints.filter(&**jwks);
255
256        for candidate in candidates {
257            let Ok(key) = crate::jwa::AsymmetricVerifyingKey::from_jwk_and_alg(
258                candidate.params(),
259                self.header().alg(),
260            ) else {
261                continue;
262            };
263
264            if self.verify(&key).is_ok() {
265                return Ok(());
266            }
267        }
268
269        Err(NoKeyWorked::default())
270    }
271
272    /// Get the raw JWT string as a borrowed [`str`]
273    pub fn as_str(&'a self) -> &'a str {
274        &self.raw
275    }
276
277    /// Get the raw JWT string as an owned [`String`]
278    pub fn into_string(self) -> String {
279        self.raw.into()
280    }
281
282    /// Split the JWT into its parts (header and payload).
283    pub fn into_parts(self) -> (JsonWebSignatureHeader, T) {
284        (self.header, self.payload)
285    }
286}
287
288#[derive(Debug, Error)]
289pub enum JwtSignatureError {
290    #[error("failed to serialize header")]
291    EncodeHeader {
292        #[source]
293        inner: serde_json::Error,
294    },
295
296    #[error("failed to serialize payload")]
297    EncodePayload {
298        #[source]
299        inner: serde_json::Error,
300    },
301
302    #[error("failed to sign")]
303    Signature {
304        #[from]
305        inner: signature::Error,
306    },
307}
308
309impl JwtSignatureError {
310    fn encode_header(inner: serde_json::Error) -> Self {
311        Self::EncodeHeader { inner }
312    }
313
314    fn encode_payload(inner: serde_json::Error) -> Self {
315        Self::EncodePayload { inner }
316    }
317}
318
319impl<T> Jwt<'static, T> {
320    /// Sign the given payload with the given key.
321    ///
322    /// # Errors
323    ///
324    /// Returns an error if the payload could not be serialized or if the key
325    /// could not sign the payload.
326    pub fn sign<K, S>(
327        header: JsonWebSignatureHeader,
328        payload: T,
329        key: &K,
330    ) -> Result<Self, JwtSignatureError>
331    where
332        K: RandomizedSigner<S>,
333        S: SignatureEncoding,
334        T: Serialize,
335    {
336        #[allow(clippy::disallowed_methods)]
337        Self::sign_with_rng(&mut thread_rng(), header, payload, key)
338    }
339
340    /// Sign the given payload with the given key using the given RNG.
341    ///
342    /// # Errors
343    ///
344    /// Returns an error if the payload could not be serialized or if the key
345    /// could not sign the payload.
346    pub fn sign_with_rng<R, K, S>(
347        rng: &mut R,
348        header: JsonWebSignatureHeader,
349        payload: T,
350        key: &K,
351    ) -> Result<Self, JwtSignatureError>
352    where
353        R: CryptoRngCore,
354        K: RandomizedSigner<S>,
355        S: SignatureEncoding,
356        T: Serialize,
357    {
358        let header_ = serde_json::to_vec(&header).map_err(JwtSignatureError::encode_header)?;
359        let header_ = Base64UrlUnpadded::encode_string(&header_);
360
361        let payload_ = serde_json::to_vec(&payload).map_err(JwtSignatureError::encode_payload)?;
362        let payload_ = Base64UrlUnpadded::encode_string(&payload_);
363
364        let mut inner = format!("{header_}.{payload_}");
365
366        let first_dot = header_.len();
367        let second_dot = inner.len();
368
369        let signature = key.try_sign_with_rng(rng, inner.as_bytes())?.to_vec();
370        let signature_ = Base64UrlUnpadded::encode_string(&signature);
371        inner.reserve_exact(1 + signature_.len());
372        inner.push('.');
373        inner.push_str(&signature_);
374
375        let raw = RawJwt::new(inner, first_dot, second_dot);
376
377        Ok(Self {
378            raw,
379            header,
380            payload,
381            signature,
382        })
383    }
384}
385
386#[cfg(test)]
387mod tests {
388    #![allow(clippy::disallowed_methods)]
389    use mas_iana::jose::JsonWebSignatureAlg;
390    use rand::thread_rng;
391
392    use super::*;
393
394    #[test]
395    fn test_jwt_decode() {
396        let jwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c";
397        let jwt: Jwt<'_, serde_json::Value> = Jwt::try_from(jwt).unwrap();
398        assert_eq!(jwt.raw.header(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9");
399        assert_eq!(
400            jwt.raw.payload(),
401            "eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ"
402        );
403        assert_eq!(
404            jwt.raw.signature(),
405            "SflKxwRJSMeKKF2QT4fwpMeJf36POk6yJV_adQssw5c"
406        );
407        assert_eq!(jwt.raw.signed_part(), "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ");
408    }
409
410    #[test]
411    fn test_jwt_sign_and_verify() {
412        let header = JsonWebSignatureHeader::new(JsonWebSignatureAlg::Es256);
413        let payload = serde_json::json!({"hello": "world"});
414
415        let key = ecdsa::SigningKey::<p256::NistP256>::random(&mut thread_rng());
416        let signed = Jwt::sign::<_, ecdsa::Signature<_>>(header, payload, &key).unwrap();
417        signed
418            .verify::<_, ecdsa::Signature<_>>(key.verifying_key())
419            .unwrap();
420    }
421}