pykrete_jsonwebkey/
lib.rs

1#![deny(rust_2018_idioms, unreachable_pub)]
2#![forbid(unsafe_code)]
3
4//! *[JSON Web Key (JWK)](https://tools.ietf.org/html/rfc7517#section-4.3) (de)serialization, generation, and conversion.*
5//!
6//! **Note**: this crate requires Rust nightly >= 1.45 because it uses
7//! `feature(const_generics, fixed_size_array)` to enable statically-checked key lengths.
8//!
9//! ## Examples
10//!
11//! ### Deserializing from JSON
12//!
13//! ```
14//! extern crate pykrete_jsonwebkey as jwk;
15//! // Generated using https://mkjwk.org/.
16//! let jwt_str = r#"{
17//!    "kty": "oct",
18//!    "use": "sig",
19//!    "kid": "my signing key",
20//!    "k": "Wpj30SfkzM_m0Sa_B2NqNw",
21//!    "alg": "HS256"
22//! }"#;
23//! let the_jwk: jwk::JsonWebKey = jwt_str.parse().unwrap();
24//! println!("{:#?}", the_jwk); // looks like `jwt_str` but with reordered fields.
25//! ```
26//!
27//! ### Using with other crates
28//!
29//! ```
30//! #[cfg(all(feature = "generate", feature = "jwt-convert"))] {
31//! extern crate jsonwebtoken as jwt;
32//! extern crate pykrete_jsonwebkey as jwk;
33//!
34//! #[derive(serde::Serialize, serde::Deserialize)]
35//! struct TokenClaims {
36//!     exp: usize,
37//! }
38//!
39//! let mut my_jwk = jwk::JsonWebKey::new(jwk::Key::generate_p256());
40//! my_jwk.set_algorithm(jwk::Algorithm::ES256);
41//!
42//! let alg: jwt::Algorithm = my_jwk.algorithm.unwrap().into();
43//! let token = jwt::encode(
44//!     &jwt::Header::new(alg),
45//!     &TokenClaims {
46//!         exp: 0,
47//!     },
48//!     &my_jwk.key.to_encoding_key(),
49//! ).unwrap();
50//!
51//! let mut validation = jwt::Validation::new(alg);
52//! validation.validate_exp = false;
53//! jwt::decode::<TokenClaims>(&token, &my_jwk.key.to_decoding_key(), &validation).unwrap();
54//! }
55//! ```
56//!
57//! ## Features
58//!
59//! * `convert` - enables `Key::{to_der, to_pem}`.
60//!               This pulls in the [yasna](https://crates.io/crates/yasna) crate.
61//! * `generate` - enables `Key::{generate_p256, generate_symmetric}`.
62//!                This pulls in the [p256](https://crates.io/crates/p256) and [rand](https://crates.io/crates/rand) crates.
63//! * `jsonwebtoken` - enables conversions to types in the [jsonwebtoken](https://crates.io/crates/jsonwebtoken) crate.
64
65mod byte_array;
66mod byte_vec;
67mod key_ops;
68#[cfg(test)]
69mod tests;
70mod utils;
71
72use std::{borrow::Cow, fmt};
73
74use generic_array::typenum::{U32, U48};
75use serde::{Deserialize, Serialize};
76
77pub use byte_array::ByteArray;
78pub use byte_vec::ByteVec;
79pub use key_ops::KeyOps;
80
81#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
82pub struct JsonWebKey {
83    #[serde(flatten)]
84    pub key: Box<Key>,
85
86    #[serde(default, rename = "use", skip_serializing_if = "Option::is_none")]
87    pub key_use: Option<KeyUse>,
88
89    #[serde(default, skip_serializing_if = "KeyOps::is_empty")]
90    pub key_ops: KeyOps,
91
92    #[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")]
93    pub key_id: Option<String>,
94
95    #[serde(default, rename = "alg", skip_serializing_if = "Option::is_none")]
96    pub algorithm: Option<Algorithm>,
97
98    #[serde(default, flatten, skip_serializing_if = "X509Params::is_empty")]
99    pub x5: X509Params,
100}
101
102#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
103pub struct X509Params {
104    /// x5u: The URL of the X.509 cert corresponding to this key.
105    #[serde(default, rename = "x5u", skip_serializing_if = "Option::is_none")]
106    url: Option<String>,
107
108    /// x5c: The certificate chain used to verify this key.
109    #[serde(default, rename = "x5c", skip_serializing_if = "Option::is_none")]
110    cert_chain: Option<Vec<String>>,
111
112    /// x5t: The SHA-1 thumbprint of the DER-encoded X.509 version of the public key.
113    #[serde(default, rename = "x5t", skip_serializing_if = "Option::is_none")]
114    thumbprint: Option<String>,
115
116    /// x5t#S256: The same data as the thumbprint, but digested using SHA-256
117    #[serde(default, rename = "x5t#S256", skip_serializing_if = "Option::is_none")]
118    thumbprint_sha256: Option<String>,
119}
120
121impl X509Params {
122    fn is_empty(&self) -> bool {
123        matches!(
124            self,
125            X509Params {
126                url: None,
127                cert_chain: None,
128                thumbprint: None,
129                thumbprint_sha256: None,
130            }
131        )
132    }
133}
134
135impl JsonWebKey {
136    pub fn new(key: Key) -> Self {
137        Self {
138            key: Box::new(key),
139            key_use: None,
140            key_ops: KeyOps::empty(),
141            key_id: None,
142            algorithm: None,
143            x5: Default::default(),
144        }
145    }
146
147    pub fn set_algorithm(&mut self, alg: Algorithm) -> Result<(), Error> {
148        Self::validate_algorithm(alg, &self.key)?;
149        self.algorithm = Some(alg);
150        Ok(())
151    }
152
153    pub fn from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, Error> {
154        Ok(serde_json::from_slice(bytes.as_ref())?)
155    }
156
157    fn validate_algorithm(alg: Algorithm, key: &Key) -> Result<(), Error> {
158        use Algorithm::*;
159        use Key::*;
160        match (alg, key) {
161            (
162                ES256,
163                EC {
164                    curve: Curve::P256 { .. },
165                    ..
166                },
167            )
168            | (
169                ES384,
170                EC {
171                    curve: Curve::P384 { .. },
172                },
173            )
174            | (RS256, RSA { .. })
175            | (RS384, RSA { .. })
176            | (RS512, RSA { .. })
177            | (HS256, Symmetric { .. }) => Ok(()),
178            (HS384, Symmetric { .. }) => Ok(()),
179            (HS512, Symmetric { .. }) => Ok(()),
180            _ => Err(Error::MismatchedAlgorithm),
181        }
182    }
183}
184
185impl std::str::FromStr for JsonWebKey {
186    type Err = Error;
187    fn from_str(json: &str) -> Result<Self, Self::Err> {
188        let jwk = Self::from_slice(json.as_bytes())?;
189
190        let alg = match jwk.algorithm {
191            Some(alg) => alg,
192            None => return Ok(jwk),
193        };
194        Self::validate_algorithm(alg, &jwk.key).map(|_| jwk)
195    }
196}
197
198impl std::fmt::Display for JsonWebKey {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        if f.alternate() {
201            write!(f, "{}", serde_json::to_string_pretty(self).unwrap())
202        } else {
203            write!(f, "{}", serde_json::to_string(self).unwrap())
204        }
205    }
206}
207
208#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
209#[serde(tag = "kty")]
210#[allow(clippy::upper_case_acronyms)]
211pub enum Key {
212    /// An elliptic curve, as per [RFC 7518 §6.2](https://tools.ietf.org/html/rfc7518#section-6.2).
213    EC {
214        #[serde(flatten)]
215        curve: Curve,
216    },
217    /// An elliptic curve, as per [RFC 7518 §6.3](https://tools.ietf.org/html/rfc7518#section-6.3).
218    /// See also: [RFC 3447](https://tools.ietf.org/html/rfc3447).
219    RSA {
220        #[serde(flatten)]
221        public: RsaPublic,
222        #[serde(flatten, default, skip_serializing_if = "Option::is_none")]
223        private: Option<RsaPrivate>,
224    },
225    /// A symmetric key, as per [RFC 7518 §6.4](https://tools.ietf.org/html/rfc7518#section-6.4).
226    #[serde(rename = "oct")]
227    Symmetric {
228        #[serde(rename = "k")]
229        key: ByteVec,
230    },
231}
232
233#[cfg(feature = "thumbprint")]
234impl Key {
235    /// The JWK thumbprint, as per [RFC 7638](https://datatracker.ietf.org/doc/html/rfc7638),
236    /// using SHA-256 as the hash function.
237    pub fn thumbprint(&self) -> String {
238        self.try_thumbprint_using_hasher::<sha2::Sha256>().unwrap()
239    }
240
241    /// The JWK thumbprint, as per [RFC 7638](https://datatracker.ietf.org/doc/html/rfc7638),
242    /// using the provided hash function.
243    pub fn try_thumbprint_using_hasher<H: sha2::digest::Digest>(
244        &self,
245    ) -> Result<String, serde_json::Error> {
246        use serde::ser::{SerializeStruct, Serializer};
247        let mut s = serde_json::Serializer::new(Vec::new());
248        match self {
249            Self::EC {
250                curve: curve @ Curve::P256 { x, y, .. },
251            } => {
252                let mut ss = s.serialize_struct("", 4)?;
253                ss.serialize_field("crv", curve.name())?;
254                ss.serialize_field("kty", "EC")?;
255                ss.serialize_field("x", x)?;
256                ss.serialize_field("y", y)?;
257                ss.end()?;
258            }
259            Self::EC {
260                curve: curve @ Curve::P384 { x, y, .. },
261            } => {
262                let mut ss = s.serialize_struct("", 4)?;
263                ss.serialize_field("crv", curve.name())?;
264                ss.serialize_field("kty", "EC")?;
265                ss.serialize_field("x", x)?;
266                ss.serialize_field("y", y)?;
267                ss.end()?;
268            }
269            Self::RSA {
270                public: RsaPublic { e, n },
271                ..
272            } => {
273                let mut ss = s.serialize_struct("", 3)?;
274                ss.serialize_field("e", e)?;
275                ss.serialize_field("kty", "RSA")?;
276                ss.serialize_field("n", n)?;
277                ss.end()?;
278            }
279            Self::Symmetric { key } => {
280                let mut ss = s.serialize_struct("", 2)?;
281                ss.serialize_field("k", key)?;
282                ss.serialize_field("kty", "oct")?;
283                ss.end()?;
284            }
285        }
286        Ok(crate::utils::base64_encode(H::digest(s.into_inner())))
287    }
288}
289
290impl Key {
291    /// Returns true iff this key only contains private components (i.e. a private asymmetric
292    /// key or a symmetric key).
293    pub fn is_private(&self) -> bool {
294        matches!(
295            self,
296            Self::Symmetric { .. }
297                | Self::EC {
298                    curve: Curve::P256 { d: Some(_), .. },
299                    ..
300                }
301                | Self::EC {
302                    curve: Curve::P384 { d: Some(_), .. },
303                    ..
304                }
305                | Self::RSA {
306                    private: Some(_),
307                    ..
308                }
309        )
310    }
311
312    /// Returns the public part of this key (symmetric keys have no public parts).
313    pub fn to_public(&self) -> Option<Cow<'_, Self>> {
314        if !self.is_private() {
315            return Some(Cow::Borrowed(self));
316        }
317        Some(Cow::Owned(match self {
318            Self::Symmetric { .. } => return None,
319            Self::EC {
320                curve: Curve::P256 { x, y, .. },
321            } => Self::EC {
322                curve: Curve::P256 {
323                    x: x.clone(),
324                    y: y.clone(),
325                    d: None,
326                },
327            },
328            Self::EC {
329                curve: Curve::P384 { x, y, .. },
330            } => Self::EC {
331                curve: Curve::P384 {
332                    x: x.clone(),
333                    y: y.clone(),
334                    d: None,
335                },
336            },
337            Self::RSA { public, .. } => Self::RSA {
338                public: public.clone(),
339                private: None,
340            },
341        }))
342    }
343
344    /// If this key is asymmetric, encodes it as PKCS#8.
345    #[cfg(feature = "pkcs-convert")]
346    pub fn try_to_der(&self) -> Result<Vec<u8>, ConversionError> {
347        use num_bigint::BigUint;
348        use yasna::{models::ObjectIdentifier, DERWriter, DERWriterSeq, Tag};
349
350        use crate::utils::pkcs8;
351
352        if let Self::Symmetric { .. } = self {
353            return Err(ConversionError::NotAsymmetric);
354        }
355
356        Ok(match self {
357            Self::EC {
358                curve: Curve::P256 { d, x, y },
359            } => {
360                let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
361                let prime256v1_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 3, 1, 7]);
362                let oids = &[Some(&ec_public_oid), Some(&prime256v1_oid)];
363
364                let write_public = |writer: DERWriter<'_>| {
365                    let public_bytes: Vec<u8> = [0x04 /* uncompressed */]
366                        .iter()
367                        .chain(x.iter())
368                        .chain(y.iter())
369                        .copied()
370                        .collect();
371                    writer.write_bitvec_bytes(&public_bytes, 8 * (32 * 2 + 1));
372                };
373
374                match d {
375                    Some(private_point) => {
376                        pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
377                            writer.next().write_i8(1); // version
378                            writer.next().write_bytes(private_point);
379                            // The following tagged value is optional. OpenSSL produces it,
380                            // but many tools, including jwt.io and `jsonwebtoken`, don't like it,
381                            // so we don't include it.
382                            // writer.next().write_tagged(Tag::context(0), |writer| {
383                            //     writer.write_oid(&prime256v1_oid)
384                            // });
385                            writer.next().write_tagged(Tag::context(1), write_public);
386                        })
387                    }
388                    None => pkcs8::write_public(oids, write_public),
389                }
390            }
391            Self::EC {
392                curve: Curve::P384 { d, x, y },
393            } => {
394                let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
395                let prime384v1_oid = ObjectIdentifier::from_slice(&[1, 3, 132, 0, 34]);
396                let oids = &[Some(&ec_public_oid), Some(&prime384v1_oid)];
397
398                let write_public = |writer: DERWriter<'_>| {
399                    let public_bytes: Vec<u8> = [0x04 /* uncompressed */]
400                        .iter()
401                        .chain(x.iter())
402                        .chain(y.iter())
403                        .copied()
404                        .collect();
405                    writer.write_bitvec_bytes(&public_bytes, 8 * (48 * 2 + 1));
406                };
407
408                match d {
409                    Some(private_point) => {
410                        pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
411                            writer.next().write_i8(1); // version
412                            writer.next().write_bytes(private_point);
413                            writer.next().write_tagged(Tag::context(1), write_public);
414                        })
415                    }
416                    None => pkcs8::write_public(oids, write_public),
417                }
418            }
419            Self::RSA { public, private } => {
420                let rsa_encryption_oid = ObjectIdentifier::from_slice(&[
421                    1, 2, 840, 113549, 1, 1, 1, // rsaEncryption
422                ]);
423                let oids = &[Some(&rsa_encryption_oid), None];
424                let write_bytevec = |writer: DERWriter<'_>, vec: &ByteVec| {
425                    let bigint = BigUint::from_bytes_be(vec);
426                    writer.write_biguint(&bigint);
427                };
428
429                let write_public = |writer: &mut DERWriterSeq<'_>| {
430                    write_bytevec(writer.next(), &public.n);
431                    writer.next().write_u32(PUBLIC_EXPONENT);
432                };
433
434                let write_private = |writer: &mut DERWriterSeq<'_>, private: &RsaPrivate| {
435                    // https://tools.ietf.org/html/rfc3447#appendix-A.1.2
436                    writer.next().write_i8(0); // version (two-prime)
437                    write_public(writer);
438                    write_bytevec(writer.next(), &private.d);
439                    macro_rules! write_opt_bytevecs {
440                            ($($param:ident),+) => {{
441                                $(write_bytevec(writer.next(), private.$param.as_ref().unwrap());)+
442                            }};
443                        }
444                    write_opt_bytevecs!(p, q, dp, dq, qi);
445                };
446
447                match private {
448                    Some(
449                        private @ RsaPrivate {
450                            d: _,
451                            p: Some(_),
452                            q: Some(_),
453                            dp: Some(_),
454                            dq: Some(_),
455                            qi: Some(_),
456                        },
457                    ) => pkcs8::write_private(oids, |writer| write_private(writer, private)),
458                    Some(_) => return Err(ConversionError::MissingRsaParams),
459                    None => pkcs8::write_public(oids, |writer| {
460                        let body =
461                            yasna::construct_der(|writer| writer.write_sequence(write_public));
462                        writer.write_bitvec_bytes(&body, body.len() * 8);
463                    }),
464                }
465            }
466            Self::Symmetric { .. } => unreachable!("checked above"),
467        })
468    }
469
470    /// Unwrapping `try_to_der`.
471    /// Panics if the key is not asymmetric or there are missing RSA components.
472    #[cfg(feature = "pkcs-convert")]
473    pub fn to_der(&self) -> Vec<u8> {
474        self.try_to_der().unwrap()
475    }
476
477    /// If this key is asymmetric, encodes it as PKCS#8 with PEM armoring.
478    #[cfg(feature = "pkcs-convert")]
479    pub fn try_to_pem(&self) -> Result<String, ConversionError> {
480        use base64::{engine::general_purpose::STANDARD, Engine};
481        use std::fmt::Write;
482        let der_b64 = STANDARD.encode(self.try_to_der()?);
483        let key_ty = if self.is_private() {
484            "PRIVATE"
485        } else {
486            "PUBLIC"
487        };
488        let mut pem = String::new();
489        writeln!(&mut pem, "-----BEGIN {} KEY-----", key_ty).unwrap();
490        //^ re: `unwrap`, if writing to a string fails, we've got bigger issues.
491        const MAX_LINE_LEN: usize = 64;
492        for i in (0..der_b64.len()).step_by(MAX_LINE_LEN) {
493            writeln!(
494                &mut pem,
495                "{}",
496                &der_b64[i..std::cmp::min(i + MAX_LINE_LEN, der_b64.len())]
497            )
498            .unwrap();
499        }
500        writeln!(&mut pem, "-----END {} KEY-----", key_ty).unwrap();
501        Ok(pem)
502    }
503
504    /// Unwrapping `try_to_pem`.
505    /// Panics if the key is not asymmetric or there are missing RSA components.
506    #[cfg(feature = "pkcs-convert")]
507    pub fn to_pem(&self) -> String {
508        self.try_to_pem().unwrap()
509    }
510
511    /// Generates a new symmetric key with the specified number of bits.
512    /// Best used with one of the HS algorithms (e.g., HS256).
513    #[cfg(feature = "generate")]
514    pub fn generate_symmetric(num_bits: usize) -> Self {
515        use rand::RngCore;
516        let mut bytes = vec![0; num_bits / 8];
517        rand::thread_rng().fill_bytes(&mut bytes);
518        Self::Symmetric { key: bytes.into() }
519    }
520
521    /// Generates a new EC keypair using the prime256 curve.
522    /// Used with the ES256 algorithm.
523    #[cfg(feature = "generate")]
524    pub fn generate_p256() -> Self {
525        use p256::elliptic_curve::{self as elliptic_curve, sec1::ToEncodedPoint};
526
527        let sk = elliptic_curve::SecretKey::random(&mut rand::thread_rng());
528        let sk_scalar = p256::Scalar::from(&sk);
529
530        let pk = p256::ProjectivePoint::GENERATOR * sk_scalar;
531        let pk_bytes = &pk
532            .to_affine()
533            .to_encoded_point(false /* compress */)
534            .to_bytes()[1..];
535        let (x_bytes, y_bytes) = pk_bytes.split_at(32);
536
537        Self::EC {
538            curve: Curve::P256 {
539                d: Some(sk_scalar.to_bytes().into()),
540                x: ByteArray::from_slice(x_bytes),
541                y: ByteArray::from_slice(y_bytes),
542            },
543        }
544    }
545}
546
547#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
548#[serde(tag = "crv")]
549pub enum Curve {
550    /// Parameters of the prime256v1 (P256) curve.
551    #[serde(rename = "P-256")]
552    P256 {
553        /// The private scalar.
554        #[serde(skip_serializing_if = "Option::is_none")]
555        d: Option<ByteArray<U32>>,
556        /// The curve point x coordinate.
557        x: ByteArray<U32>,
558        /// The curve point y coordinate.
559        y: ByteArray<U32>,
560    },
561    /// Parameters of the prime384v1 (P384) curve.
562    #[serde(rename = "P-384")]
563    P384 {
564        /// The private scalar.
565        #[serde(skip_serializing_if = "Option::is_none")]
566        d: Option<ByteArray<U48>>,
567        /// The curve point x coordinate.
568        x: ByteArray<U48>,
569        /// The curve point y coordinate.
570        y: ByteArray<U48>,
571    },
572}
573
574impl Curve {
575    pub fn name(&self) -> &'static str {
576        match self {
577            Self::P256 { .. } => "P-256",
578            Self::P384 { .. } => "P-256",
579        }
580    }
581}
582
583impl fmt::Display for Curve {
584    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585        match self {
586            Self::P256 { x, y, .. } => f
587                .debug_struct("Curve::P256")
588                .field("x", x)
589                .field("y", y)
590                .finish(),
591            Self::P384 { x, y, .. } => f
592                .debug_struct("Curve::P384")
593                .field("x", x)
594                .field("y", y)
595                .finish(),
596        }
597    }
598}
599
600#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
601pub struct RsaPublic {
602    /// The standard public exponent, 65537.
603    pub e: PublicExponent,
604    /// The modulus, p*q.
605    pub n: ByteVec,
606}
607
608const PUBLIC_EXPONENT: u32 = 65537;
609const PUBLIC_EXPONENT_B64: &str = "AQAB"; // little-endian, strip zeros
610const PUBLIC_EXPONENT_B64_PADDED: &str = "AQABAA==";
611
612/// The standard RSA public exponent, 65537.
613#[derive(Clone, Copy, Debug, PartialEq, Eq)]
614pub struct PublicExponent;
615
616impl Serialize for PublicExponent {
617    fn serialize<S: serde::ser::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
618        PUBLIC_EXPONENT_B64.serialize(s)
619    }
620}
621
622impl<'de> Deserialize<'de> for PublicExponent {
623    fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
624        let e = String::deserialize(d)?;
625        if e == PUBLIC_EXPONENT_B64 || e == PUBLIC_EXPONENT_B64_PADDED {
626            Ok(Self)
627        } else {
628            Err(serde::de::Error::custom(&format!(
629                "public exponent must be {}",
630                PUBLIC_EXPONENT
631            )))
632        }
633    }
634}
635
636#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
637pub struct RsaPrivate {
638    /// Private exponent.
639    pub d: ByteVec,
640    /// First prime factor.
641    #[serde(default, skip_serializing_if = "Option::is_none")]
642    pub p: Option<ByteVec>,
643    /// Second prime factor.
644    #[serde(default, skip_serializing_if = "Option::is_none")]
645    pub q: Option<ByteVec>,
646    /// First factor Chinese Remainder Theorem (CRT) exponent.
647    #[serde(default, skip_serializing_if = "Option::is_none")]
648    pub dp: Option<ByteVec>,
649    /// Second factor CRT exponent.
650    #[serde(default, skip_serializing_if = "Option::is_none")]
651    pub dq: Option<ByteVec>,
652    /// First CRT coefficient.
653    #[serde(default, skip_serializing_if = "Option::is_none")]
654    pub qi: Option<ByteVec>,
655}
656
657impl fmt::Debug for RsaPrivate {
658    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
659        f.write_str("RsaPrivate")
660    }
661}
662
663#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
664pub enum KeyUse {
665    #[serde(rename = "sig")]
666    Signing,
667    #[serde(rename = "enc")]
668    Encryption,
669}
670
671#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
672#[allow(clippy::upper_case_acronyms)]
673pub enum Algorithm {
674    HS256,
675    HS384,
676    HS512,
677    RS256,
678    RS384,
679    RS512,
680    ES256,
681    ES384,
682}
683
684impl Algorithm {
685    pub fn name(&self) -> &'static str {
686        match self {
687            Self::HS256 => "hs256",
688            Self::HS384 => "hs384",
689            Self::HS512 => "hs512",
690            Self::RS256 => "rs256",
691            Self::RS384 => "rs384",
692            Self::RS512 => "rs512",
693            Self::ES256 => "es256",
694            Self::ES384 => "es384",
695        }
696    }
697}
698
699#[cfg(feature = "jwt-convert")]
700const _IMPL_JWT_CONVERSIONS: () = {
701    use jsonwebtoken as jwt;
702
703    impl From<Algorithm> for jwt::Algorithm {
704        fn from(alg: Algorithm) -> Self {
705            match alg {
706                Algorithm::HS256 => Self::HS256,
707                Algorithm::HS384 => Self::HS384,
708                Algorithm::HS512 => Self::HS512,
709                Algorithm::ES256 => Self::ES256,
710                Algorithm::ES384 => Self::ES384,
711                Algorithm::RS256 => Self::RS256,
712                Algorithm::RS384 => Self::RS384,
713                Algorithm::RS512 => Self::RS512,
714            }
715        }
716    }
717
718    impl Key {
719        /// Returns an `EncodingKey` if the key is private.
720        pub fn try_to_encoding_key(&self) -> Result<jwt::EncodingKey, ConversionError> {
721            if !self.is_private() {
722                return Err(ConversionError::NotPrivate);
723            }
724            Ok(match self {
725                Self::Symmetric { key } => jwt::EncodingKey::from_secret(key),
726                // The following two conversion will not panic, as we've ensured that the keys
727                // are private and tested that the successful output of `try_to_pem` is valid.
728                Self::EC { .. } => {
729                    jwt::EncodingKey::from_ec_pem(self.try_to_pem()?.as_bytes()).unwrap()
730                }
731                Self::RSA { .. } => {
732                    jwt::EncodingKey::from_rsa_pem(self.try_to_pem()?.as_bytes()).unwrap()
733                }
734            })
735        }
736
737        /// Unwrapping `try_to_encoding_key`. Panics if the key is public.
738        pub fn to_encoding_key(&self) -> jwt::EncodingKey {
739            self.try_to_encoding_key().unwrap()
740        }
741
742        pub fn to_decoding_key(&self) -> jwt::DecodingKey {
743            match self {
744                Self::Symmetric { key } => jwt::DecodingKey::from_secret(key),
745                Self::EC { .. } => {
746                    // The following will not panic: all EC JWKs have public components due to
747                    // typing. PEM conversion will always succeed, for the same reason.
748                    // Hence, jwt::DecodingKey shall have no issue with de-converting.
749                    jwt::DecodingKey::from_ec_pem(self.to_public().unwrap().to_pem().as_bytes())
750                        .unwrap()
751                }
752                Self::RSA { .. } => {
753                    jwt::DecodingKey::from_rsa_pem(self.to_public().unwrap().to_pem().as_bytes())
754                        .unwrap()
755                }
756            }
757        }
758    }
759};
760
761#[derive(Debug, thiserror::Error)]
762pub enum Error {
763    #[error(transparent)]
764    Serde(#[from] serde_json::Error),
765
766    #[error(transparent)]
767    Base64Decode(#[from] base64::DecodeError),
768
769    #[error("mismatched algorithm for key type")]
770    MismatchedAlgorithm,
771}
772
773#[derive(Debug, thiserror::Error)]
774pub enum ConversionError {
775    #[error("encoding RSA JWK as PKCS#8 requires specifing all of p, q, dp, dq, qi")]
776    MissingRsaParams,
777
778    #[error("a symmetric key can not be encoded using PKCS#8")]
779    NotAsymmetric,
780
781    #[cfg(feature = "jwt-convert")]
782    #[error("a public key cannot be converted to a `jsonwebtoken::EncodingKey`")]
783    NotPrivate,
784}