gel_jwt/
bare_key.rs

1use base64ct::Encoding;
2use const_oid::db::rfc5912::{ID_EC_PUBLIC_KEY, RSA_ENCRYPTION, SECP_256_R_1};
3use der::{asn1::BitString, Any, AnyRef, Decode, Encode, SliceReader};
4use elliptic_curve::{generic_array::GenericArray, sec1::FromEncodedPoint};
5use num_bigint_dig::BigUint;
6use p256::elliptic_curve::{sec1::ToEncodedPoint, JwkEcKey};
7use pem::Pem;
8use pkcs1::UintRef;
9use pkcs8::{
10    spki::{AlgorithmIdentifier, SubjectPublicKeyInfoOwned},
11    PrivateKeyInfo,
12};
13use ring::{
14    rand::SystemRandom,
15    signature::{RsaKeyPair, ECDSA_P256_SHA256_FIXED_SIGNING},
16};
17use rustls_pki_types::PrivatePkcs1KeyDer;
18use sec1::{EcParameters, EcPrivateKey};
19use serde::{Deserialize, Serialize};
20use std::{collections::HashMap, str::FromStr, vec::Vec};
21
22use crate::{KeyError, KeyType, KeyValidationError};
23
24const MIN_OCT_LEN_BYTES: usize = 16;
25const MIN_RSA_KEY_BITS: usize = 2048;
26
27#[cfg(feature = "keygen")]
28const DEFAULT_GEN_RSA_KEY_BITS: usize = 2048;
29#[cfg(feature = "keygen")]
30const DEFAULT_GEN_OCT_LEN_BYTES: usize = 32;
31
32#[derive(zeroize::ZeroizeOnDrop, Eq, PartialEq, Clone)]
33pub(crate) struct HmacKey {
34    key: zeroize::Zeroizing<Vec<u8>>,
35}
36
37impl std::hash::Hash for HmacKey {
38    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
39        self.key.hash(state);
40    }
41}
42
43#[derive(derive_more::Debug, Serialize, Deserialize)]
44pub struct SerializedKeys {
45    pub keys: Vec<SerializedKey>,
46}
47
48/// Deserialize
49#[derive(derive_more::Debug)]
50pub enum SerializedKey {
51    Private(Option<String>, BarePrivateKey),
52    Public(Option<String>, BarePublicKey),
53    #[debug("UnknownOrInvalid({_0}, {_0}, ...)")]
54    UnknownOrInvalid(
55        #[allow(unused)] KeyError,
56        String,
57        HashMap<String, serde_json::Value>,
58    ),
59}
60
61impl<'de> serde::Deserialize<'de> for SerializedKey {
62    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
63    where
64        D: serde::Deserializer<'de>,
65    {
66        let map: HashMap<String, serde_json::Value> = HashMap::deserialize(deserializer)?;
67        let get = |k: &'static str| {
68            map.get(k)
69                .map(|s| s.as_str().unwrap_or_default())
70                .unwrap_or_default()
71        };
72
73        let kty = get("kty");
74        let kid = map
75            .get("kid")
76            .map(|v| v.as_str().unwrap_or_default().to_owned());
77
78        match kty {
79            "RSA" => {
80                // Check if private key by looking for p,q components
81                if map.contains_key("p") && map.contains_key("q") {
82                    // Private key
83                    match BarePrivateKey::from_jwt_rsa(
84                        get("n"),
85                        get("e"),
86                        get("d"),
87                        get("p"),
88                        get("q"),
89                        get("dp"),
90                        get("dq"),
91                        get("qi"),
92                    ) {
93                        Ok(key) => Ok(SerializedKey::Private(kid, key)),
94                        Err(e) => Ok(SerializedKey::UnknownOrInvalid(e, kty.to_string(), map)),
95                    }
96                } else {
97                    // Public key
98                    match BarePublicKey::from_jwt_rsa(get("n"), get("e")) {
99                        Ok(key) => Ok(SerializedKey::Public(kid, key)),
100                        Err(e) => Ok(SerializedKey::UnknownOrInvalid(e, kty.to_string(), map)),
101                    }
102                }
103            }
104            "EC" => {
105                // Check if private key by looking for d component
106                if map.contains_key("d") {
107                    // Private key
108                    match BarePrivateKey::from_jwt_ec(get("crv"), get("d"), get("x"), get("y")) {
109                        Ok(key) => Ok(SerializedKey::Private(kid, key)),
110                        Err(e) => Ok(SerializedKey::UnknownOrInvalid(e, kty.to_string(), map)),
111                    }
112                } else {
113                    // Public key
114                    match BarePublicKey::from_jwt_ec(get("crv"), get("x"), get("y")) {
115                        Ok(key) => Ok(SerializedKey::Public(kid, key)),
116                        Err(e) => Ok(SerializedKey::UnknownOrInvalid(e, kty.to_string(), map)),
117                    }
118                }
119            }
120            "oct" => match BarePrivateKey::from_jwt_oct(get("k")) {
121                Ok(key) => Ok(SerializedKey::Private(kid, key)),
122                Err(e) => Ok(SerializedKey::UnknownOrInvalid(e, kty.to_string(), map)),
123            },
124            _ => Ok(SerializedKey::UnknownOrInvalid(
125                KeyError::UnsupportedKeyType(kty.to_string()),
126                kty.to_string(),
127                map,
128            )),
129        }
130    }
131}
132
133impl serde::Serialize for SerializedKey {
134    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
135    where
136        S: serde::Serializer,
137    {
138        use serde::ser::SerializeMap;
139
140        let b64 = |s: &[u8]| zeroize::Zeroizing::new(base64ct::Base64UrlUnpadded::encode_string(s));
141
142        match self {
143            SerializedKey::Private(kid, key) => {
144                let mut map = serializer.serialize_map(None)?;
145                match &key.inner {
146                    BarePrivateKeyInner::RS256(key) => {
147                        let rsa = pkcs1::RsaPrivateKey::from_der(key.secret_pkcs1_der())
148                            .map_err(serde::ser::Error::custom)?;
149                        if let Some(kid) = kid {
150                            map.serialize_entry("kid", kid)?;
151                        }
152                        map.serialize_entry("kty", "RSA")?;
153                        map.serialize_entry("n", &b64(rsa.modulus.as_bytes()))?;
154                        map.serialize_entry("e", &b64(rsa.public_exponent.as_bytes()))?;
155                        map.serialize_entry("d", &b64(rsa.private_exponent.as_bytes()))?;
156
157                        // Add dp, dq, qi
158                        map.serialize_entry("dp", &b64(rsa.exponent1.as_bytes()))?;
159                        map.serialize_entry("dq", &b64(rsa.exponent2.as_bytes()))?;
160                        if rsa.other_prime_infos.is_none() {
161                            map.serialize_entry("p", &b64(rsa.prime1.as_bytes()))?;
162                            map.serialize_entry("q", &b64(rsa.prime2.as_bytes()))?;
163                        } else {
164                            return Err(serde::ser::Error::custom(
165                                "RSA private key must have 2 primes",
166                            ));
167                        }
168
169                        map.serialize_entry("qi", &b64(rsa.coefficient.as_bytes()))?;
170                    }
171                    BarePrivateKeyInner::ES256(key) => {
172                        if let Some(kid) = kid {
173                            map.serialize_entry("kid", kid)?;
174                        }
175                        map.serialize_entry("kty", "EC")?;
176                        map.serialize_entry("crv", "P-256")?;
177                        let public_key = key.public_key();
178                        let point = public_key.to_encoded_point(false);
179                        map.serialize_entry("x", &b64(point.x().unwrap()))?;
180                        map.serialize_entry("y", &b64(point.y().unwrap()))?;
181                        map.serialize_entry("d", &b64(key.to_bytes().as_ref()))?;
182                    }
183                    BarePrivateKeyInner::HS256(key) => {
184                        if let Some(kid) = kid {
185                            map.serialize_entry("kid", kid)?;
186                        }
187                        map.serialize_entry("kty", "oct")?;
188                        map.serialize_entry("k", &b64(&key.key))?;
189                    }
190                }
191                map.end()
192            }
193            SerializedKey::Public(kid, key) => {
194                let mut map = serializer.serialize_map(None)?;
195                match &key.inner {
196                    BarePublicKeyInner::RS256 { n, e } => {
197                        if let Some(kid) = kid {
198                            map.serialize_entry("kid", kid)?;
199                        }
200                        map.serialize_entry("kty", "RSA")?;
201                        map.serialize_entry("n", &b64(&n.to_bytes_be()))?;
202                        map.serialize_entry("e", &b64(&e.to_bytes_be()))?;
203                    }
204                    BarePublicKeyInner::ES256(key) => {
205                        if let Some(kid) = kid {
206                            map.serialize_entry("kid", kid)?;
207                        }
208                        map.serialize_entry("kty", "EC")?;
209                        map.serialize_entry("crv", "P-256")?;
210                        let point = key.to_encoded_point(false);
211                        map.serialize_entry("x", &b64(point.x().unwrap()))?;
212                        map.serialize_entry("y", &b64(point.y().unwrap()))?;
213                    }
214                    BarePublicKeyInner::HS256(key) => {
215                        if let Some(kid) = kid {
216                            map.serialize_entry("kid", kid)?;
217                        }
218                        map.serialize_entry("kty", "oct")?;
219                        map.serialize_entry("k", &b64(&key.key))?;
220                    }
221                }
222                map.end()
223            }
224            SerializedKey::UnknownOrInvalid(_, kty, map) => {
225                let mut new_map = serializer.serialize_map(None)?;
226                new_map.serialize_entry("kty", kty)?;
227                for (k, v) in map {
228                    new_map.serialize_entry(k, v)?;
229                }
230                new_map.end()
231            }
232        }
233    }
234}
235
236#[derive(Debug, PartialEq, Eq, Hash)]
237pub struct BareKey {
238    pub(crate) inner: BareKeyInner,
239}
240
241#[derive(Debug, Hash, PartialEq, Eq)]
242pub(crate) enum BareKeyInner {
243    Private(BarePrivateKeyInner),
244    Public(BarePublicKeyInner),
245}
246
247impl BareKey {
248    fn from_unvalidated(inner: BareKeyInner) -> Result<Self, KeyError> {
249        match inner {
250            BareKeyInner::Private(inner) => Ok(Self {
251                inner: BareKeyInner::Private(inner.validate()?),
252            }),
253            BareKeyInner::Public(inner) => Ok(Self {
254                inner: BareKeyInner::Public(inner.validate()?),
255            }),
256        }
257    }
258
259    pub fn key_type(&self) -> KeyType {
260        match &self.inner {
261            BareKeyInner::Private(key) => key.key_type(),
262            BareKeyInner::Public(key) => key.key_type(),
263        }
264    }
265
266    /// Load a key from a PEM-encoded string. Supported formats are PKCS1, PKCS8,
267    /// SEC1, and `JWT OCTAL KEY`.
268    pub fn from_pem(pem: &str) -> Result<Self, KeyError> {
269        let key = parse_pem(pem)?;
270        Self::from_parsed_unvalidated(&key).and_then(Self::from_unvalidated)
271    }
272
273    pub fn from_pem_multiple(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
274        let mut keys = Vec::new();
275        let pems = pem::parse_many(pem).map_err(|_| KeyError::DecodeError)?;
276        for pem in pems {
277            let key = Self::from_parsed_unvalidated(&pem).and_then(Self::from_unvalidated);
278            keys.push(key);
279        }
280        Ok(keys)
281    }
282
283    fn from_parsed_unvalidated(pem: &Pem) -> Result<BareKeyInner, KeyError> {
284        match pem.tag() {
285            "JWT OCTAL KEY" => handle_oct_key(pem).map(BareKeyInner::Private),
286            // EC never appears in a raw "ECPublicKey" form, so treat this as
287            // SPKI format.
288            // https://www.rfc-editor.org/rfc/rfc5915
289            "PUBLIC KEY" | "EC PUBLIC KEY" => handle_spki_pubkey(pem).map(BareKeyInner::Public),
290            "RSA PUBLIC KEY" => handle_rsa_pubkey(pem).map(BareKeyInner::Public),
291            "EC PRIVATE KEY" => handle_ec_key(pem.contents()).map(BareKeyInner::Private),
292            "RSA PRIVATE KEY" => handle_rsa_key(pem).map(BareKeyInner::Private),
293            "PRIVATE KEY" => handle_pkcs8_key(pem.contents()).map(BareKeyInner::Private),
294            tag => Err(KeyError::UnsupportedKeyType(tag.to_string())),
295        }
296    }
297
298    pub fn try_to_public(&self) -> Result<BarePublicKey, KeyError> {
299        match &self.inner {
300            BareKeyInner::Private(key) => BarePublicKey::from_unvalidated(key.try_into()?),
301            BareKeyInner::Public(key) => Ok(BarePublicKey { inner: key.clone() }),
302        }
303    }
304
305    pub fn try_to_private(&self) -> Result<BarePrivateKey, KeyError> {
306        match &self.inner {
307            BareKeyInner::Private(key) => Ok(BarePrivateKey { inner: key.clone() }),
308            BareKeyInner::Public(_) => {
309                Err(KeyError::UnsupportedKeyType("No private key".to_string()))
310            }
311        }
312    }
313
314    pub fn try_into_public(self) -> Result<BarePublicKey, KeyError> {
315        match &self.inner {
316            BareKeyInner::Private(key) => BarePublicKey::from_unvalidated(key.try_into()?),
317            BareKeyInner::Public(key) => Ok(BarePublicKey { inner: key.clone() }),
318        }
319    }
320
321    pub fn try_into_private(self) -> Result<BarePrivateKey, KeyError> {
322        match &self.inner {
323            BareKeyInner::Private(key) => Ok(BarePrivateKey { inner: key.clone() }),
324            BareKeyInner::Public(_) => {
325                Err(KeyError::UnsupportedKeyType("No private key".to_string()))
326            }
327        }
328    }
329
330    pub fn to_pem(&self) -> String {
331        match &self.inner {
332            BareKeyInner::Private(key) => key.to_pem(),
333            BareKeyInner::Public(key) => key.to_pem(),
334        }
335    }
336
337    pub fn clone_key(&self) -> Self {
338        match &self.inner {
339            BareKeyInner::Private(key) => BareKey {
340                inner: BareKeyInner::Private(key.clone()),
341            },
342            BareKeyInner::Public(key) => BareKey {
343                inner: BareKeyInner::Public(key.clone()),
344            },
345        }
346    }
347}
348
349/// A bare private key contains one of the following:
350///
351/// - An RSA private key
352/// - An ECDSA private key (P-256)
353/// - A symmetric key
354#[derive(Debug, Hash, PartialEq, Eq)]
355pub struct BarePrivateKey {
356    pub(crate) inner: BarePrivateKeyInner,
357}
358
359impl std::fmt::Debug for BarePrivateKeyInner {
360    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
361        match &self {
362            BarePrivateKeyInner::RS256(_) => write!(f, "RS256(...)"),
363            BarePrivateKeyInner::ES256(_) => write!(f, "ES256(...)"),
364            BarePrivateKeyInner::HS256(_) => write!(f, "HS256(...)"),
365        }
366    }
367}
368
369impl std::hash::Hash for BarePrivateKeyInner {
370    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
371        match &self {
372            BarePrivateKeyInner::RS256(key) => {
373                let Ok(key) = ring::rsa::KeyPair::from_der(key.secret_pkcs1_der()) else {
374                    return;
375                };
376                key.public().as_ref().hash(state);
377            }
378            BarePrivateKeyInner::ES256(key) => {
379                let key = key.public_key();
380                let point = key.to_encoded_point(false);
381                point.hash(state);
382            }
383            BarePrivateKeyInner::HS256(key) => key.hash(state),
384        }
385    }
386}
387
388impl PartialEq for BarePrivateKeyInner {
389    fn eq(&self, other: &Self) -> bool {
390        match (&self, &other) {
391            (BarePrivateKeyInner::RS256(a), BarePrivateKeyInner::RS256(b)) => {
392                let Ok(a) = ring::rsa::KeyPair::from_der(a.secret_pkcs1_der()) else {
393                    return false;
394                };
395                let Ok(b) = ring::rsa::KeyPair::from_der(b.secret_pkcs1_der()) else {
396                    return false;
397                };
398                a.public().as_ref() == b.public().as_ref()
399            }
400            (BarePrivateKeyInner::ES256(a), BarePrivateKeyInner::ES256(b)) => {
401                let a = a.public_key();
402                let b = b.public_key();
403                let a = a.to_encoded_point(false);
404                let b = b.to_encoded_point(false);
405                a == b
406            }
407            (BarePrivateKeyInner::HS256(a), BarePrivateKeyInner::HS256(b)) => a == b,
408            _ => false,
409        }
410    }
411}
412
413impl Eq for BarePrivateKeyInner {}
414
415pub(crate) enum BarePrivateKeyInner {
416    /// APIs expose PKCS1 more than PKCS8 so we can work with that
417    RS256(rustls_pki_types::PrivatePkcs1KeyDer<'static>),
418    /// Use the raw p256 secret key
419    ES256(p256::SecretKey),
420    /// Bag 'o' bytes (self-zeroing).
421    HS256(HmacKey),
422}
423
424impl Clone for BarePrivateKeyInner {
425    fn clone(&self) -> Self {
426        match self {
427            BarePrivateKeyInner::RS256(key) => BarePrivateKeyInner::RS256(key.clone_key()),
428            BarePrivateKeyInner::ES256(key) => BarePrivateKeyInner::ES256(key.clone()),
429            BarePrivateKeyInner::HS256(key) => BarePrivateKeyInner::HS256(key.clone()),
430        }
431    }
432}
433
434/// In debug mode, using the openssl command to generate RSA keys is much faster
435/// than ring.
436#[allow(unused)]
437#[cfg(unix)]
438fn optional_openssl_rsa_keygen(bits: usize) -> Option<BarePrivateKey> {
439    use std::process::Command;
440    // Try to call `openssl genrsa {bits} > /dev/null 2>&1` and then parse the stdout
441    // as PEM. If we fail, just return None.
442    let output = Command::new("openssl")
443        .args(["genrsa", &bits.to_string()])
444        .output()
445        .ok()?;
446    if output.status.success() {
447        let rsa = BarePrivateKey::from_pem(&String::from_utf8(output.stdout).ok()?).ok()?;
448        Some(rsa)
449    } else {
450        None
451    }
452}
453
454#[allow(unused)]
455#[cfg(not(unix))]
456fn optional_openssl_rsa_keygen(bits: usize) -> Option<BarePrivateKey> {
457    None
458}
459
460impl BarePrivateKey {
461    fn from_unvalidated(inner: BarePrivateKeyInner) -> Result<Self, KeyError> {
462        Ok(Self {
463            inner: inner.validate()?,
464        })
465    }
466
467    /// Generate a new key of the given type. This may be slow for RSA keys
468    /// when running without compiler optimizations.
469    #[cfg(feature = "keygen")]
470    pub fn generate(key_type: KeyType) -> Result<Self, KeyError> {
471        use pkcs1::EncodeRsaPrivateKey;
472        use rand::{rngs::ThreadRng, Rng};
473        use ring::rand::SystemRandom;
474
475        match key_type {
476            KeyType::RS256 => {
477                // Because keygen is so slow in debug mode, we use openssl to generate.
478                #[cfg(debug_assertions)]
479                {
480                    let rsa = optional_openssl_rsa_keygen(DEFAULT_GEN_RSA_KEY_BITS);
481                    if let Some(rsa) = rsa {
482                        return Ok(rsa);
483                    }
484                }
485
486                let key =
487                    rsa::RsaPrivateKey::new(&mut rand::thread_rng(), DEFAULT_GEN_RSA_KEY_BITS)
488                        .unwrap();
489                let key = key.to_pkcs1_der().unwrap();
490                Self::from_unvalidated(BarePrivateKeyInner::RS256(PrivatePkcs1KeyDer::from(
491                    key.to_bytes().to_vec(),
492                )))
493            }
494            KeyType::ES256 => {
495                let key = ring::signature::EcdsaKeyPair::generate_pkcs8(
496                    &ECDSA_P256_SHA256_FIXED_SIGNING,
497                    &SystemRandom::new(),
498                )
499                .unwrap();
500                Self::from_unvalidated(handle_pkcs8_key(key.as_ref())?)
501            }
502            KeyType::HS256 => {
503                let mut rng = ThreadRng::default();
504                let mut key = zeroize::Zeroizing::new(vec![0; DEFAULT_GEN_OCT_LEN_BYTES]);
505                rng.fill(key.as_mut_slice());
506                Self::from_unvalidated(BarePrivateKeyInner::HS256(HmacKey { key }))
507            }
508        }
509    }
510
511    /// Load an ECDSA key from a JWK.
512    pub fn from_jwt_ec(crv: &str, d: &str, x: &str, y: &str) -> Result<Self, KeyError> {
513        if crv != "P-256" {
514            return Err(KeyError::UnsupportedKeyType(crv.to_string()));
515        }
516
517        // TODO: Not an ideal way to parse
518        let validation = |c: char| !c.is_alphanumeric() && c != '-' && c != '_';
519        if x.contains(validation) || y.contains(validation) || d.contains(validation) {
520            return Err(KeyError::DecodeError);
521        }
522        let jwk = JwkEcKey::from_str(&format!(
523            r#"{{"kty":"EC","crv":"P-256","x":"{x}","y":"{y}","d":"{d}"}}"#
524        ))
525        .map_err(|_| KeyError::DecodeError)?;
526
527        let key: p256::elliptic_curve::SecretKey<p256::NistP256> =
528            jwk.to_secret_key().map_err(|_| KeyError::DecodeError)?;
529
530        Self::from_unvalidated(BarePrivateKeyInner::ES256(key))
531    }
532
533    /// Load an RSA key from a JWK.
534    #[allow(clippy::too_many_arguments)]
535    pub fn from_jwt_rsa(
536        n: &str,
537        e: &str,
538        d: &str,
539        p: &str,
540        q: &str,
541        dp: &str,
542        dq: &str,
543        qinv: &str,
544    ) -> Result<Self, KeyError> {
545        let n = b64_decode(n)?;
546        let e = b64_decode(e)?;
547        let d = b64_decode(d)?;
548        let p = b64_decode(p)?;
549        let q = b64_decode(q)?;
550        let dp = b64_decode(dp)?;
551        let dq = b64_decode(dq)?;
552        let qinv = b64_decode(qinv)?;
553
554        let rsa = pkcs1::RsaPrivateKey {
555            modulus: UintRef::new(&n).map_err(|_| KeyError::DecodeError)?,
556            public_exponent: UintRef::new(&e).map_err(|_| KeyError::DecodeError)?,
557            private_exponent: UintRef::new(&d).map_err(|_| KeyError::DecodeError)?,
558            prime1: UintRef::new(&p).map_err(|_| KeyError::DecodeError)?,
559            prime2: UintRef::new(&q).map_err(|_| KeyError::DecodeError)?,
560            exponent1: UintRef::new(&dp).map_err(|_| KeyError::DecodeError)?,
561            exponent2: UintRef::new(&dq).map_err(|_| KeyError::DecodeError)?,
562            coefficient: UintRef::new(&qinv).map_err(|_| KeyError::DecodeError)?,
563            other_prime_infos: None,
564        };
565
566        let mut vec = Vec::with_capacity(n.len() * 4);
567        rsa.encode_to_vec(&mut vec)
568            .map_err(|_| KeyError::DecodeError)?;
569
570        Self::from_unvalidated(BarePrivateKeyInner::RS256(PrivatePkcs1KeyDer::from(vec)))
571    }
572
573    /// Load an HMAC key from a base64-encoded string.
574    pub fn from_jwt_oct(k: &str) -> Result<Self, KeyError> {
575        let key = b64_decode(k)?;
576        Self::from_unvalidated(BarePrivateKeyInner::HS256(HmacKey { key }))
577    }
578
579    /// Load an HMAC key from a raw byte slice.
580    pub fn from_raw_oct(key: &[u8]) -> Result<Self, KeyError> {
581        Self::from_unvalidated(BarePrivateKeyInner::HS256(HmacKey {
582            key: key.to_vec().into(),
583        }))
584    }
585
586    /// Load a key from a PEM-encoded string. Supported formats are PKCS1, PKCS8,
587    /// SEC1, and `JWT OCTAL KEY`.
588    pub fn from_pem(pem: &str) -> Result<Self, KeyError> {
589        BareKey::from_pem(pem)?.try_into_private()
590    }
591
592    pub fn from_pem_multiple(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
593        Ok(BareKey::from_pem_multiple(pem)?
594            .into_iter()
595            .map(|key| key.and_then(|k| k.try_into_private()))
596            .collect())
597    }
598
599    pub fn to_public(&self) -> Result<BarePublicKey, KeyError> {
600        let inner = (&(self.inner)).try_into()?;
601        Ok(BarePublicKey { inner })
602    }
603
604    pub fn into_public(self) -> Result<BarePublicKey, KeyError> {
605        let inner = (&(self.inner)).try_into()?;
606        Ok(BarePublicKey { inner })
607    }
608
609    pub fn clone_key(&self) -> Self {
610        Self {
611            inner: self.inner.clone(),
612        }
613    }
614
615    pub fn to_pem(&self) -> String {
616        self.inner.to_pem()
617    }
618
619    pub fn to_pem_public(&self) -> Result<String, KeyError> {
620        self.inner.to_pem_public()
621    }
622
623    pub fn key_type(&self) -> KeyType {
624        self.inner.key_type()
625    }
626}
627
628impl BarePrivateKeyInner {
629    pub fn key_type(&self) -> KeyType {
630        match &self {
631            BarePrivateKeyInner::RS256(..) => KeyType::RS256,
632            BarePrivateKeyInner::ES256(..) => KeyType::ES256,
633            BarePrivateKeyInner::HS256(..) => KeyType::HS256,
634        }
635    }
636
637    pub fn to_pem(&self) -> String {
638        let key = match &self {
639            BarePrivateKeyInner::RS256(key) => {
640                pem::encode(&Pem::new("RSA PRIVATE KEY", key.secret_pkcs1_der()))
641            }
642            BarePrivateKeyInner::ES256(key) => {
643                let pkcs8 = pkcs8_from_ec(key).unwrap();
644                pem::encode(&Pem::new("PRIVATE KEY", pkcs8))
645            }
646            BarePrivateKeyInner::HS256(key) => {
647                pem::encode(&Pem::new("JWT OCTAL KEY", key.key.as_slice()))
648            }
649        };
650        key
651    }
652
653    /// Export this private key to a public key in PEM format.
654    pub fn to_pem_public(&self) -> Result<String, KeyError> {
655        let key = match &self {
656            BarePrivateKeyInner::RS256(key) => {
657                let pkcs1 = pkcs1::RsaPrivateKey::from_der(key.secret_pkcs1_der())
658                    .map_err(|_| KeyError::DecodeError)?;
659                BarePublicKeyInner::RS256 {
660                    n: BigUint::from_bytes_be(pkcs1.modulus.as_bytes()),
661                    e: BigUint::from_bytes_be(pkcs1.public_exponent.as_bytes()),
662                }
663                .to_pem()
664            }
665            BarePrivateKeyInner::ES256(key) => BarePublicKeyInner::ES256(key.public_key()).to_pem(),
666            _ => return Err(KeyError::UnsupportedKeyType(self.key_type().to_string())),
667        };
668        Ok(key)
669    }
670
671    fn validate(self) -> Result<Self, KeyError> {
672        match &self {
673            BarePrivateKeyInner::RS256(key) => {
674                validate_rsa_key_pair(key.secret_pkcs1_der())?;
675                Ok(self)
676            }
677            BarePrivateKeyInner::ES256(key) => {
678                validate_ecdsa_key_pair(key)?;
679                Ok(self)
680            }
681            BarePrivateKeyInner::HS256(key) => {
682                if key.key.len() < MIN_OCT_LEN_BYTES {
683                    return Err(KeyError::UnsupportedKeyType(format!(
684                        "oct key ({} bytes) < {} bytes",
685                        key.key.len(),
686                        MIN_OCT_LEN_BYTES
687                    )));
688                }
689                Ok(self)
690            }
691        }
692    }
693}
694
695fn parse_pem(pem: &str) -> Result<Pem, KeyError> {
696    pem::parse(pem).map_err(|_| KeyError::InvalidPem)
697}
698
699fn handle_oct_key(key: &Pem) -> Result<BarePrivateKeyInner, KeyError> {
700    let key = key.contents().to_vec().into();
701    Ok(BarePrivateKeyInner::HS256(HmacKey { key }))
702}
703
704fn handle_ec_key(key: &[u8]) -> Result<BarePrivateKeyInner, KeyError> {
705    let mut reader = SliceReader::new(key).map_err(|_| KeyError::DecodeError)?;
706    let decoded_key = EcPrivateKey::decode(&mut reader).map_err(|_| KeyError::DecodeError)?;
707
708    if let Some(parameters) = decoded_key.parameters {
709        if parameters.named_curve() == Some(SECP_256_R_1) {
710            let key = p256::SecretKey::from_slice(decoded_key.private_key)
711                .map_err(|_| KeyError::DecodeError)?;
712            return Ok(BarePrivateKeyInner::ES256(key));
713        }
714    }
715
716    Err(KeyError::InvalidEcParameters)
717}
718
719fn handle_rsa_key(key: &Pem) -> Result<BarePrivateKeyInner, KeyError> {
720    let mut reader = SliceReader::new(key.contents()).map_err(|_| KeyError::DecodeError)?;
721    let _decoded_key =
722        pkcs1::RsaPrivateKey::decode(&mut reader).map_err(|_| KeyError::DecodeError)?;
723
724    Ok(BarePrivateKeyInner::RS256(PrivatePkcs1KeyDer::from(
725        key.contents().to_vec(),
726    )))
727}
728
729fn handle_pkcs8_key(key: &[u8]) -> Result<BarePrivateKeyInner, KeyError> {
730    let mut reader = SliceReader::new(key).map_err(|_| KeyError::DecodeError)?;
731    let decoded_key = PrivateKeyInfo::decode(&mut reader).map_err(|_| KeyError::DecodeError)?;
732
733    match decoded_key.algorithm.oid {
734        ID_EC_PUBLIC_KEY => {
735            // Ensure the curve is P-256
736            if decoded_key.algorithm.parameters_oid() != Ok(SECP_256_R_1) {
737                return Err(KeyError::InvalidEcParameters);
738            }
739            let mut reader =
740                SliceReader::new(decoded_key.private_key).map_err(|_| KeyError::DecodeError)?;
741            let key = EcPrivateKey::decode(&mut reader).map_err(|_| KeyError::DecodeError)?;
742            let key =
743                p256::SecretKey::from_slice(key.private_key).map_err(|_| KeyError::DecodeError)?;
744            Ok(BarePrivateKeyInner::ES256(key))
745        }
746        RSA_ENCRYPTION => {
747            RsaKeyPair::from_der(decoded_key.private_key)
748                .map_err(|e| KeyError::KeyValidationError(KeyValidationError(e.to_string())))?;
749            Ok(BarePrivateKeyInner::RS256(PrivatePkcs1KeyDer::from(
750                decoded_key.private_key.to_vec(),
751            )))
752        }
753        _ => Err(KeyError::UnsupportedKeyType(
754            decoded_key.algorithm.oid.to_string(),
755        )),
756    }
757}
758
759fn pkcs8_from_ec(key: &p256::SecretKey) -> Result<Vec<u8>, KeyError> {
760    let key_bytes = key.to_bytes();
761    let public_key_bytes = key.public_key().to_sec1_bytes().into_vec();
762    let mut vec = Vec::new();
763    EcPrivateKey {
764        private_key: key_bytes.as_ref(),
765        parameters: Some(EcParameters::NamedCurve(SECP_256_R_1)),
766        public_key: Some(public_key_bytes.as_ref()),
767    }
768    .encode_to_vec(&mut vec)
769    .map_err(|_| KeyError::EncodeError)?;
770
771    let pkcs8 = pkcs8::PrivateKeyInfo {
772        algorithm: AlgorithmIdentifier {
773            oid: ID_EC_PUBLIC_KEY,
774            parameters: Some(AnyRef::from(&EcParameters::NamedCurve(SECP_256_R_1))),
775        },
776        private_key: &vec,
777        public_key: None,
778    };
779    let mut buf = Vec::new();
780    pkcs8
781        .encode_to_vec(&mut buf)
782        .map_err(|_| KeyError::EncodeError)?;
783    Ok(buf)
784}
785
786impl TryInto<jsonwebtoken::EncodingKey> for &BarePrivateKey {
787    type Error = KeyError;
788
789    fn try_into(self) -> Result<jsonwebtoken::EncodingKey, Self::Error> {
790        match &self.inner {
791            BarePrivateKeyInner::RS256(key) => Ok(jsonwebtoken::EncodingKey::from_rsa_der(
792                key.secret_pkcs1_der(),
793            )),
794            BarePrivateKeyInner::ES256(key) => {
795                Ok(jsonwebtoken::EncodingKey::from_ec_der(&pkcs8_from_ec(key)?))
796            }
797            BarePrivateKeyInner::HS256(key) => Ok(jsonwebtoken::EncodingKey::from_secret(&key.key)),
798        }
799    }
800}
801
802impl TryInto<jsonwebtoken::DecodingKey> for &BarePublicKey {
803    type Error = KeyError;
804
805    fn try_into(self) -> Result<jsonwebtoken::DecodingKey, Self::Error> {
806        match &self.inner {
807            BarePublicKeyInner::RS256 { n, e } => {
808                Ok(jsonwebtoken::DecodingKey::from_rsa_raw_components(
809                    &n.to_bytes_be(),
810                    &e.to_bytes_be(),
811                ))
812            }
813            BarePublicKeyInner::ES256(key) => {
814                Ok(jsonwebtoken::DecodingKey::from_ec_der(&key.to_sec1_bytes()))
815            }
816            BarePublicKeyInner::HS256(key) => Ok(jsonwebtoken::DecodingKey::from_secret(&key.key)),
817        }
818    }
819}
820
821impl TryFrom<&BarePrivateKeyInner> for BarePublicKeyInner {
822    type Error = KeyError;
823
824    fn try_from(key: &BarePrivateKeyInner) -> Result<Self, Self::Error> {
825        match key {
826            BarePrivateKeyInner::RS256(key) => {
827                let rsa = pkcs1::RsaPrivateKey::from_der(key.secret_pkcs1_der())
828                    .map_err(|_| KeyError::DecodeError)?;
829                let n = BigUint::from_bytes_be(rsa.modulus.as_bytes());
830                let e = BigUint::from_bytes_be(rsa.public_exponent.as_bytes());
831                Ok(BarePublicKeyInner::RS256 { n, e })
832            }
833            BarePrivateKeyInner::ES256(key) => {
834                let pk = key.public_key();
835                Ok(BarePublicKeyInner::ES256(pk))
836            }
837            BarePrivateKeyInner::HS256(key) => Ok(BarePublicKeyInner::HS256(key.clone())),
838        }
839    }
840}
841
842#[derive(Debug, PartialEq, Eq, Hash)]
843pub struct BarePublicKey {
844    pub(crate) inner: BarePublicKeyInner,
845}
846
847impl std::fmt::Debug for BarePublicKeyInner {
848    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
849        match &self {
850            BarePublicKeyInner::RS256 { n, e } => write!(f, "RS256({n}, {e})"),
851            BarePublicKeyInner::ES256(pk) => write!(f, "ES256({pk:?})"),
852            BarePublicKeyInner::HS256(_key) => write!(f, "HS256(...)"),
853        }
854    }
855}
856
857impl std::hash::Hash for BarePublicKeyInner {
858    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
859        match &self {
860            BarePublicKeyInner::RS256 { n, e } => {
861                n.hash(state);
862                e.hash(state);
863            }
864            BarePublicKeyInner::ES256(pk) => {
865                pk.to_encoded_point(false).hash(state);
866            }
867            BarePublicKeyInner::HS256(key) => {
868                key.hash(state);
869            }
870        }
871    }
872}
873
874impl Eq for BarePublicKeyInner {}
875
876impl PartialEq for BarePublicKeyInner {
877    fn eq(&self, other: &Self) -> bool {
878        match (&self, &other) {
879            (
880                BarePublicKeyInner::RS256 { n: n1, e: e1 },
881                BarePublicKeyInner::RS256 { n: n2, e: e2 },
882            ) => n1 == n2 && e1 == e2,
883            (BarePublicKeyInner::ES256(pk1), BarePublicKeyInner::ES256(pk2)) => pk1 == pk2,
884            (BarePublicKeyInner::HS256(key1), BarePublicKeyInner::HS256(key2)) => key1 == key2,
885            _ => false,
886        }
887    }
888}
889
890#[derive(Clone)]
891pub(crate) enum BarePublicKeyInner {
892    RS256 { n: BigUint, e: BigUint },
893    ES256(p256::PublicKey),
894    HS256(HmacKey),
895}
896
897impl BarePublicKey {
898    fn from_unvalidated(inner: BarePublicKeyInner) -> Result<Self, KeyError> {
899        Ok(Self {
900            inner: inner.validate()?,
901        })
902    }
903
904    /// Load an ECDSA public key from a JWK.
905    pub fn from_jwt_ec(crv: &str, x: &str, y: &str) -> Result<Self, KeyError> {
906        if crv != "P-256" {
907            return Err(KeyError::UnsupportedKeyType(format!(
908                "EC curve ({crv}) not supported"
909            )));
910        }
911
912        let x = b64_decode(x)?;
913        let y = b64_decode(y)?;
914        let x = GenericArray::<u8, p256::U32>::from_slice(x.as_slice());
915        let y = GenericArray::<u8, p256::U32>::from_slice(y.as_slice());
916        let point = p256::EncodedPoint::from_affine_coordinates(x, y, false);
917        let key = p256::PublicKey::from_encoded_point(&point)
918            .into_option()
919            .ok_or(KeyError::DecodeError)?;
920        Self::from_unvalidated(BarePublicKeyInner::ES256(key))
921    }
922
923    /// Load an RSA public key from a JWK.
924    pub fn from_jwt_rsa(n: &str, e: &str) -> Result<Self, KeyError> {
925        let n = BigUint::from_bytes_be(&b64_decode(n)?);
926        let e = BigUint::from_bytes_be(&b64_decode(e)?);
927        Self::from_unvalidated(BarePublicKeyInner::RS256 { n, e })
928    }
929
930    pub fn from_jwt_oct(k: &str) -> Result<Self, KeyError> {
931        let key = b64_decode(k)?;
932        Self::from_unvalidated(BarePublicKeyInner::HS256(HmacKey { key }))
933    }
934
935    /// Creates a `BarePublicKey` from a PEM-encoded public or private key. If the
936    /// PEM-encoded file contains a private key, it will be converted to a public key
937    /// and the private key data will be discarded.
938    ///
939    /// Supported formats include the private key formats from [`BareKey::from_pem`],
940    /// `SPKI`-containers (`PUBLIC KEY` and `EC PUBLIC KEY`), and `RSA PUBLIC KEY`
941    /// traditional-style keys (`RsaPublicKey`).
942    pub fn from_pem(pem: &str) -> Result<Self, KeyError> {
943        let key = BareKey::from_pem(pem)?;
944        key.try_to_public()
945    }
946
947    pub fn from_pem_multiple(pem: &str) -> Result<Vec<Result<Self, KeyError>>, KeyError> {
948        Ok(BareKey::from_pem_multiple(pem)?
949            .into_iter()
950            .map(|key| key.and_then(|k| k.try_to_public()))
951            .collect())
952    }
953
954    pub fn clone_key(&self) -> Self {
955        Self {
956            inner: self.inner.clone(),
957        }
958    }
959
960    pub fn key_type(&self) -> KeyType {
961        self.inner.key_type()
962    }
963
964    pub fn to_pem(&self) -> String {
965        self.inner.to_pem()
966    }
967}
968
969impl BarePublicKeyInner {
970    pub fn key_type(&self) -> KeyType {
971        match &self {
972            BarePublicKeyInner::RS256 { .. } => KeyType::RS256,
973            BarePublicKeyInner::ES256(..) => KeyType::ES256,
974            BarePublicKeyInner::HS256(..) => KeyType::HS256,
975        }
976    }
977
978    pub fn to_pem(&self) -> String {
979        // We use unwrap() here but these cases should not be reachable
980        match &self {
981            BarePublicKeyInner::RS256 { n, e } => {
982                let mut v = Vec::new();
983                pkcs1::RsaPublicKey {
984                    modulus: UintRef::new(&n.to_bytes_be()).unwrap(),
985                    public_exponent: UintRef::new(&e.to_bytes_be()).unwrap(),
986                }
987                .encode_to_vec(&mut v)
988                .unwrap();
989                pem::encode(&Pem::new("RSA PUBLIC KEY", v))
990            }
991            BarePublicKeyInner::ES256(spki) => {
992                let spki = SubjectPublicKeyInfoOwned {
993                    algorithm: AlgorithmIdentifier {
994                        oid: ID_EC_PUBLIC_KEY,
995                        parameters: Some(
996                            AnyRef::from(&EcParameters::NamedCurve(SECP_256_R_1)).into(),
997                        ),
998                    },
999                    subject_public_key: BitString::from_bytes(&spki.to_sec1_bytes()).unwrap(),
1000                };
1001                let mut v = vec![];
1002                spki.encode_to_vec(&mut v).unwrap();
1003                pem::encode(&Pem::new("PUBLIC KEY", v))
1004            }
1005            BarePublicKeyInner::HS256(key) => {
1006                pem::encode(&Pem::new("JWT OCTAL KEY", key.key.as_slice()))
1007            }
1008        }
1009    }
1010
1011    fn validate(self) -> Result<Self, KeyError> {
1012        match &self {
1013            BarePublicKeyInner::RS256 { n, e } => validate_rsa_pubkey(n, e),
1014            BarePublicKeyInner::ES256(pk) => validate_ecdsa_pubkey(pk),
1015            BarePublicKeyInner::HS256(key) => {
1016                if key.key.len() < MIN_OCT_LEN_BYTES {
1017                    return Err(KeyError::UnsupportedKeyType(format!(
1018                        "oct key ({} bytes) < {} bytes",
1019                        key.key.len(),
1020                        MIN_OCT_LEN_BYTES
1021                    )));
1022                }
1023                Ok(())
1024            }
1025        }?;
1026        Ok(self)
1027    }
1028}
1029
1030fn handle_spki_pubkey(key: &Pem) -> Result<BarePublicKeyInner, KeyError> {
1031    let mut reader = SliceReader::new(key.contents()).map_err(|_| KeyError::DecodeError)?;
1032    let decoded_key = pkcs8::SubjectPublicKeyInfo::<Any, BitString>::decode(&mut reader)
1033        .map_err(|_| KeyError::DecodeError)?;
1034
1035    match decoded_key.algorithm.oid {
1036        ID_EC_PUBLIC_KEY => {
1037            let pk = p256::PublicKey::from_sec1_bytes(decoded_key.subject_public_key.raw_bytes())
1038                .map_err(|_| KeyError::DecodeError)?;
1039            Ok(BarePublicKeyInner::ES256(pk))
1040        }
1041        RSA_ENCRYPTION => {
1042            let pub_key = pkcs1::RsaPublicKey::from_der(decoded_key.subject_public_key.raw_bytes())
1043                .map_err(|_| KeyError::DecodeError)?;
1044            Ok(BarePublicKeyInner::RS256 {
1045                n: BigUint::from_bytes_be(pub_key.modulus.as_bytes()),
1046                e: BigUint::from_bytes_be(pub_key.public_exponent.as_bytes()),
1047            })
1048        }
1049        _ => Err(KeyError::UnsupportedKeyType(
1050            decoded_key.algorithm.oid.to_string(),
1051        )),
1052    }
1053}
1054
1055fn handle_rsa_pubkey(key: &Pem) -> Result<BarePublicKeyInner, KeyError> {
1056    let mut reader = SliceReader::new(key.contents()).map_err(|_| KeyError::DecodeError)?;
1057    let decoded_key =
1058        pkcs1::RsaPublicKey::decode(&mut reader).map_err(|_| KeyError::DecodeError)?;
1059    Ok(BarePublicKeyInner::RS256 {
1060        n: BigUint::from_bytes_be(decoded_key.modulus.as_bytes()),
1061        e: BigUint::from_bytes_be(decoded_key.public_exponent.as_bytes()),
1062    })
1063}
1064
1065/// Decode a base64 string with optional padding, since jwcrypto also seems to
1066/// accept this.
1067///
1068/// > JWKs make use of the base64url encoding as defined in RFC 4648 As allowed
1069/// > by Section 3.2 of the RFC, this specification mandates that base64url
1070/// > encoding when used with JWKs MUST NOT use padding. Notes on implementing
1071/// > base64url encoding can be found in the JWS specification.
1072fn b64_decode(s: &str) -> Result<zeroize::Zeroizing<Vec<u8>>, KeyError> {
1073    let vec = if s.ends_with('=') {
1074        base64ct::Base64Url::decode_vec(s).map_err(|_| KeyError::DecodeError)?
1075    } else {
1076        base64ct::Base64UrlUnpadded::decode_vec(s).map_err(|_| KeyError::DecodeError)?
1077    };
1078    Ok(zeroize::Zeroizing::new(vec))
1079}
1080
1081fn validate_ecdsa_key_pair(key: &p256::SecretKey) -> Result<(), KeyError> {
1082    let pkcs8_bytes = pkcs8_from_ec(key)?;
1083    let _keypair = ring::signature::EcdsaKeyPair::from_pkcs8(
1084        &ECDSA_P256_SHA256_FIXED_SIGNING,
1085        &pkcs8_bytes,
1086        &SystemRandom::new(),
1087    )
1088    .map_err(|e| KeyError::KeyValidationError(KeyValidationError(e.to_string())))?;
1089    Ok(())
1090}
1091
1092fn validate_rsa_key_pair(pkcs8: &[u8]) -> Result<(), KeyError> {
1093    let _keypair = ring::signature::RsaKeyPair::from_der(pkcs8)
1094        .map_err(|e| KeyError::KeyValidationError(KeyValidationError(e.to_string())))?;
1095    Ok(())
1096}
1097
1098fn validate_rsa_pubkey(n: &BigUint, e: &BigUint) -> Result<(), KeyError> {
1099    // TODO: Should we validate more than this?
1100    if e == &BigUint::from(3_u8) {
1101        return Err(KeyError::UnsupportedKeyType("RSA e=3".to_string()));
1102    }
1103    if n.bits() < MIN_RSA_KEY_BITS {
1104        return Err(KeyError::UnsupportedKeyType(format!(
1105            "RSA n ({}) < {} bits",
1106            n.bits(),
1107            MIN_RSA_KEY_BITS
1108        )));
1109    }
1110    Ok(())
1111}
1112
1113fn validate_ecdsa_pubkey(_pk: &p256::PublicKey) -> Result<(), KeyError> {
1114    // TODO: Should we validate more than this?
1115    Ok(())
1116}
1117
1118#[cfg(test)]
1119mod tests {
1120    use std::hash::{Hash, Hasher};
1121
1122    use super::*;
1123    use rstest::*;
1124
1125    #[test]
1126    fn test_fallback_rsa_keygen() {
1127        let rsa = optional_openssl_rsa_keygen(DEFAULT_GEN_RSA_KEY_BITS);
1128        if let Some(rsa) = rsa {
1129            println!("{}", rsa.to_pem());
1130        } else {
1131            println!("Failed to generate RSA key");
1132        }
1133    }
1134
1135    fn load_test_file(filename: &str) -> String {
1136        let path = std::path::Path::new(env!("CARGO_MANIFEST_DIR"))
1137            .join("src/testcases")
1138            .join(filename);
1139        eprintln!("FILE: {}", path.display());
1140        std::fs::read_to_string(path).unwrap()
1141    }
1142
1143    #[rstest]
1144    #[case::ec_pk8("prime256v1-prv-pkcs8.pem")]
1145    #[case::ec_sec1("prime256v1-prv-sec1.pem")]
1146    #[case::rsa_pkcs1("rsa2048-prv-pkcs1.pem")]
1147    #[case::rsa_pkcs8("rsa2048-prv-pkcs8.pem")]
1148    fn test_from_pem_private(#[case] pem: &str) {
1149        let input = load_test_file(pem);
1150        eprintln!("IN:\n{input}");
1151        let key = BarePrivateKey::from_pem(&input).unwrap();
1152        eprintln!("OUT:\n{}", key.to_pem());
1153        let key = BarePrivateKey::from_pem(&key.to_pem()).expect("Failed to round-trip");
1154
1155        let key_type = key.key_type();
1156        let encoding_key = (&key).try_into().unwrap();
1157        let token = match key_type {
1158            KeyType::RS256 => jsonwebtoken::encode(
1159                &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256),
1160                &["claim"],
1161                &encoding_key,
1162            )
1163            .unwrap(),
1164            KeyType::ES256 => jsonwebtoken::encode(
1165                &jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256),
1166                &["claim"],
1167                &encoding_key,
1168            )
1169            .unwrap(),
1170            _ => unreachable!(),
1171        };
1172        println!("{token}");
1173    }
1174
1175    #[rstest]
1176    #[case::ec_pk8("prime256v1-prv-pkcs8.pem")]
1177    #[case::ec_sec1("prime256v1-prv-sec1.pem")]
1178    #[case::ec_spki_unc("prime256v1-pub-spki-uncompressed.pem")]
1179    #[case::ec_spki("prime256v1-pub-spki.pem")]
1180    fn test_from_pem_public_ec(#[case] pem: &str) {
1181        let key = BarePublicKey::from_pem(&load_test_file(pem)).unwrap();
1182        println!("{}", key.to_pem());
1183        BarePublicKey::from_pem(&key.to_pem()).expect("Failed to round-trip");
1184    }
1185
1186    #[rstest]
1187    #[case::rsa_pkcs1("rsa2048-prv-pkcs1.pem")]
1188    #[case::rsa_pkcs8("rsa2048-prv-pkcs8.pem")]
1189    #[case::rsa_spki("rsa2048-pub-pkcs1.pem")]
1190    #[case::rsa_spki_pkcs8("rsa2048-pub-pkcs8.pem")]
1191    fn test_from_pem_public_rsa(#[case] pem: &str) {
1192        let key = BarePublicKey::from_pem(&load_test_file(pem)).unwrap();
1193        println!("{}", key.to_pem());
1194        BarePublicKey::from_pem(&key.to_pem()).expect("Failed to round-trip");
1195    }
1196
1197    /// Test that the equality and hash functions work for BarePublicKey and BareKey. All
1198    /// key forms should be equal.
1199    #[test]
1200    fn test_eq_hash() {
1201        let key1 = BarePrivateKey::from_pem(&load_test_file("rsa2048-prv-pkcs1.pem")).unwrap();
1202
1203        for key in [
1204            "rsa2048-prv-pkcs1.pem",
1205            "rsa2048-prv-pkcs8.pem",
1206            "rsa2048-pub-pkcs1.pem",
1207            "rsa2048-pub-pkcs8.pem",
1208        ] {
1209            if key.contains("pub") {
1210                let key1: BarePublicKey = key1.to_public().unwrap();
1211                let key2 = BarePublicKey::from_pem(&load_test_file(key)).unwrap();
1212                assert_eq!(key1, key2);
1213                let mut hasher = std::collections::hash_map::DefaultHasher::new();
1214                key1.hash(&mut hasher);
1215                let hash1 = hasher.finish();
1216                hasher = std::collections::hash_map::DefaultHasher::new();
1217                key2.hash(&mut hasher);
1218                let hash2 = hasher.finish();
1219                assert_eq!(hash1, hash2);
1220            } else {
1221                let key2 = BarePrivateKey::from_pem(&load_test_file(key)).unwrap();
1222                assert_eq!(key1, key2);
1223                let mut hasher = std::collections::hash_map::DefaultHasher::new();
1224                key1.hash(&mut hasher);
1225                let hash1 = hasher.finish();
1226                hasher = std::collections::hash_map::DefaultHasher::new();
1227                key2.hash(&mut hasher);
1228                let hash2 = hasher.finish();
1229                assert_eq!(hash1, hash2);
1230            }
1231        }
1232    }
1233
1234    #[test]
1235    fn test_jwt_ec_key() {
1236        let key = BarePrivateKey::from_jwt_ec(
1237            "P-256",
1238            "w0pL1NOlKBOMtSOvUf6aFeEguWFCclQjWrWqHtHdEA8",
1239            "ZX_Ajm_22hdQbXImmtmaG-9TQ2z5Dt5Hbia0JzibvXc",
1240            "9r0Do-XFPyMYM6XCtOAT8AgY2xyRYLuS4U-_xXHDjeE",
1241        )
1242        .unwrap();
1243        println!("{}", key.to_pem());
1244    }
1245
1246    #[test]
1247    fn test_jwt_rsa_key() {
1248        let e = "AQAB";
1249        let n = r#"oW-OMq9ATezmeSGLlTbp--Epar64s7qZSi2hTgmdmlaJdpDO8X_EunUIB4DLyPEsOH45-W
1250            P2xxmw9Uv0UHfvfHsqOKx6vyLjSkDcrUddBWLWhJ5vVm2iHW8FGtYmaLWcHyyh2QiVQUriUNo3HtQqGRKBw9V2X
1251            gIJ4tzIysuxiMM0uFs8IAvl6TX7MHgUnW4rohyDCJiWLs8UDHpdN3mBpIiokrRr_iTTWNb5m_HKWGJ7RBsLaRsX
1252            VhxgxZm2PrEEcgb5XlcBbRqOD-5LilCGw5IcX4y12vl_zGpdn-X63UjZmgjRyXKNLh7pOMyKDvWl5vp89w-DKTV
1253            5oN6CkVnI5w"#
1254            .replace(char::is_whitespace, "");
1255        let d = r#"QkfWhrnMeZIP6GDc-dUTiV5fTlvi4qv0vu9wIGWzRwhLpRn8VUwDnhhpxQbc5HIcmU8-B0
1256            ZDLmi-bmASfa1Ybu_0nFM4jFxLHJP35s77grgbYlTYWpBltJb97hBJsckKwgPlqYGsIiQYOmD1q5spc6TVEW4Fj
1257            MBihbnnWNf72q2_1CeYgBmLxaMDukUJ8gAaRXkGT0_4YBVBioPUpt_JrfX4dvtJlV3ehXnjN2KiH0xxXHinYdQr
1258            NSjrUSMUFRCNvSadmuYp1Aoxgsa43VoNAQqbvDRzxjX8eqjdXykVU_ILLwveH9NpZVho727Vd2ISvhwjtjDYMLY
1259            q6H_Rj6yrTQ"#
1260            .replace(char::is_whitespace, "");
1261        let p = r#"1Ce5utgQeHjSPQ_WbUzNt2wRCN8_VbH2LcmPzvxx1XfP7N8FpPs7isx5RpGnrAcVlxq9bI
1262            MgKq5wtEW2mK4rHB9n9kIxQwDGD7YGOSU3uK-Mi_ygm7ytTo3keMQ9Vj_W05UCT4l8RHvHwU6h-hvCIcN0TnHO0
1263            mX4JsAgRB-XmuU"#
1264            .replace(char::is_whitespace, "");
1265        let q = r#"wsx4ar__O_4dAva_emh7nOSAarF0UBrCuckHImCHwCM62mntXXhjAyY7t9BMQ4ccgYLNeW
1266            1l9lKpP3orkpYY1wsRMWGrQyDZlKqwNp-x5IG7c5RescuCJ4Yy5JO_PmtXOwukWH7YUTk7nWCCYNCxfHCsxvr-X
1267            T4oct9FZAtHu9s"#
1268            .replace(char::is_whitespace, "");
1269        let dp = r#"iKG49MM4AE5Xn-m2QBgpmIppghw87tS45g4cpsJgEYmjCDstqG4Aj8hWBoPBx4Gcfv9Cp
1270            ULhkXtcrE0FZtksfGUhkDBbB3rVE8M3yM_WTgQI8RLW4NWni6LIVJqVohllIkih_1VdCcHqCO26VZhQ82usWO
1271            TkvQ3cviAX56es_J0"#
1272            .replace(char::is_whitespace, "");
1273        let dq = r#"YqJl1qeg9R-WUQnfqnt9G9QXse5olqb2Mlw34JBALGmqQy2fotRyTgXt9wThmM-w_2Lb5
1274            8AdALyaNioGJhMaQMi5y-dIcJURltVWpFH4IVwPLlbSG_SP0rOA0Xx-OXzgjmU2shiIL5hrNyTG337MX9Ytph
1275            Mw-MWgdYnX-PA9QkE"#
1276            .replace(char::is_whitespace, "");
1277        let qinv = r#"CMSVnYipRlZJ0miceg9ECPNkAIvKUbaUYfccdJOl2ffP0Fs4FNxoJBoakyoNuJdYjV6
1278            syGSMON0A9OBpGWCL3A21X5BHw3JhsA3XGLMwXAjLA1_2mb_fV9HsaO9SqOZsU-Lo1w_g9PHK5EtqieJMP0iT
1279            fNdJIk8HyzKZVDPccJI"#
1280            .replace(char::is_whitespace, "");
1281        let key = BarePrivateKey::from_jwt_rsa(&n, e, &d, &p, &q, &dp, &dq, &qinv).unwrap();
1282
1283        let json = serde_json::to_value(SerializedKey::Private(None, key)).unwrap();
1284        assert_eq!(json["kty"], "RSA");
1285        assert_eq!(json["n"], n);
1286        assert_eq!(json["e"], e);
1287        assert_eq!(json["d"], d);
1288        assert_eq!(json["p"], p);
1289        assert_eq!(json["q"], q);
1290        assert_eq!(json["dp"], dp);
1291        assert_eq!(json["dq"], dq);
1292        assert_eq!(json["qi"], qinv);
1293    }
1294
1295    #[test]
1296    fn test_hs256_key_generation() {
1297        let key = BarePrivateKey::generate(KeyType::HS256).unwrap();
1298        let pem = key.to_pem();
1299        println!("{pem}");
1300    }
1301
1302    #[test]
1303    fn test_es256_key_generation() {
1304        let key = BarePrivateKey::generate(KeyType::ES256).unwrap();
1305        let pem = key.to_pem();
1306        println!("{pem}");
1307        let key2 = BarePrivateKey::from_pem(&pem).expect("Failed to round-trip");
1308        println!("{}", key2.to_pem());
1309        assert_eq!(key, key2);
1310        assert_eq!(key.to_pem(), key2.to_pem());
1311    }
1312
1313    #[test]
1314    fn test_rs256_key_generation() {
1315        let key = BarePrivateKey::generate(KeyType::RS256).unwrap();
1316        let pem = key.to_pem();
1317        println!("{pem}");
1318        let key2 = BarePrivateKey::from_pem(&pem).expect("Failed to round-trip");
1319        println!("{}", key2.to_pem());
1320        assert_eq!(key, key2);
1321        assert_eq!(key.to_pem(), key2.to_pem());
1322    }
1323
1324    #[test]
1325    fn test_deserialize_private_keys() {
1326        let json = load_test_file("jwkset-prv.json");
1327        let keys: SerializedKeys = serde_json::from_str(&json).unwrap();
1328        println!("{keys:?}");
1329
1330        println!("{}", serde_json::to_string(&keys).unwrap());
1331    }
1332
1333    #[test]
1334    fn test_deserialize_public_keys() {
1335        let json = load_test_file("jwkset-pub.json");
1336        let keys: SerializedKeys = serde_json::from_str(&json).unwrap();
1337        println!("{keys:?}");
1338        println!("{}", serde_json::to_string(&keys).unwrap());
1339    }
1340}