1use std::{ops::Deref, sync::Arc};
18
19use der::{zeroize::Zeroizing, Decode, Encode, EncodePem};
20use elliptic_curve::{pkcs8::EncodePrivateKey, sec1::ToEncodedPoint};
21use mas_iana::jose::{JsonWebKeyType, JsonWebSignatureAlg};
22pub use mas_jose::jwk::{JsonWebKey, JsonWebKeySet};
23use mas_jose::{
24 jwa::{AsymmetricSigningKey, AsymmetricVerifyingKey},
25 jwk::{JsonWebKeyPublicParameters, ParametersInfo, PublicJsonWebKeySet},
26};
27use pem_rfc7468::PemLabel;
28use pkcs1::EncodeRsaPrivateKey;
29use pkcs8::{AssociatedOid, PrivateKeyInfo};
30use rand::{CryptoRng, RngCore};
31use rsa::BigUint;
32use thiserror::Error;
33
34mod encrypter;
35
36pub use aead;
37
38pub use self::encrypter::{DecryptError, Encrypter};
39
40#[derive(Debug, Error)]
42pub enum LoadError {
43 #[error("Failed to read PEM document")]
44 Pem {
45 #[from]
46 inner: pem_rfc7468::Error,
47 },
48
49 #[error("Invalid RSA private key")]
50 Rsa {
51 #[from]
52 inner: rsa::errors::Error,
53 },
54
55 #[error("Failed to decode PKCS1-encoded RSA key")]
56 Pkcs1 {
57 #[from]
58 inner: pkcs1::Error,
59 },
60
61 #[error("Failed to decode PKCS8-encoded key")]
62 Pkcs8 {
63 #[from]
64 inner: pkcs8::Error,
65 },
66
67 #[error(transparent)]
68 Der {
69 #[from]
70 inner: der::Error,
71 },
72
73 #[error(transparent)]
74 Spki {
75 #[from]
76 inner: spki::Error,
77 },
78
79 #[error("Unknown Elliptic Curve OID {oid}")]
80 UnknownEllipticCurveOid { oid: const_oid::ObjectIdentifier },
81
82 #[error("Unknown algorithm OID {oid}")]
83 UnknownAlgorithmOid { oid: const_oid::ObjectIdentifier },
84
85 #[error("Unsupported PEM label {label:?}")]
86 UnsupportedPemLabel { label: String },
87
88 #[error("Missing parameters in SEC1 key")]
89 MissingSec1Parameters,
90
91 #[error("Missing curve name in SEC1 parameters")]
92 MissingSec1CurveName,
93
94 #[error("Key is encrypted and no password was provided")]
95 Encrypted,
96
97 #[error("Key is not encrypted but a password was provided")]
98 Unencrypted,
99
100 #[error("Unsupported format")]
101 UnsupportedFormat,
102
103 #[error("Could not decode encrypted payload")]
104 InEncrypted {
105 #[source]
106 inner: Box<LoadError>,
107 },
108}
109
110impl LoadError {
111 #[must_use]
115 pub fn is_encrypted(&self) -> bool {
116 matches!(self, Self::Encrypted)
117 }
118
119 #[must_use]
123 pub fn is_unencrypted(&self) -> bool {
124 matches!(self, Self::Unencrypted)
125 }
126}
127
128#[non_exhaustive]
130#[derive(Debug)]
131pub enum PrivateKey {
132 Rsa(Box<rsa::RsaPrivateKey>),
133 EcP256(Box<elliptic_curve::SecretKey<p256::NistP256>>),
134 EcP384(Box<elliptic_curve::SecretKey<p384::NistP384>>),
135 EcK256(Box<elliptic_curve::SecretKey<k256::Secp256k1>>),
136}
137
138#[derive(Debug, Error)]
140#[error("Wrong algorithm for key")]
141pub struct WrongAlgorithmError;
142
143impl PrivateKey {
144 fn from_pkcs1_private_key(pkcs1_key: &pkcs1::RsaPrivateKey) -> Result<Self, LoadError> {
145 if pkcs1_key.version() != pkcs1::Version::TwoPrime {
149 return Err(pkcs1::Error::Version.into());
150 }
151
152 let n = BigUint::from_bytes_be(pkcs1_key.modulus.as_bytes());
153 let e = BigUint::from_bytes_be(pkcs1_key.public_exponent.as_bytes());
154 let d = BigUint::from_bytes_be(pkcs1_key.private_exponent.as_bytes());
155 let first_prime = BigUint::from_bytes_be(pkcs1_key.prime1.as_bytes());
156 let second_prime = BigUint::from_bytes_be(pkcs1_key.prime2.as_bytes());
157 let primes = vec![first_prime, second_prime];
158 let key = rsa::RsaPrivateKey::from_components(n, e, d, primes)?;
159 Ok(Self::Rsa(Box::new(key)))
160 }
161
162 fn from_private_key_info(info: PrivateKeyInfo) -> Result<Self, LoadError> {
163 match info.algorithm.oid {
164 pkcs1::ALGORITHM_OID => Ok(Self::Rsa(Box::new(info.try_into()?))),
165 elliptic_curve::ALGORITHM_OID => match info.algorithm.parameters_oid()? {
166 p256::NistP256::OID => Ok(Self::EcP256(Box::new(info.try_into()?))),
167 p384::NistP384::OID => Ok(Self::EcP384(Box::new(info.try_into()?))),
168 k256::Secp256k1::OID => Ok(Self::EcK256(Box::new(info.try_into()?))),
169 oid => Err(LoadError::UnknownEllipticCurveOid { oid }),
170 },
171 oid => Err(LoadError::UnknownAlgorithmOid { oid }),
172 }
173 }
174
175 fn from_ec_private_key(key: sec1::EcPrivateKey) -> Result<Self, LoadError> {
176 let curve = key
177 .parameters
178 .ok_or(LoadError::MissingSec1Parameters)?
179 .named_curve()
180 .ok_or(LoadError::MissingSec1CurveName)?;
181
182 match curve {
183 p256::NistP256::OID => Ok(Self::EcP256(Box::new(key.try_into()?))),
184 p384::NistP384::OID => Ok(Self::EcP384(Box::new(key.try_into()?))),
185 k256::Secp256k1::OID => Ok(Self::EcK256(Box::new(key.try_into()?))),
186 oid => Err(LoadError::UnknownEllipticCurveOid { oid }),
187 }
188 }
189
190 pub fn to_der(&self) -> Result<Zeroizing<Vec<u8>>, pkcs1::Error> {
199 let der = match self {
200 PrivateKey::Rsa(key) => key.to_pkcs1_der()?.to_bytes(),
201 PrivateKey::EcP256(key) => to_sec1_der(key)?,
202 PrivateKey::EcP384(key) => to_sec1_der(key)?,
203 PrivateKey::EcK256(key) => to_sec1_der(key)?,
204 };
205
206 Ok(der)
207 }
208
209 pub fn to_pkcs8_der(&self) -> Result<Zeroizing<Vec<u8>>, pkcs8::Error> {
215 let der = match self {
216 PrivateKey::Rsa(key) => key.to_pkcs8_der()?,
217 PrivateKey::EcP256(key) => key.to_pkcs8_der()?,
218 PrivateKey::EcP384(key) => key.to_pkcs8_der()?,
219 PrivateKey::EcK256(key) => key.to_pkcs8_der()?,
220 };
221
222 Ok(der.to_bytes())
223 }
224
225 pub fn to_pem(
234 &self,
235 line_ending: pem_rfc7468::LineEnding,
236 ) -> Result<Zeroizing<String>, pkcs1::Error> {
237 let pem = match self {
238 PrivateKey::Rsa(key) => key.to_pkcs1_pem(line_ending)?,
239 PrivateKey::EcP256(key) => to_sec1_pem(key, line_ending)?,
240 PrivateKey::EcP384(key) => to_sec1_pem(key, line_ending)?,
241 PrivateKey::EcK256(key) => to_sec1_pem(key, line_ending)?,
242 };
243
244 Ok(pem)
245 }
246
247 pub fn load(bytes: &[u8]) -> Result<Self, LoadError> {
254 if let Ok(pem) = std::str::from_utf8(bytes) {
255 match Self::load_pem(pem) {
256 Ok(s) => return Ok(s),
257 Err(LoadError::Pem { .. }) => {}
260 Err(e) => return Err(e),
261 }
262 }
263
264 Self::load_der(bytes)
265 }
266
267 pub fn load_encrypted(bytes: &[u8], password: impl AsRef<[u8]>) -> Result<Self, LoadError> {
275 if let Ok(pem) = std::str::from_utf8(bytes) {
276 match Self::load_encrypted_pem(pem, password.as_ref()) {
277 Ok(s) => return Ok(s),
278 Err(LoadError::Pem { .. }) => {}
281 Err(e) => return Err(e),
282 }
283 }
284
285 Self::load_encrypted_der(bytes, password)
286 }
287
288 pub fn load_encrypted_der(der: &[u8], password: impl AsRef<[u8]>) -> Result<Self, LoadError> {
298 if let Ok(info) = pkcs8::EncryptedPrivateKeyInfo::from_der(der) {
299 let decrypted = info.decrypt(password)?;
300 return Self::load_der(decrypted.as_bytes()).map_err(|inner| LoadError::InEncrypted {
301 inner: Box::new(inner),
302 });
303 }
304
305 if pkcs8::PrivateKeyInfo::from_der(der).is_ok()
306 || sec1::EcPrivateKey::from_der(der).is_ok()
307 || pkcs1::RsaPrivateKey::from_der(der).is_ok()
308 {
309 return Err(LoadError::Unencrypted);
310 }
311
312 Err(LoadError::UnsupportedFormat)
313 }
314
315 pub fn load_der(der: &[u8]) -> Result<Self, LoadError> {
327 if pkcs8::EncryptedPrivateKeyInfo::from_der(der).is_ok() {
329 return Err(LoadError::Encrypted);
330 }
331
332 if let Ok(info) = pkcs8::PrivateKeyInfo::from_der(der) {
333 return Self::from_private_key_info(info);
334 }
335
336 if let Ok(info) = sec1::EcPrivateKey::from_der(der) {
337 return Self::from_ec_private_key(info);
338 }
339
340 if let Ok(pkcs1_key) = pkcs1::RsaPrivateKey::from_der(der) {
341 return Self::from_pkcs1_private_key(&pkcs1_key);
342 }
343
344 Err(LoadError::UnsupportedFormat)
345 }
346
347 pub fn load_encrypted_pem(pem: &str, password: impl AsRef<[u8]>) -> Result<Self, LoadError> {
359 let (label, doc) = pem_rfc7468::decode_vec(pem.as_bytes())?;
360
361 match label {
362 pkcs8::EncryptedPrivateKeyInfo::PEM_LABEL => {
363 let info = pkcs8::EncryptedPrivateKeyInfo::from_der(&doc)?;
364 let decrypted = info.decrypt(password)?;
365 return Self::load_der(decrypted.as_bytes()).map_err(|inner| {
366 LoadError::InEncrypted {
367 inner: Box::new(inner),
368 }
369 });
370 }
371
372 pkcs1::RsaPrivateKey::PEM_LABEL
373 | pkcs8::PrivateKeyInfo::PEM_LABEL
374 | sec1::EcPrivateKey::PEM_LABEL => Err(LoadError::Unencrypted),
375
376 label => Err(LoadError::UnsupportedPemLabel {
377 label: label.to_owned(),
378 }),
379 }
380 }
381
382 pub fn load_pem(pem: &str) -> Result<Self, LoadError> {
393 let (label, doc) = pem_rfc7468::decode_vec(pem.as_bytes())?;
394
395 match label {
396 pkcs1::RsaPrivateKey::PEM_LABEL => {
397 let pkcs1_key = pkcs1::RsaPrivateKey::from_der(&doc)?;
398 Self::from_pkcs1_private_key(&pkcs1_key)
399 }
400
401 pkcs8::PrivateKeyInfo::PEM_LABEL => {
402 let info = pkcs8::PrivateKeyInfo::from_der(&doc)?;
403 Self::from_private_key_info(info)
404 }
405
406 sec1::EcPrivateKey::PEM_LABEL => {
407 let key = sec1::EcPrivateKey::from_der(&doc)?;
408 Self::from_ec_private_key(key)
409 }
410
411 pkcs8::EncryptedPrivateKeyInfo::PEM_LABEL => Err(LoadError::Encrypted),
412
413 label => Err(LoadError::UnsupportedPemLabel {
414 label: label.to_owned(),
415 }),
416 }
417 }
418
419 pub fn verifying_key_for_alg(
426 &self,
427 alg: &JsonWebSignatureAlg,
428 ) -> Result<AsymmetricVerifyingKey, WrongAlgorithmError> {
429 let key = match (self, alg) {
430 (Self::Rsa(key), _) => {
431 let key: rsa::RsaPublicKey = key.to_public_key();
432 match alg {
433 JsonWebSignatureAlg::Rs256 => AsymmetricVerifyingKey::rs256(key),
434 JsonWebSignatureAlg::Rs384 => AsymmetricVerifyingKey::rs384(key),
435 JsonWebSignatureAlg::Rs512 => AsymmetricVerifyingKey::rs512(key),
436 JsonWebSignatureAlg::Ps256 => AsymmetricVerifyingKey::ps256(key),
437 JsonWebSignatureAlg::Ps384 => AsymmetricVerifyingKey::ps384(key),
438 JsonWebSignatureAlg::Ps512 => AsymmetricVerifyingKey::ps512(key),
439 _ => return Err(WrongAlgorithmError),
440 }
441 }
442
443 (Self::EcP256(key), JsonWebSignatureAlg::Es256) => {
444 AsymmetricVerifyingKey::es256(key.public_key())
445 }
446
447 (Self::EcP384(key), JsonWebSignatureAlg::Es384) => {
448 AsymmetricVerifyingKey::es384(key.public_key())
449 }
450
451 (Self::EcK256(key), JsonWebSignatureAlg::Es256K) => {
452 AsymmetricVerifyingKey::es256k(key.public_key())
453 }
454
455 _ => return Err(WrongAlgorithmError),
456 };
457
458 Ok(key)
459 }
460
461 pub fn signing_key_for_alg(
468 &self,
469 alg: &JsonWebSignatureAlg,
470 ) -> Result<AsymmetricSigningKey, WrongAlgorithmError> {
471 let key = match (self, alg) {
472 (Self::Rsa(key), _) => {
473 let key: rsa::RsaPrivateKey = *key.clone();
474 match alg {
475 JsonWebSignatureAlg::Rs256 => AsymmetricSigningKey::rs256(key),
476 JsonWebSignatureAlg::Rs384 => AsymmetricSigningKey::rs384(key),
477 JsonWebSignatureAlg::Rs512 => AsymmetricSigningKey::rs512(key),
478 JsonWebSignatureAlg::Ps256 => AsymmetricSigningKey::ps256(key),
479 JsonWebSignatureAlg::Ps384 => AsymmetricSigningKey::ps384(key),
480 JsonWebSignatureAlg::Ps512 => AsymmetricSigningKey::ps512(key),
481 _ => return Err(WrongAlgorithmError),
482 }
483 }
484
485 (Self::EcP256(key), JsonWebSignatureAlg::Es256) => {
486 AsymmetricSigningKey::es256(*key.clone())
487 }
488
489 (Self::EcP384(key), JsonWebSignatureAlg::Es384) => {
490 AsymmetricSigningKey::es384(*key.clone())
491 }
492
493 (Self::EcK256(key), JsonWebSignatureAlg::Es256K) => {
494 AsymmetricSigningKey::es256k(*key.clone())
495 }
496
497 _ => return Err(WrongAlgorithmError),
498 };
499
500 Ok(key)
501 }
502
503 pub fn generate_rsa<R: RngCore + CryptoRng>(mut rng: R) -> Result<Self, rsa::errors::Error> {
509 let key = rsa::RsaPrivateKey::new(&mut rng, 2048)?;
510 Ok(Self::Rsa(Box::new(key)))
511 }
512
513 pub fn generate_ec_p256<R: RngCore + CryptoRng>(mut rng: R) -> Self {
515 let key = elliptic_curve::SecretKey::random(&mut rng);
516 Self::EcP256(Box::new(key))
517 }
518
519 pub fn generate_ec_p384<R: RngCore + CryptoRng>(mut rng: R) -> Self {
521 let key = elliptic_curve::SecretKey::random(&mut rng);
522 Self::EcP384(Box::new(key))
523 }
524
525 pub fn generate_ec_k256<R: RngCore + CryptoRng>(mut rng: R) -> Self {
527 let key = elliptic_curve::SecretKey::random(&mut rng);
528 Self::EcK256(Box::new(key))
529 }
530}
531
532fn to_sec1_der<C>(key: &elliptic_curve::SecretKey<C>) -> Result<Zeroizing<Vec<u8>>, der::Error>
536where
537 C: elliptic_curve::Curve + elliptic_curve::CurveArithmetic + AssociatedOid,
538 elliptic_curve::PublicKey<C>: elliptic_curve::sec1::ToEncodedPoint<C>,
539 C::FieldBytesSize: elliptic_curve::sec1::ModulusSize,
540{
541 let private_key_bytes = Zeroizing::new(key.to_bytes());
542 let public_key_bytes = key.public_key().to_encoded_point(false);
543 Ok(Zeroizing::new(
544 sec1::EcPrivateKey {
545 private_key: &private_key_bytes,
546 parameters: Some(sec1::EcParameters::NamedCurve(C::OID)),
547 public_key: Some(public_key_bytes.as_bytes()),
548 }
549 .to_der()?,
550 ))
551}
552
553fn to_sec1_pem<C>(
554 key: &elliptic_curve::SecretKey<C>,
555 line_ending: pem_rfc7468::LineEnding,
556) -> Result<Zeroizing<String>, der::Error>
557where
558 C: elliptic_curve::Curve + elliptic_curve::CurveArithmetic + AssociatedOid,
559 elliptic_curve::PublicKey<C>: elliptic_curve::sec1::ToEncodedPoint<C>,
560 C::FieldBytesSize: elliptic_curve::sec1::ModulusSize,
561{
562 let private_key_bytes = Zeroizing::new(key.to_bytes());
563 let public_key_bytes = key.public_key().to_encoded_point(false);
564 Ok(Zeroizing::new(
565 sec1::EcPrivateKey {
566 private_key: &private_key_bytes,
567 parameters: Some(sec1::EcParameters::NamedCurve(C::OID)),
568 public_key: Some(public_key_bytes.as_bytes()),
569 }
570 .to_pem(line_ending)?,
571 ))
572}
573
574impl From<&PrivateKey> for JsonWebKeyPublicParameters {
575 fn from(val: &PrivateKey) -> Self {
576 match val {
577 PrivateKey::Rsa(key) => key.to_public_key().into(),
578 PrivateKey::EcP256(key) => key.public_key().into(),
579 PrivateKey::EcP384(key) => key.public_key().into(),
580 PrivateKey::EcK256(key) => key.public_key().into(),
581 }
582 }
583}
584
585impl ParametersInfo for PrivateKey {
586 fn kty(&self) -> JsonWebKeyType {
587 match self {
588 PrivateKey::Rsa(_) => JsonWebKeyType::Rsa,
589 PrivateKey::EcP256(_) | PrivateKey::EcP384(_) | PrivateKey::EcK256(_) => {
590 JsonWebKeyType::Ec
591 }
592 }
593 }
594
595 fn possible_algs(&self) -> &'static [JsonWebSignatureAlg] {
596 match self {
597 PrivateKey::Rsa(_) => &[
598 JsonWebSignatureAlg::Rs256,
599 JsonWebSignatureAlg::Rs384,
600 JsonWebSignatureAlg::Rs512,
601 JsonWebSignatureAlg::Ps256,
602 JsonWebSignatureAlg::Ps384,
603 JsonWebSignatureAlg::Ps512,
604 ],
605 PrivateKey::EcP256(_) => &[JsonWebSignatureAlg::Es256],
606 PrivateKey::EcP384(_) => &[JsonWebSignatureAlg::Es384],
607 PrivateKey::EcK256(_) => &[JsonWebSignatureAlg::Es256K],
608 }
609 }
610}
611
612#[derive(Clone, Default)]
616pub struct Keystore {
617 keys: Arc<JsonWebKeySet<PrivateKey>>,
618}
619
620impl Keystore {
621 #[must_use]
641 pub fn new(keys: JsonWebKeySet<PrivateKey>) -> Self {
642 let keys = Arc::new(keys);
643 Self { keys }
644 }
645
646 #[must_use]
648 pub fn public_jwks(&self) -> PublicJsonWebKeySet {
649 self.keys
650 .iter()
651 .map(|key| {
652 key.cloned_map(|params: &PrivateKey| JsonWebKeyPublicParameters::from(params))
653 })
654 .collect()
655 }
656}
657
658impl Deref for Keystore {
659 type Target = JsonWebKeySet<PrivateKey>;
660
661 fn deref(&self) -> &Self::Target {
662 &self.keys
663 }
664}