1#![deny(rust_2018_idioms, unreachable_pub)]
2#![forbid(unsafe_code)]
3
4mod byte_array;
66mod byte_vec;
67mod key_ops;
68#[cfg(test)]
69mod tests;
70mod utils;
71
72use std::{borrow::Cow, fmt};
73
74use generic_array::typenum::{U32, U48};
75use serde::{Deserialize, Serialize};
76
77pub use byte_array::ByteArray;
78pub use byte_vec::ByteVec;
79pub use key_ops::KeyOps;
80
81#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
82pub struct JsonWebKey {
83 #[serde(flatten)]
84 pub key: Box<Key>,
85
86 #[serde(default, rename = "use", skip_serializing_if = "Option::is_none")]
87 pub key_use: Option<KeyUse>,
88
89 #[serde(default, skip_serializing_if = "KeyOps::is_empty")]
90 pub key_ops: KeyOps,
91
92 #[serde(default, rename = "kid", skip_serializing_if = "Option::is_none")]
93 pub key_id: Option<String>,
94
95 #[serde(default, rename = "alg", skip_serializing_if = "Option::is_none")]
96 pub algorithm: Option<Algorithm>,
97
98 #[serde(default, flatten, skip_serializing_if = "X509Params::is_empty")]
99 pub x5: X509Params,
100}
101
102#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
103pub struct X509Params {
104 #[serde(default, rename = "x5u", skip_serializing_if = "Option::is_none")]
106 url: Option<String>,
107
108 #[serde(default, rename = "x5c", skip_serializing_if = "Option::is_none")]
110 cert_chain: Option<Vec<String>>,
111
112 #[serde(default, rename = "x5t", skip_serializing_if = "Option::is_none")]
114 thumbprint: Option<String>,
115
116 #[serde(default, rename = "x5t#S256", skip_serializing_if = "Option::is_none")]
118 thumbprint_sha256: Option<String>,
119}
120
121impl X509Params {
122 fn is_empty(&self) -> bool {
123 matches!(
124 self,
125 X509Params {
126 url: None,
127 cert_chain: None,
128 thumbprint: None,
129 thumbprint_sha256: None,
130 }
131 )
132 }
133}
134
135impl JsonWebKey {
136 pub fn new(key: Key) -> Self {
137 Self {
138 key: Box::new(key),
139 key_use: None,
140 key_ops: KeyOps::empty(),
141 key_id: None,
142 algorithm: None,
143 x5: Default::default(),
144 }
145 }
146
147 pub fn set_algorithm(&mut self, alg: Algorithm) -> Result<(), Error> {
148 Self::validate_algorithm(alg, &self.key)?;
149 self.algorithm = Some(alg);
150 Ok(())
151 }
152
153 pub fn from_slice(bytes: impl AsRef<[u8]>) -> Result<Self, Error> {
154 Ok(serde_json::from_slice(bytes.as_ref())?)
155 }
156
157 fn validate_algorithm(alg: Algorithm, key: &Key) -> Result<(), Error> {
158 use Algorithm::*;
159 use Key::*;
160 match (alg, key) {
161 (
162 ES256,
163 EC {
164 curve: Curve::P256 { .. },
165 ..
166 },
167 )
168 | (
169 ES384,
170 EC {
171 curve: Curve::P384 { .. },
172 },
173 )
174 | (RS256, RSA { .. })
175 | (RS384, RSA { .. })
176 | (RS512, RSA { .. })
177 | (HS256, Symmetric { .. }) => Ok(()),
178 (HS384, Symmetric { .. }) => Ok(()),
179 (HS512, Symmetric { .. }) => Ok(()),
180 _ => Err(Error::MismatchedAlgorithm),
181 }
182 }
183}
184
185impl std::str::FromStr for JsonWebKey {
186 type Err = Error;
187 fn from_str(json: &str) -> Result<Self, Self::Err> {
188 let jwk = Self::from_slice(json.as_bytes())?;
189
190 let alg = match jwk.algorithm {
191 Some(alg) => alg,
192 None => return Ok(jwk),
193 };
194 Self::validate_algorithm(alg, &jwk.key).map(|_| jwk)
195 }
196}
197
198impl std::fmt::Display for JsonWebKey {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 if f.alternate() {
201 write!(f, "{}", serde_json::to_string_pretty(self).unwrap())
202 } else {
203 write!(f, "{}", serde_json::to_string(self).unwrap())
204 }
205 }
206}
207
208#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
209#[serde(tag = "kty")]
210#[allow(clippy::upper_case_acronyms)]
211pub enum Key {
212 EC {
214 #[serde(flatten)]
215 curve: Curve,
216 },
217 RSA {
220 #[serde(flatten)]
221 public: RsaPublic,
222 #[serde(flatten, default, skip_serializing_if = "Option::is_none")]
223 private: Option<RsaPrivate>,
224 },
225 #[serde(rename = "oct")]
227 Symmetric {
228 #[serde(rename = "k")]
229 key: ByteVec,
230 },
231}
232
233#[cfg(feature = "thumbprint")]
234impl Key {
235 pub fn thumbprint(&self) -> String {
238 self.try_thumbprint_using_hasher::<sha2::Sha256>().unwrap()
239 }
240
241 pub fn try_thumbprint_using_hasher<H: sha2::digest::Digest>(
244 &self,
245 ) -> Result<String, serde_json::Error> {
246 use serde::ser::{SerializeStruct, Serializer};
247 let mut s = serde_json::Serializer::new(Vec::new());
248 match self {
249 Self::EC {
250 curve: curve @ Curve::P256 { x, y, .. },
251 } => {
252 let mut ss = s.serialize_struct("", 4)?;
253 ss.serialize_field("crv", curve.name())?;
254 ss.serialize_field("kty", "EC")?;
255 ss.serialize_field("x", x)?;
256 ss.serialize_field("y", y)?;
257 ss.end()?;
258 }
259 Self::EC {
260 curve: curve @ Curve::P384 { x, y, .. },
261 } => {
262 let mut ss = s.serialize_struct("", 4)?;
263 ss.serialize_field("crv", curve.name())?;
264 ss.serialize_field("kty", "EC")?;
265 ss.serialize_field("x", x)?;
266 ss.serialize_field("y", y)?;
267 ss.end()?;
268 }
269 Self::RSA {
270 public: RsaPublic { e, n },
271 ..
272 } => {
273 let mut ss = s.serialize_struct("", 3)?;
274 ss.serialize_field("e", e)?;
275 ss.serialize_field("kty", "RSA")?;
276 ss.serialize_field("n", n)?;
277 ss.end()?;
278 }
279 Self::Symmetric { key } => {
280 let mut ss = s.serialize_struct("", 2)?;
281 ss.serialize_field("k", key)?;
282 ss.serialize_field("kty", "oct")?;
283 ss.end()?;
284 }
285 }
286 Ok(crate::utils::base64_encode(H::digest(s.into_inner())))
287 }
288}
289
290impl Key {
291 pub fn is_private(&self) -> bool {
294 matches!(
295 self,
296 Self::Symmetric { .. }
297 | Self::EC {
298 curve: Curve::P256 { d: Some(_), .. },
299 ..
300 }
301 | Self::EC {
302 curve: Curve::P384 { d: Some(_), .. },
303 ..
304 }
305 | Self::RSA {
306 private: Some(_),
307 ..
308 }
309 )
310 }
311
312 pub fn to_public(&self) -> Option<Cow<'_, Self>> {
314 if !self.is_private() {
315 return Some(Cow::Borrowed(self));
316 }
317 Some(Cow::Owned(match self {
318 Self::Symmetric { .. } => return None,
319 Self::EC {
320 curve: Curve::P256 { x, y, .. },
321 } => Self::EC {
322 curve: Curve::P256 {
323 x: x.clone(),
324 y: y.clone(),
325 d: None,
326 },
327 },
328 Self::EC {
329 curve: Curve::P384 { x, y, .. },
330 } => Self::EC {
331 curve: Curve::P384 {
332 x: x.clone(),
333 y: y.clone(),
334 d: None,
335 },
336 },
337 Self::RSA { public, .. } => Self::RSA {
338 public: public.clone(),
339 private: None,
340 },
341 }))
342 }
343
344 #[cfg(feature = "pkcs-convert")]
346 pub fn try_to_der(&self) -> Result<Vec<u8>, ConversionError> {
347 use num_bigint::BigUint;
348 use yasna::{models::ObjectIdentifier, DERWriter, DERWriterSeq, Tag};
349
350 use crate::utils::pkcs8;
351
352 if let Self::Symmetric { .. } = self {
353 return Err(ConversionError::NotAsymmetric);
354 }
355
356 Ok(match self {
357 Self::EC {
358 curve: Curve::P256 { d, x, y },
359 } => {
360 let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
361 let prime256v1_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 3, 1, 7]);
362 let oids = &[Some(&ec_public_oid), Some(&prime256v1_oid)];
363
364 let write_public = |writer: DERWriter<'_>| {
365 let public_bytes: Vec<u8> = [0x04 ]
366 .iter()
367 .chain(x.iter())
368 .chain(y.iter())
369 .copied()
370 .collect();
371 writer.write_bitvec_bytes(&public_bytes, 8 * (32 * 2 + 1));
372 };
373
374 match d {
375 Some(private_point) => {
376 pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
377 writer.next().write_i8(1); writer.next().write_bytes(private_point);
379 writer.next().write_tagged(Tag::context(1), write_public);
386 })
387 }
388 None => pkcs8::write_public(oids, write_public),
389 }
390 }
391 Self::EC {
392 curve: Curve::P384 { d, x, y },
393 } => {
394 let ec_public_oid = ObjectIdentifier::from_slice(&[1, 2, 840, 10045, 2, 1]);
395 let prime384v1_oid = ObjectIdentifier::from_slice(&[1, 3, 132, 0, 34]);
396 let oids = &[Some(&ec_public_oid), Some(&prime384v1_oid)];
397
398 let write_public = |writer: DERWriter<'_>| {
399 let public_bytes: Vec<u8> = [0x04 ]
400 .iter()
401 .chain(x.iter())
402 .chain(y.iter())
403 .copied()
404 .collect();
405 writer.write_bitvec_bytes(&public_bytes, 8 * (48 * 2 + 1));
406 };
407
408 match d {
409 Some(private_point) => {
410 pkcs8::write_private(oids, |writer: &mut DERWriterSeq<'_>| {
411 writer.next().write_i8(1); writer.next().write_bytes(private_point);
413 writer.next().write_tagged(Tag::context(1), write_public);
414 })
415 }
416 None => pkcs8::write_public(oids, write_public),
417 }
418 }
419 Self::RSA { public, private } => {
420 let rsa_encryption_oid = ObjectIdentifier::from_slice(&[
421 1, 2, 840, 113549, 1, 1, 1, ]);
423 let oids = &[Some(&rsa_encryption_oid), None];
424 let write_bytevec = |writer: DERWriter<'_>, vec: &ByteVec| {
425 let bigint = BigUint::from_bytes_be(vec);
426 writer.write_biguint(&bigint);
427 };
428
429 let write_public = |writer: &mut DERWriterSeq<'_>| {
430 write_bytevec(writer.next(), &public.n);
431 writer.next().write_u32(PUBLIC_EXPONENT);
432 };
433
434 let write_private = |writer: &mut DERWriterSeq<'_>, private: &RsaPrivate| {
435 writer.next().write_i8(0); write_public(writer);
438 write_bytevec(writer.next(), &private.d);
439 macro_rules! write_opt_bytevecs {
440 ($($param:ident),+) => {{
441 $(write_bytevec(writer.next(), private.$param.as_ref().unwrap());)+
442 }};
443 }
444 write_opt_bytevecs!(p, q, dp, dq, qi);
445 };
446
447 match private {
448 Some(
449 private @ RsaPrivate {
450 d: _,
451 p: Some(_),
452 q: Some(_),
453 dp: Some(_),
454 dq: Some(_),
455 qi: Some(_),
456 },
457 ) => pkcs8::write_private(oids, |writer| write_private(writer, private)),
458 Some(_) => return Err(ConversionError::MissingRsaParams),
459 None => pkcs8::write_public(oids, |writer| {
460 let body =
461 yasna::construct_der(|writer| writer.write_sequence(write_public));
462 writer.write_bitvec_bytes(&body, body.len() * 8);
463 }),
464 }
465 }
466 Self::Symmetric { .. } => unreachable!("checked above"),
467 })
468 }
469
470 #[cfg(feature = "pkcs-convert")]
473 pub fn to_der(&self) -> Vec<u8> {
474 self.try_to_der().unwrap()
475 }
476
477 #[cfg(feature = "pkcs-convert")]
479 pub fn try_to_pem(&self) -> Result<String, ConversionError> {
480 use base64::{engine::general_purpose::STANDARD, Engine};
481 use std::fmt::Write;
482 let der_b64 = STANDARD.encode(self.try_to_der()?);
483 let key_ty = if self.is_private() {
484 "PRIVATE"
485 } else {
486 "PUBLIC"
487 };
488 let mut pem = String::new();
489 writeln!(&mut pem, "-----BEGIN {} KEY-----", key_ty).unwrap();
490 const MAX_LINE_LEN: usize = 64;
492 for i in (0..der_b64.len()).step_by(MAX_LINE_LEN) {
493 writeln!(
494 &mut pem,
495 "{}",
496 &der_b64[i..std::cmp::min(i + MAX_LINE_LEN, der_b64.len())]
497 )
498 .unwrap();
499 }
500 writeln!(&mut pem, "-----END {} KEY-----", key_ty).unwrap();
501 Ok(pem)
502 }
503
504 #[cfg(feature = "pkcs-convert")]
507 pub fn to_pem(&self) -> String {
508 self.try_to_pem().unwrap()
509 }
510
511 #[cfg(feature = "generate")]
514 pub fn generate_symmetric(num_bits: usize) -> Self {
515 use rand::RngCore;
516 let mut bytes = vec![0; num_bits / 8];
517 rand::thread_rng().fill_bytes(&mut bytes);
518 Self::Symmetric { key: bytes.into() }
519 }
520
521 #[cfg(feature = "generate")]
524 pub fn generate_p256() -> Self {
525 use p256::elliptic_curve::{self as elliptic_curve, sec1::ToEncodedPoint};
526
527 let sk = elliptic_curve::SecretKey::random(&mut rand::thread_rng());
528 let sk_scalar = p256::Scalar::from(&sk);
529
530 let pk = p256::ProjectivePoint::GENERATOR * sk_scalar;
531 let pk_bytes = &pk
532 .to_affine()
533 .to_encoded_point(false )
534 .to_bytes()[1..];
535 let (x_bytes, y_bytes) = pk_bytes.split_at(32);
536
537 Self::EC {
538 curve: Curve::P256 {
539 d: Some(sk_scalar.to_bytes().into()),
540 x: ByteArray::from_slice(x_bytes),
541 y: ByteArray::from_slice(y_bytes),
542 },
543 }
544 }
545}
546
547#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
548#[serde(tag = "crv")]
549pub enum Curve {
550 #[serde(rename = "P-256")]
552 P256 {
553 #[serde(skip_serializing_if = "Option::is_none")]
555 d: Option<ByteArray<U32>>,
556 x: ByteArray<U32>,
558 y: ByteArray<U32>,
560 },
561 #[serde(rename = "P-384")]
563 P384 {
564 #[serde(skip_serializing_if = "Option::is_none")]
566 d: Option<ByteArray<U48>>,
567 x: ByteArray<U48>,
569 y: ByteArray<U48>,
571 },
572}
573
574impl Curve {
575 pub fn name(&self) -> &'static str {
576 match self {
577 Self::P256 { .. } => "P-256",
578 Self::P384 { .. } => "P-256",
579 }
580 }
581}
582
583impl fmt::Display for Curve {
584 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
585 match self {
586 Self::P256 { x, y, .. } => f
587 .debug_struct("Curve::P256")
588 .field("x", x)
589 .field("y", y)
590 .finish(),
591 Self::P384 { x, y, .. } => f
592 .debug_struct("Curve::P384")
593 .field("x", x)
594 .field("y", y)
595 .finish(),
596 }
597 }
598}
599
600#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
601pub struct RsaPublic {
602 pub e: PublicExponent,
604 pub n: ByteVec,
606}
607
608const PUBLIC_EXPONENT: u32 = 65537;
609const PUBLIC_EXPONENT_B64: &str = "AQAB"; const PUBLIC_EXPONENT_B64_PADDED: &str = "AQABAA==";
611
612#[derive(Clone, Copy, Debug, PartialEq, Eq)]
614pub struct PublicExponent;
615
616impl Serialize for PublicExponent {
617 fn serialize<S: serde::ser::Serializer>(&self, s: S) -> Result<S::Ok, S::Error> {
618 PUBLIC_EXPONENT_B64.serialize(s)
619 }
620}
621
622impl<'de> Deserialize<'de> for PublicExponent {
623 fn deserialize<D: serde::de::Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
624 let e = String::deserialize(d)?;
625 if e == PUBLIC_EXPONENT_B64 || e == PUBLIC_EXPONENT_B64_PADDED {
626 Ok(Self)
627 } else {
628 Err(serde::de::Error::custom(&format!(
629 "public exponent must be {}",
630 PUBLIC_EXPONENT
631 )))
632 }
633 }
634}
635
636#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
637pub struct RsaPrivate {
638 pub d: ByteVec,
640 #[serde(default, skip_serializing_if = "Option::is_none")]
642 pub p: Option<ByteVec>,
643 #[serde(default, skip_serializing_if = "Option::is_none")]
645 pub q: Option<ByteVec>,
646 #[serde(default, skip_serializing_if = "Option::is_none")]
648 pub dp: Option<ByteVec>,
649 #[serde(default, skip_serializing_if = "Option::is_none")]
651 pub dq: Option<ByteVec>,
652 #[serde(default, skip_serializing_if = "Option::is_none")]
654 pub qi: Option<ByteVec>,
655}
656
657impl fmt::Debug for RsaPrivate {
658 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
659 f.write_str("RsaPrivate")
660 }
661}
662
663#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
664pub enum KeyUse {
665 #[serde(rename = "sig")]
666 Signing,
667 #[serde(rename = "enc")]
668 Encryption,
669}
670
671#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)]
672#[allow(clippy::upper_case_acronyms)]
673pub enum Algorithm {
674 HS256,
675 HS384,
676 HS512,
677 RS256,
678 RS384,
679 RS512,
680 ES256,
681 ES384,
682}
683
684impl Algorithm {
685 pub fn name(&self) -> &'static str {
686 match self {
687 Self::HS256 => "hs256",
688 Self::HS384 => "hs384",
689 Self::HS512 => "hs512",
690 Self::RS256 => "rs256",
691 Self::RS384 => "rs384",
692 Self::RS512 => "rs512",
693 Self::ES256 => "es256",
694 Self::ES384 => "es384",
695 }
696 }
697}
698
699#[cfg(feature = "jwt-convert")]
700const _IMPL_JWT_CONVERSIONS: () = {
701 use jsonwebtoken as jwt;
702
703 impl From<Algorithm> for jwt::Algorithm {
704 fn from(alg: Algorithm) -> Self {
705 match alg {
706 Algorithm::HS256 => Self::HS256,
707 Algorithm::HS384 => Self::HS384,
708 Algorithm::HS512 => Self::HS512,
709 Algorithm::ES256 => Self::ES256,
710 Algorithm::ES384 => Self::ES384,
711 Algorithm::RS256 => Self::RS256,
712 Algorithm::RS384 => Self::RS384,
713 Algorithm::RS512 => Self::RS512,
714 }
715 }
716 }
717
718 impl Key {
719 pub fn try_to_encoding_key(&self) -> Result<jwt::EncodingKey, ConversionError> {
721 if !self.is_private() {
722 return Err(ConversionError::NotPrivate);
723 }
724 Ok(match self {
725 Self::Symmetric { key } => jwt::EncodingKey::from_secret(key),
726 Self::EC { .. } => {
729 jwt::EncodingKey::from_ec_pem(self.try_to_pem()?.as_bytes()).unwrap()
730 }
731 Self::RSA { .. } => {
732 jwt::EncodingKey::from_rsa_pem(self.try_to_pem()?.as_bytes()).unwrap()
733 }
734 })
735 }
736
737 pub fn to_encoding_key(&self) -> jwt::EncodingKey {
739 self.try_to_encoding_key().unwrap()
740 }
741
742 pub fn to_decoding_key(&self) -> jwt::DecodingKey {
743 match self {
744 Self::Symmetric { key } => jwt::DecodingKey::from_secret(key),
745 Self::EC { .. } => {
746 jwt::DecodingKey::from_ec_pem(self.to_public().unwrap().to_pem().as_bytes())
750 .unwrap()
751 }
752 Self::RSA { .. } => {
753 jwt::DecodingKey::from_rsa_pem(self.to_public().unwrap().to_pem().as_bytes())
754 .unwrap()
755 }
756 }
757 }
758 }
759};
760
761#[derive(Debug, thiserror::Error)]
762pub enum Error {
763 #[error(transparent)]
764 Serde(#[from] serde_json::Error),
765
766 #[error(transparent)]
767 Base64Decode(#[from] base64::DecodeError),
768
769 #[error("mismatched algorithm for key type")]
770 MismatchedAlgorithm,
771}
772
773#[derive(Debug, thiserror::Error)]
774pub enum ConversionError {
775 #[error("encoding RSA JWK as PKCS#8 requires specifing all of p, q, dp, dq, qi")]
776 MissingRsaParams,
777
778 #[error("a symmetric key can not be encoded using PKCS#8")]
779 NotAsymmetric,
780
781 #[cfg(feature = "jwt-convert")]
782 #[error("a public key cannot be converted to a `jsonwebtoken::EncodingKey`")]
783 NotPrivate,
784}