1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
use anyhow::bail;

use crate::jwe::{JweContext, JweDecrypter, JweEncrypter, JweHeader};
use crate::jwk::{Jwk, JwkSet};
use crate::jws::{JwsContext, JwsHeader, JwsSigner, JwsVerifier};
use crate::jwt::{self, JwtPayload};
use crate::{JoseError, JoseHeader, Map, Value};

#[derive(Debug, Eq, PartialEq, Clone)]
pub struct JwtContext {
    jws_context: JwsContext,
    jwe_context: JweContext,
}

impl JwtContext {
    pub fn new() -> Self {
        Self {
            jws_context: JwsContext::new(),
            jwe_context: JweContext::new(),
        }
    }

    /// Test a critical header claim name is acceptable.
    ///
    /// # Arguments
    ///
    /// * `name` - a critical header claim name
    pub fn is_acceptable_critical(&self, name: &str) -> bool {
        self.jws_context.is_acceptable_critical(name)
    }

    /// Add a acceptable critical header claim name
    ///
    /// # Arguments
    ///
    /// * `name` - a acceptable critical header claim name
    pub fn add_acceptable_critical(&mut self, name: &str) {
        self.jws_context.add_acceptable_critical(name);
        self.jwe_context.add_acceptable_critical(name);
    }

    /// Remove a acceptable critical header claim name
    ///
    /// # Arguments
    ///
    /// * `name` - a acceptable critical header claim name
    pub fn remove_acceptable_critical(&mut self, name: &str) {
        self.jws_context.remove_acceptable_critical(name);
        self.jwe_context.remove_acceptable_critical(name);
    }

    /// Return the string repsentation of the JWT with a "none" algorithm.
    ///
    /// # Arguments
    ///
    /// * `payload` - The payload data.
    /// * `header` - The JWT heaser claims.
    pub fn encode_unsecured(
        &self,
        payload: &JwtPayload,
        header: &JwsHeader,
    ) -> Result<String, JoseError> {
        self.encode_with_signer(payload, header, &jwt::None.signer())
    }

    /// Return the string repsentation of the JWT with the siginig algorithm.
    ///
    /// # Arguments
    ///
    /// * `payload` - The payload data.
    /// * `header` - The JWS heaser claims.
    /// * `signer` - a signer object.
    pub fn encode_with_signer(
        &self,
        payload: &JwtPayload,
        header: &JwsHeader,
        signer: &dyn JwsSigner,
    ) -> Result<String, JoseError> {
        (|| -> anyhow::Result<String> {
            if let Some(vals) = header.critical() {
                if vals.contains(&"b64") {
                    bail!("JWT is not support b64 header claim.");
                }
            }

            let payload_bytes = serde_json::to_vec(payload.claims_set()).unwrap();
            let jwt = self
                .jws_context
                .serialize_compact(&payload_bytes, header, signer)?;
            Ok(jwt)
        })()
        .map_err(|err| match err.downcast::<JoseError>() {
            Ok(err) => err,
            Err(err) => JoseError::InvalidJwtFormat(err),
        })
    }

    /// Return the string repsentation of the JWT with the encrypting algorithm.
    ///
    /// # Arguments
    ///
    /// * `payload` - The payload data.
    /// * `header` - The JWE heaser claims.
    /// * `encrypter` - a encrypter object.
    pub fn encode_with_encrypter(
        &self,
        payload: &JwtPayload,
        header: &JweHeader,
        encrypter: &dyn JweEncrypter,
    ) -> Result<String, JoseError> {
        let payload_bytes = serde_json::to_vec(payload.claims_set()).unwrap();
        let jwt = self
            .jwe_context
            .serialize_compact(&payload_bytes, header, encrypter)?;
        Ok(jwt)
    }

    /// Return the Jose header decoded from JWT.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    pub fn decode_header(&self, input: impl AsRef<[u8]>) -> Result<Box<dyn JoseHeader>, JoseError> {
        (|| -> anyhow::Result<Box<dyn JoseHeader>> {
            let input = input.as_ref();
            let parts: Vec<&[u8]> = input.split(|b| *b == '.' as u8).collect();
            if parts.len() == 3 {
                // JWS
                let header = base64::decode_config(parts[0], base64::URL_SAFE_NO_PAD)?;
                let header: Map<String, Value> = serde_json::from_slice(&header)?;
                let header = JwsHeader::from_map(header)?;
                Ok(Box::new(header))
            } else if parts.len() == 5 {
                // JWE
                let header = base64::decode_config(parts[0], base64::URL_SAFE_NO_PAD)?;
                let header: Map<String, Value> = serde_json::from_slice(&header)?;
                let header = JweHeader::from_map(header)?;
                Ok(Box::new(header))
            } else {
                bail!("The input cannot be recognized as a JWT.");
            }
        })()
        .map_err(|err| match err.downcast::<JoseError>() {
            Ok(err) => err,
            Err(err) => JoseError::InvalidJwtFormat(err),
        })
    }

    /// Return the JWT object decoded with the "none" algorithm.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    pub fn decode_unsecured(
        &self,
        input: impl AsRef<[u8]>,
    ) -> Result<(JwtPayload, JwsHeader), JoseError> {
        self.decode_with_verifier(input, &jwt::None.verifier())
    }

    /// Return the JWT object decoded by the selected verifier.
    ///
    /// # Arguments
    ///
    /// * `verifier` - a verifier of the signing algorithm.
    /// * `input` - a JWT string representation.
    pub fn decode_with_verifier(
        &self,
        input: impl AsRef<[u8]>,
        verifier: &dyn JwsVerifier,
    ) -> Result<(JwtPayload, JwsHeader), JoseError> {
        self.decode_with_verifier_selector(input, |_header| Ok(Some(verifier)))
    }

    /// Return the JWT object decoded with a selected verifying algorithm.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    /// * `selector` - a function for selecting the verifying algorithm.
    pub fn decode_with_verifier_selector<'a, F>(
        &self,
        input: impl AsRef<[u8]>,
        selector: F,
    ) -> Result<(JwtPayload, JwsHeader), JoseError>
    where
        F: Fn(&JwsHeader) -> Result<Option<&'a dyn JwsVerifier>, JoseError>,
    {
        (|| -> anyhow::Result<(JwtPayload, JwsHeader)> {
            let (payload, header) =
                self.jws_context
                    .deserialize_compact_with_selector(input, |header| {
                        (|| -> anyhow::Result<Option<&'a dyn JwsVerifier>> {
                            let verifier = match selector(&header)? {
                                Some(val) => val,
                                None => return Ok(None),
                            };

                            if self.is_acceptable_critical("b64") {
                                bail!("JWT is not supported b64 header claim.");
                            }

                            Ok(Some(verifier))
                        })()
                        .map_err(|err| {
                            match err.downcast::<JoseError>() {
                                Ok(err) => err,
                                Err(err) => JoseError::InvalidJwtFormat(err),
                            }
                        })
                    })?;

            let payload: Map<String, Value> = serde_json::from_slice(&payload)?;
            let payload = JwtPayload::from_map(payload)?;

            Ok((payload, header))
        })()
        .map_err(|err| match err.downcast::<JoseError>() {
            Ok(err) => err,
            Err(err) => JoseError::InvalidJwtFormat(err),
        })
    }

    /// Return the JWT object decoded by using a JWK set.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    /// * `jwk_set` - a JWK set.
    /// * `selector` - a function for selecting the verifying algorithm.
    pub fn decode_with_verifier_in_jwk_set<F>(
        &self,
        input: impl AsRef<[u8]>,
        jwk_set: &JwkSet,
        selector: F,
    ) -> Result<(JwtPayload, JwsHeader), JoseError>
    where
        F: Fn(&Jwk) -> Result<Option<&dyn JwsVerifier>, JoseError>,
    {
        self.decode_with_verifier_selector(input, |header| {
            let key_id = match header.key_id() {
                Some(val) => val,
                None => return Ok(None),
            };

            for jwk in jwk_set.get(key_id) {
                if let Some(val) = selector(jwk)? {
                    return Ok(Some(val));
                }
            }
            Ok(None)
        })
    }

    /// Return the JWT object decoded by the selected decrypter.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    /// * `decrypter` - a decrypter of the decrypting algorithm.
    pub fn decode_with_decrypter(
        &self,
        input: impl AsRef<[u8]>,
        decrypter: &dyn JweDecrypter,
    ) -> Result<(JwtPayload, JweHeader), JoseError> {
        self.decode_with_decrypter_selector(input, |_header| Ok(Some(decrypter)))
    }

    /// Return the JWT object decoded with a selected decrypting algorithm.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    /// * `decrypter_selector` - a function for selecting the decrypting algorithm.
    pub fn decode_with_decrypter_selector<'a, F>(
        &self,
        input: impl AsRef<[u8]>,
        selector: F,
    ) -> Result<(JwtPayload, JweHeader), JoseError>
    where
        F: Fn(&JweHeader) -> Result<Option<&'a dyn JweDecrypter>, JoseError>,
    {
        (|| -> anyhow::Result<(JwtPayload, JweHeader)> {
            let (payload, header) =
                self.jwe_context
                    .deserialize_compact_with_selector(input, |header| {
                        let decrypter = match selector(&header)? {
                            Some(val) => val,
                            None => return Ok(None),
                        };

                        Ok(Some(decrypter))
                    })?;

            let payload: Map<String, Value> = serde_json::from_slice(&payload)?;
            let payload = JwtPayload::from_map(payload)?;

            Ok((payload, header))
        })()
        .map_err(|err| match err.downcast::<JoseError>() {
            Ok(err) => err,
            Err(err) => JoseError::InvalidJwtFormat(err),
        })
    }

    /// Return the JWT object decoded by using a JWK set.
    ///
    /// # Arguments
    ///
    /// * `input` - a JWT string representation.
    /// * `jwk_set` - a JWK set.
    /// * `selector` - a function for selecting the decrypting algorithm.
    pub fn decode_with_decrypter_in_jwk_set<F>(
        &self,
        input: impl AsRef<[u8]>,
        jwk_set: &JwkSet,
        selector: F,
    ) -> Result<(JwtPayload, JweHeader), JoseError>
    where
        F: Fn(&Jwk) -> Result<Option<&dyn JweDecrypter>, JoseError>,
    {
        self.decode_with_decrypter_selector(input, |header| {
            let key_id = match header.key_id() {
                Some(val) => val,
                None => return Ok(None),
            };

            for jwk in jwk_set.get(key_id) {
                if let Some(val) = selector(jwk)? {
                    return Ok(Some(val));
                }
            }
            Ok(None)
        })
    }
}