1use crate::drbg::HmacDrbgSha256;
19use crate::hash::{noxtls_sha1, noxtls_sha256, noxtls_sha384, noxtls_sha512};
20use crate::internal_alloc::Vec;
21use noxtls_core::{Error, Result};
22
23use super::bignum::BigUint;
24
25const RSA_KEYGEN_MIN_BITS: usize = 1024;
26const RSA_KEYGEN_MAX_BITS: usize = 4096;
27const RSA_MIN_SECURE_BITS: usize = 2048;
28const RSA_RECOMMENDED_SECURE_BITS: usize = 3072;
29
30#[derive(Debug, Clone)]
32pub struct RsaPrivateKey {
33 pub n: BigUint,
34 pub d: BigUint,
35 crt: Option<RsaPrivateCrtComponents>,
36}
37
38#[derive(Debug, Clone)]
40pub struct RsaPublicKey {
41 pub n: BigUint,
42 pub e: BigUint,
43}
44
45#[derive(Debug, Clone)]
47struct RsaPrivateCrtComponents {
48 p: BigUint,
49 q: BigUint,
50 dp: BigUint,
51 dq: BigUint,
52 qinv: BigUint,
53}
54
55#[derive(Debug, Copy, Clone, Eq, PartialEq)]
57pub enum RsaPssHashAlgorithm {
58 Sha1,
59 Sha256,
60 Sha384,
61 Sha512,
62}
63
64#[derive(Debug, Copy, Clone, Eq, PartialEq)]
65pub enum RsaKeySizePolicy {
66 Minimum2048,
68 Minimum3072,
70}
71
72impl RsaKeySizePolicy {
73 fn min_bits(self) -> usize {
87 match self {
88 Self::Minimum2048 => RSA_MIN_SECURE_BITS,
89 Self::Minimum3072 => RSA_RECOMMENDED_SECURE_BITS,
90 }
91 }
92}
93
94impl RsaPrivateKey {
95 pub fn from_be_bytes(n: &[u8], d: &[u8]) -> Result<Self> {
110 if n.is_empty() || d.is_empty() {
111 return Err(Error::InvalidLength(
112 "rsa private key fields must not be empty",
113 ));
114 }
115 let key = Self {
116 n: BigUint::from_be_bytes(n),
117 d: BigUint::from_be_bytes(d),
118 crt: None,
119 };
120 if !cfg!(feature = "hazardous-legacy-crypto") && key.n.bit_len() < RSA_MIN_SECURE_BITS {
121 return Err(Error::InvalidLength(
122 "rsa private key modulus must be at least 2048 bits",
123 ));
124 }
125 validate_private_components(&key.n, &key.d)?;
126 Ok(key)
127 }
128
129 #[must_use]
138 pub fn from_u128(n: u128, d: u128) -> Self {
139 Self {
140 n: BigUint::from_u128(n),
141 d: BigUint::from_u128(d),
142 crt: None,
143 }
144 }
145
146 pub fn clear(&mut self) {
151 self.n.clear();
152 self.d.clear();
153 if let Some(crt) = self.crt.as_mut() {
154 crt.p.clear();
155 crt.q.clear();
156 crt.dp.clear();
157 crt.dq.clear();
158 crt.qinv.clear();
159 }
160 self.crt = None;
161 }
162
163 pub fn with_crt_components(
175 mut self,
176 p: &[u8],
177 q: &[u8],
178 dp: &[u8],
179 dq: &[u8],
180 qinv: &[u8],
181 ) -> Result<Self> {
182 let crt = RsaPrivateCrtComponents {
183 p: BigUint::from_be_bytes(p),
184 q: BigUint::from_be_bytes(q),
185 dp: BigUint::from_be_bytes(dp),
186 dq: BigUint::from_be_bytes(dq),
187 qinv: BigUint::from_be_bytes(qinv),
188 };
189 validate_crt_components(&self.n, &crt)?;
190 self.crt = Some(crt);
191 Ok(self)
192 }
193
194 pub fn sign_digest(&self, digest: &[u8]) -> Result<Vec<u8>> {
202 if digest.is_empty() {
203 return Err(Error::InvalidLength("digest must not be empty"));
204 }
205 validate_private_components(&self.n, &self.d)?;
206 let m = BigUint::from_be_bytes(digest).modulo(&self.n);
207 let s = BigUint::mod_exp(&m, &self.d, &self.n);
208 s.to_be_bytes_padded(self.modulus_len())
209 }
210
211 pub fn sign_pkcs1_v15_sha256(&self, msg: &[u8]) -> Result<Vec<u8>> {
219 validate_private_components(&self.n, &self.d)?;
220 let hash = noxtls_sha256(msg);
221 let em = emsa_pkcs1_v15_encode(
222 &hash,
223 PKCS1_V15_DIGESTINFO_SHA256_PREFIX,
224 self.modulus_len(),
225 )?;
226 let m = BigUint::from_be_bytes(&em);
227 let s = BigUint::mod_exp(&m, &self.d, &self.n);
228 s.to_be_bytes_padded(self.modulus_len())
229 }
230
231 pub fn sign_pkcs1_v15_sha1(&self, msg: &[u8]) -> Result<Vec<u8>> {
239 validate_private_components(&self.n, &self.d)?;
240 let hash = noxtls_sha1(msg);
241 let em =
242 emsa_pkcs1_v15_encode(&hash, PKCS1_V15_DIGESTINFO_SHA1_PREFIX, self.modulus_len())?;
243 let m = BigUint::from_be_bytes(&em);
244 let s = BigUint::mod_exp(&m, &self.d, &self.n);
245 s.to_be_bytes_padded(self.modulus_len())
246 }
247
248 pub fn sign_pkcs1_v15_sha384(&self, msg: &[u8]) -> Result<Vec<u8>> {
256 validate_private_components(&self.n, &self.d)?;
257 let hash = noxtls_sha384(msg);
258 let em = emsa_pkcs1_v15_encode(
259 &hash,
260 PKCS1_V15_DIGESTINFO_SHA384_PREFIX,
261 self.modulus_len(),
262 )?;
263 let m = BigUint::from_be_bytes(&em);
264 let s = BigUint::mod_exp(&m, &self.d, &self.n);
265 s.to_be_bytes_padded(self.modulus_len())
266 }
267
268 pub fn sign_pkcs1_v15_sha512(&self, msg: &[u8]) -> Result<Vec<u8>> {
276 validate_private_components(&self.n, &self.d)?;
277 let hash = noxtls_sha512(msg);
278 let em = emsa_pkcs1_v15_encode(
279 &hash,
280 PKCS1_V15_DIGESTINFO_SHA512_PREFIX,
281 self.modulus_len(),
282 )?;
283 let m = BigUint::from_be_bytes(&em);
284 let s = BigUint::mod_exp(&m, &self.d, &self.n);
285 s.to_be_bytes_padded(self.modulus_len())
286 }
287
288 pub fn sign_pss_sha256(&self, msg: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
297 validate_private_components(&self.n, &self.d)?;
298 let em_bits = self.n.bit_len().saturating_sub(1);
299 let em_len = em_bits.div_ceil(8);
300 let m_hash = noxtls_sha256(msg);
301 let em = emsa_pss_encode_sha256(&m_hash, salt, em_bits, em_len)?;
302 let s = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.d, &self.n);
303 s.to_be_bytes_padded(self.modulus_len())
304 }
305
306 pub fn sign_pss_sha384(&self, msg: &[u8], salt: &[u8]) -> Result<Vec<u8>> {
315 validate_private_components(&self.n, &self.d)?;
316 let em_bits = self.n.bit_len().saturating_sub(1);
317 let em_len = em_bits.div_ceil(8);
318 let m_hash = noxtls_sha384(msg);
319 let em = emsa_pss_encode_sha384(&m_hash, salt, em_bits, em_len)?;
320 let s = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.d, &self.n);
321 s.to_be_bytes_padded(self.modulus_len())
322 }
323
324 pub fn sign_pss_with_hashes(
325 &self,
326 msg: &[u8],
327 salt: &[u8],
328 message_hash: RsaPssHashAlgorithm,
329 mgf_hash: RsaPssHashAlgorithm,
330 ) -> Result<Vec<u8>> {
331 validate_private_components(&self.n, &self.d)?;
332 let em_bits = self.n.bit_len().saturating_sub(1);
333 let em_len = em_bits.div_ceil(8);
334 let m_hash = rsa_pss_hash(message_hash, msg);
335 let em = emsa_pss_encode(&m_hash, salt, em_bits, em_len, message_hash, mgf_hash)?;
336 let s = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.d, &self.n);
337 s.to_be_bytes_padded(self.modulus_len())
338 }
339
340 pub fn decrypt_pkcs1_v15(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
348 validate_private_components(&self.n, &self.d)?;
349 if ciphertext.len() != self.modulus_len() {
350 return Err(Error::CryptoFailure("rsa decryption failed"));
351 }
352 let em = BigUint::mod_exp(&BigUint::from_be_bytes(ciphertext), &self.d, &self.n)
353 .to_be_bytes_padded(self.modulus_len())?;
354 decode_pkcs1_v15_plaintext(&em)
355 }
356
357 pub fn decrypt_pkcs1_v15_crt_only(&self, ciphertext: &[u8]) -> Result<Vec<u8>> {
365 validate_private_components(&self.n, &self.d)?;
366 if ciphertext.len() != self.modulus_len() {
367 return Err(Error::CryptoFailure("rsa decryption failed"));
368 }
369 let crt = self
370 .crt
371 .as_ref()
372 .ok_or(Error::StateError("rsa crt parameters are not configured"))?;
373 let c = BigUint::from_be_bytes(ciphertext);
374 let m1 = BigUint::mod_exp(&c, &crt.dp, &crt.p);
375 let m2 = BigUint::mod_exp(&c, &crt.dq, &crt.q);
376 let diff = if m1.cmp(&m2).is_ge() {
377 m1.sub(&m2)
378 } else {
379 m1.add(&crt.p).sub(&m2)
380 };
381 let h = crt.qinv.mul(&diff).modulo(&crt.p);
382 let m = m2.add(&crt.q.mul(&h));
383 let em = m.to_be_bytes_padded(self.modulus_len())?;
384 decode_pkcs1_v15_plaintext(&em)
385 }
386
387 pub fn decrypt_oaep_sha256(&self, ciphertext: &[u8], label: &[u8]) -> Result<Vec<u8>> {
396 validate_private_components(&self.n, &self.d)?;
397 if ciphertext.len() != self.modulus_len() {
398 return Err(Error::CryptoFailure("rsa decryption failed"));
399 }
400 let em = BigUint::mod_exp(&BigUint::from_be_bytes(ciphertext), &self.d, &self.n)
401 .to_be_bytes_padded(self.modulus_len())?;
402 decode_oaep_sha256_plaintext(&em, label)
403 }
404
405 pub fn decrypt_oaep_sha256_crt_only(&self, ciphertext: &[u8], label: &[u8]) -> Result<Vec<u8>> {
414 validate_private_components(&self.n, &self.d)?;
415 if ciphertext.len() != self.modulus_len() {
416 return Err(Error::CryptoFailure("rsa decryption failed"));
417 }
418 let crt = self
419 .crt
420 .as_ref()
421 .ok_or(Error::StateError("rsa crt parameters are not configured"))?;
422 let c = BigUint::from_be_bytes(ciphertext);
423 let m1 = BigUint::mod_exp(&c, &crt.dp, &crt.p);
424 let m2 = BigUint::mod_exp(&c, &crt.dq, &crt.q);
425 let diff = if m1.cmp(&m2).is_ge() {
426 m1.sub(&m2)
427 } else {
428 m1.add(&crt.p).sub(&m2)
429 };
430 let h = crt.qinv.mul(&diff).modulo(&crt.p);
431 let m = m2.add(&crt.q.mul(&h));
432 let em = m.to_be_bytes_padded(self.modulus_len())?;
433 decode_oaep_sha256_plaintext(&em, label)
434 }
435
436 fn modulus_len(&self) -> usize {
450 self.n.to_be_bytes().len()
451 }
452}
453
454impl Drop for RsaPrivateKey {
455 fn drop(&mut self) {
456 self.clear();
457 }
458}
459
460impl RsaPublicKey {
461 pub fn from_be_bytes(n: &[u8], e: &[u8]) -> Result<Self> {
476 if n.is_empty() || e.is_empty() {
477 return Err(Error::InvalidLength(
478 "rsa public key fields must not be empty",
479 ));
480 }
481 let key = Self {
482 n: BigUint::from_be_bytes(n),
483 e: BigUint::from_be_bytes(e),
484 };
485 if !cfg!(feature = "hazardous-legacy-crypto") && key.n.bit_len() < RSA_MIN_SECURE_BITS {
486 return Err(Error::InvalidLength(
487 "rsa public key modulus must be at least 2048 bits",
488 ));
489 }
490 validate_public_components(&key.n, &key.e)?;
491 Ok(key)
492 }
493
494 #[must_use]
503 pub fn from_u128(n: u128, e: u128) -> Self {
504 Self {
505 n: BigUint::from_u128(n),
506 e: BigUint::from_u128(e),
507 }
508 }
509
510 pub fn clear(&mut self) {
515 self.n = BigUint::zero();
516 self.e = BigUint::zero();
517 }
518
519 pub fn verify_digest(&self, digest: &[u8], signature: &[u8]) -> Result<()> {
528 if digest.is_empty() {
529 return Err(Error::InvalidLength("digest must not be empty"));
530 }
531 validate_public_components(&self.n, &self.e)?;
532 let k = self.modulus_len();
533 let expected = BigUint::from_be_bytes(digest)
534 .modulo(&self.n)
535 .to_be_bytes_padded(k)?;
536 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
537 .to_be_bytes_padded(k)?;
538 if ct_bytes_eq(&recovered, &expected) {
539 Ok(())
540 } else {
541 Err(Error::CryptoFailure("RSA verification failed"))
542 }
543 }
544
545 pub fn verify_pkcs1_v15_sha256(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
554 validate_public_components(&self.n, &self.e)?;
555 if signature.len() != self.modulus_len() {
556 return Err(Error::InvalidLength("rsa signature length mismatch"));
557 }
558 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
559 .to_be_bytes_padded(self.modulus_len())?;
560 let expected = emsa_pkcs1_v15_encode(
561 &noxtls_sha256(msg),
562 PKCS1_V15_DIGESTINFO_SHA256_PREFIX,
563 self.modulus_len(),
564 )?;
565 if ct_bytes_eq(&recovered, &expected) {
566 Ok(())
567 } else {
568 Err(Error::CryptoFailure("RSA verification failed"))
569 }
570 }
571
572 pub fn verify_pkcs1_v15_sha1(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
581 validate_public_components(&self.n, &self.e)?;
582 if signature.len() != self.modulus_len() {
583 return Err(Error::InvalidLength("rsa signature length mismatch"));
584 }
585 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
586 .to_be_bytes_padded(self.modulus_len())?;
587 let expected = emsa_pkcs1_v15_encode(
588 &noxtls_sha1(msg),
589 PKCS1_V15_DIGESTINFO_SHA1_PREFIX,
590 self.modulus_len(),
591 )?;
592 if ct_bytes_eq(&recovered, &expected) {
593 Ok(())
594 } else {
595 Err(Error::CryptoFailure("RSA verification failed"))
596 }
597 }
598
599 pub fn verify_pkcs1_v15_sha384(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
608 validate_public_components(&self.n, &self.e)?;
609 if signature.len() != self.modulus_len() {
610 return Err(Error::InvalidLength("rsa signature length mismatch"));
611 }
612 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
613 .to_be_bytes_padded(self.modulus_len())?;
614 let expected = emsa_pkcs1_v15_encode(
615 &noxtls_sha384(msg),
616 PKCS1_V15_DIGESTINFO_SHA384_PREFIX,
617 self.modulus_len(),
618 )?;
619 if ct_bytes_eq(&recovered, &expected) {
620 Ok(())
621 } else {
622 Err(Error::CryptoFailure("RSA verification failed"))
623 }
624 }
625
626 pub fn verify_pkcs1_v15_sha512(&self, msg: &[u8], signature: &[u8]) -> Result<()> {
635 validate_public_components(&self.n, &self.e)?;
636 if signature.len() != self.modulus_len() {
637 return Err(Error::InvalidLength("rsa signature length mismatch"));
638 }
639 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
640 .to_be_bytes_padded(self.modulus_len())?;
641 let expected = emsa_pkcs1_v15_encode(
642 &noxtls_sha512(msg),
643 PKCS1_V15_DIGESTINFO_SHA512_PREFIX,
644 self.modulus_len(),
645 )?;
646 if ct_bytes_eq(&recovered, &expected) {
647 Ok(())
648 } else {
649 Err(Error::CryptoFailure("RSA verification failed"))
650 }
651 }
652
653 pub fn verify_pss_sha256(&self, msg: &[u8], signature: &[u8], salt_len: usize) -> Result<()> {
663 validate_public_components(&self.n, &self.e)?;
664 if signature.len() != self.modulus_len() {
665 return Err(Error::InvalidLength("rsa signature length mismatch"));
666 }
667 let em_bits = self.n.bit_len().saturating_sub(1);
668 let em_len = em_bits.div_ceil(8);
669 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
670 .to_be_bytes_padded(self.modulus_len())?;
671 let em = &recovered[recovered.len() - em_len..];
672 emsa_pss_verify_sha256(&noxtls_sha256(msg), em, em_bits, salt_len)
673 }
674
675 pub fn verify_pss_sha384(&self, msg: &[u8], signature: &[u8], salt_len: usize) -> Result<()> {
685 validate_public_components(&self.n, &self.e)?;
686 if signature.len() != self.modulus_len() {
687 return Err(Error::InvalidLength("rsa signature length mismatch"));
688 }
689 let em_bits = self.n.bit_len().saturating_sub(1);
690 let em_len = em_bits.div_ceil(8);
691 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
692 .to_be_bytes_padded(self.modulus_len())?;
693 let em = &recovered[recovered.len() - em_len..];
694 emsa_pss_verify_sha384(&noxtls_sha384(msg), em, em_bits, salt_len)
695 }
696
697 pub fn verify_pss_with_hashes(
698 &self,
699 msg: &[u8],
700 signature: &[u8],
701 message_hash: RsaPssHashAlgorithm,
702 mgf_hash: RsaPssHashAlgorithm,
703 salt_len: usize,
704 ) -> Result<()> {
705 validate_public_components(&self.n, &self.e)?;
706 if signature.len() != self.modulus_len() {
707 return Err(Error::InvalidLength("rsa signature length mismatch"));
708 }
709 let em_bits = self.n.bit_len().saturating_sub(1);
710 let em_len = em_bits.div_ceil(8);
711 let recovered = BigUint::mod_exp(&BigUint::from_be_bytes(signature), &self.e, &self.n)
712 .to_be_bytes_padded(self.modulus_len())?;
713 let em = &recovered[recovered.len() - em_len..];
714 let digest = rsa_pss_hash(message_hash, msg);
715 emsa_pss_verify(&digest, em, em_bits, salt_len, message_hash, mgf_hash)
716 }
717
718 pub fn encrypt_pkcs1_v15_auto(
727 &self,
728 plaintext: &[u8],
729 drbg: &mut HmacDrbgSha256,
730 ) -> Result<Vec<u8>> {
731 validate_public_components(&self.n, &self.e)?;
732 let k = self.modulus_len();
733 if plaintext.len() > k.saturating_sub(11) {
734 return Err(Error::InvalidLength(
735 "rsa plaintext too long for pkcs1 v1.5 encryption",
736 ));
737 }
738 let ps_len = k - plaintext.len() - 3;
739 let ps = drbg_nonzero_padding(drbg, ps_len)?;
740 let mut em = Vec::with_capacity(k);
741 em.push(0x00);
742 em.push(0x02);
743 em.extend_from_slice(&ps);
744 em.push(0x00);
745 em.extend_from_slice(plaintext);
746 let c = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.e, &self.n);
747 c.to_be_bytes_padded(k)
748 }
749
750 pub fn encrypt_oaep_sha256_auto(
760 &self,
761 plaintext: &[u8],
762 label: &[u8],
763 drbg: &mut HmacDrbgSha256,
764 ) -> Result<Vec<u8>> {
765 validate_public_components(&self.n, &self.e)?;
766 let k = self.modulus_len();
767 let seed = drbg.generate(32, b"rsa_oaep_sha256_seed")?;
768 let em = emea_oaep_encode_sha256(plaintext, label, &seed, k)?;
769 let c = BigUint::mod_exp(&BigUint::from_be_bytes(&em), &self.e, &self.n);
770 c.to_be_bytes_padded(k)
771 }
772
773 fn modulus_len(&self) -> usize {
787 self.n.to_be_bytes().len()
788 }
789}
790
791#[cfg(feature = "hazardous-legacy-crypto")]
801pub fn noxtls_rsa_generate_keypair_with_exponent_auto(
802 modulus_bits: usize,
803 public_exponent: u32,
804 drbg: &mut HmacDrbgSha256,
805) -> Result<(RsaPrivateKey, RsaPublicKey)> {
806 rsa_generate_keypair_backend_auto(modulus_bits, public_exponent, drbg)
807}
808
809fn rsa_generate_keypair_backend_auto(
829 modulus_bits: usize,
830 public_exponent: u32,
831 drbg: &mut HmacDrbgSha256,
832) -> Result<(RsaPrivateKey, RsaPublicKey)> {
833 if !(RSA_KEYGEN_MIN_BITS..=RSA_KEYGEN_MAX_BITS).contains(&modulus_bits) {
834 return Err(Error::InvalidLength(
835 "rsa modulus bits must be in supported range 1024..=4096",
836 ));
837 }
838 if public_exponent < 3 || (public_exponent & 1) == 0 {
839 return Err(Error::CryptoFailure(
840 "rsa public exponent must be odd and at least 3",
841 ));
842 }
843 let e = BigUint::from_u128(u128::from(public_exponent));
844 let one = BigUint::one();
845 let p_bits = modulus_bits / 2;
846 let q_bits = modulus_bits - p_bits;
847 let mut attempts = 0_u32;
848 while attempts < 256 {
849 let mut p = generate_rsa_prime_candidate_auto(p_bits, &e, drbg)?;
850 let mut q = generate_rsa_prime_candidate_auto(q_bits, &e, drbg)?;
851 let mut distinct_attempts = 0_u32;
852 while p.cmp(&q).is_eq() {
853 if distinct_attempts >= 32 {
854 break;
855 }
856 q = generate_rsa_prime_candidate_auto(q_bits, &e, drbg)?;
857 distinct_attempts = distinct_attempts.saturating_add(1);
858 }
859 if p.cmp(&q).is_eq() {
860 attempts = attempts.saturating_add(1);
861 continue;
862 }
863 if p.cmp(&q).is_gt() {
864 core::mem::swap(&mut p, &mut q);
865 }
866 let n = p.mul(&q);
867 if n.bit_len() != modulus_bits {
868 attempts = attempts.saturating_add(1);
869 continue;
870 }
871 let pm1 = p.sub(&one);
872 let qm1 = q.sub(&one);
873 let phi = pm1.mul(&qm1);
874 if BigUint::gcd(&e, &phi).cmp(&one).is_ne() {
875 attempts = attempts.saturating_add(1);
876 continue;
877 }
878 let Some(d) = BigUint::mod_inverse(&e, &phi) else {
879 attempts = attempts.saturating_add(1);
880 continue;
881 };
882 let dp = d.modulo(&pm1);
883 let dq = d.modulo(&qm1);
884 let Some(qinv) = BigUint::mod_inverse(&q, &p) else {
885 attempts = attempts.saturating_add(1);
886 continue;
887 };
888 let private = RsaPrivateKey {
889 n: n.clone(),
890 d,
891 crt: Some(RsaPrivateCrtComponents { p, q, dp, dq, qinv }),
892 };
893 let public = RsaPublicKey { n, e };
894 validate_private_components(&private.n, &private.d)?;
895 validate_public_components(&public.n, &public.e)?;
896 validate_crt_components(&private.n, private.crt.as_ref().expect("crt must exist"))?;
897 return Ok((private, public));
898 }
899 Err(Error::StateError(
900 "rsa key generation exhausted attempt budget",
901 ))
902}
903
904#[cfg(feature = "hazardous-legacy-crypto")]
913pub fn noxtls_rsa_generate_keypair_auto(
914 modulus_bits: usize,
915 drbg: &mut HmacDrbgSha256,
916) -> Result<(RsaPrivateKey, RsaPublicKey)> {
917 rsa_generate_keypair_backend_auto(modulus_bits, 65_537, drbg)
918}
919
920pub fn noxtls_rsa_generate_keypair_with_policy_auto(
931 modulus_bits: usize,
932 public_exponent: u32,
933 policy: RsaKeySizePolicy,
934 drbg: &mut HmacDrbgSha256,
935) -> Result<(RsaPrivateKey, RsaPublicKey)> {
936 if !(RSA_MIN_SECURE_BITS..=RSA_KEYGEN_MAX_BITS).contains(&modulus_bits) {
937 return Err(Error::InvalidLength(
938 "secure rsa modulus bits must be in supported range 2048..=4096",
939 ));
940 }
941 if modulus_bits < policy.min_bits() {
942 return Err(Error::InvalidLength(
943 "rsa modulus bits do not satisfy configured secure policy minimum",
944 ));
945 }
946 rsa_generate_keypair_backend_auto(modulus_bits, public_exponent, drbg)
947}
948
949pub fn noxtls_rsa_generate_keypair_secure_auto(
959 modulus_bits: usize,
960 policy: RsaKeySizePolicy,
961 drbg: &mut HmacDrbgSha256,
962) -> Result<(RsaPrivateKey, RsaPublicKey)> {
963 noxtls_rsa_generate_keypair_with_policy_auto(modulus_bits, 65_537, policy, drbg)
964}
965
966pub fn noxtls_rsassa_sha256_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
975 private.sign_pkcs1_v15_sha256(msg)
976}
977
978pub fn noxtls_rsassa_sha256_verify(
988 public: &RsaPublicKey,
989 msg: &[u8],
990 signature: &[u8],
991) -> Result<()> {
992 public.verify_pkcs1_v15_sha256(msg, signature)
993}
994
995pub fn noxtls_rsassa_sha1_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
1004 private.sign_pkcs1_v15_sha1(msg)
1005}
1006
1007pub fn noxtls_rsassa_sha1_verify(
1017 public: &RsaPublicKey,
1018 msg: &[u8],
1019 signature: &[u8],
1020) -> Result<()> {
1021 public.verify_pkcs1_v15_sha1(msg, signature)
1022}
1023
1024pub fn noxtls_rsassa_sha384_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
1033 private.sign_pkcs1_v15_sha384(msg)
1034}
1035
1036pub fn noxtls_rsassa_sha384_verify(
1046 public: &RsaPublicKey,
1047 msg: &[u8],
1048 signature: &[u8],
1049) -> Result<()> {
1050 public.verify_pkcs1_v15_sha384(msg, signature)
1051}
1052
1053pub fn noxtls_rsassa_sha512_sign(private: &RsaPrivateKey, msg: &[u8]) -> Result<Vec<u8>> {
1062 private.sign_pkcs1_v15_sha512(msg)
1063}
1064
1065pub fn noxtls_rsassa_sha512_verify(
1075 public: &RsaPublicKey,
1076 msg: &[u8],
1077 signature: &[u8],
1078) -> Result<()> {
1079 public.verify_pkcs1_v15_sha512(msg, signature)
1080}
1081
1082pub fn noxtls_rsassa_pss_sha256_sign(
1092 private: &RsaPrivateKey,
1093 msg: &[u8],
1094 salt: &[u8],
1095) -> Result<Vec<u8>> {
1096 private.sign_pss_sha256(msg, salt)
1097}
1098
1099pub fn noxtls_rsassa_pss_sha256_sign_auto(
1110 private: &RsaPrivateKey,
1111 msg: &[u8],
1112 drbg: &mut HmacDrbgSha256,
1113 salt_len: usize,
1114) -> Result<Vec<u8>> {
1115 let salt = drbg.generate(salt_len, b"rsa_pss_sha256_salt")?;
1116 private.sign_pss_sha256(msg, &salt)
1117}
1118
1119pub fn noxtls_rsassa_pss_sha256_verify(
1130 public: &RsaPublicKey,
1131 msg: &[u8],
1132 signature: &[u8],
1133 salt_len: usize,
1134) -> Result<()> {
1135 public.verify_pss_sha256(msg, signature, salt_len)
1136}
1137
1138pub fn noxtls_rsassa_pss_sha384_sign(
1148 private: &RsaPrivateKey,
1149 msg: &[u8],
1150 salt: &[u8],
1151) -> Result<Vec<u8>> {
1152 private.sign_pss_sha384(msg, salt)
1153}
1154
1155pub fn noxtls_rsassa_pss_sha384_sign_auto(
1166 private: &RsaPrivateKey,
1167 msg: &[u8],
1168 drbg: &mut HmacDrbgSha256,
1169 salt_len: usize,
1170) -> Result<Vec<u8>> {
1171 let salt = drbg.generate(salt_len, b"rsa_pss_sha384_salt")?;
1172 private.sign_pss_sha384(msg, &salt)
1173}
1174
1175pub fn noxtls_rsassa_pss_sha384_verify(
1186 public: &RsaPublicKey,
1187 msg: &[u8],
1188 signature: &[u8],
1189 salt_len: usize,
1190) -> Result<()> {
1191 public.verify_pss_sha384(msg, signature, salt_len)
1192}
1193
1194pub fn noxtls_rsassa_pss_sign(
1195 private: &RsaPrivateKey,
1196 msg: &[u8],
1197 salt: &[u8],
1198 message_hash: RsaPssHashAlgorithm,
1199 mgf_hash: RsaPssHashAlgorithm,
1200) -> Result<Vec<u8>> {
1201 private.sign_pss_with_hashes(msg, salt, message_hash, mgf_hash)
1202}
1203
1204pub fn noxtls_rsassa_pss_verify(
1205 public: &RsaPublicKey,
1206 msg: &[u8],
1207 signature: &[u8],
1208 message_hash: RsaPssHashAlgorithm,
1209 mgf_hash: RsaPssHashAlgorithm,
1210 salt_len: usize,
1211) -> Result<()> {
1212 public.verify_pss_with_hashes(msg, signature, message_hash, mgf_hash, salt_len)
1213}
1214
1215pub fn noxtls_rsaes_pkcs1_v15_encrypt_auto(
1225 public: &RsaPublicKey,
1226 plaintext: &[u8],
1227 drbg: &mut HmacDrbgSha256,
1228) -> Result<Vec<u8>> {
1229 public.encrypt_pkcs1_v15_auto(plaintext, drbg)
1230}
1231
1232pub fn noxtls_rsaes_pkcs1_v15_decrypt(
1241 private: &RsaPrivateKey,
1242 ciphertext: &[u8],
1243) -> Result<Vec<u8>> {
1244 private.decrypt_pkcs1_v15(ciphertext)
1245}
1246
1247pub fn noxtls_rsaes_pkcs1_v15_decrypt_crt_only(
1261 private: &RsaPrivateKey,
1262 ciphertext: &[u8],
1263) -> Result<Vec<u8>> {
1264 private.decrypt_pkcs1_v15_crt_only(ciphertext)
1265}
1266
1267pub fn noxtls_rsaes_oaep_sha256_encrypt_auto(
1278 public: &RsaPublicKey,
1279 plaintext: &[u8],
1280 label: &[u8],
1281 drbg: &mut HmacDrbgSha256,
1282) -> Result<Vec<u8>> {
1283 public.encrypt_oaep_sha256_auto(plaintext, label, drbg)
1284}
1285
1286pub fn noxtls_rsaes_oaep_sha256_decrypt(
1296 private: &RsaPrivateKey,
1297 ciphertext: &[u8],
1298 label: &[u8],
1299) -> Result<Vec<u8>> {
1300 private.decrypt_oaep_sha256(ciphertext, label)
1301}
1302
1303pub fn noxtls_rsaes_oaep_sha256_decrypt_crt_only(
1313 private: &RsaPrivateKey,
1314 ciphertext: &[u8],
1315 label: &[u8],
1316) -> Result<Vec<u8>> {
1317 private.decrypt_oaep_sha256_crt_only(ciphertext, label)
1318}
1319
1320const PKCS1_V15_DIGESTINFO_SHA1_PREFIX: &[u8] = &[
1321 0x30, 0x21, 0x30, 0x09, 0x06, 0x05, 0x2B, 0x0E, 0x03, 0x02, 0x1A, 0x05, 0x00, 0x04, 0x14,
1322];
1323const PKCS1_V15_DIGESTINFO_SHA256_PREFIX: &[u8] = &[
1324 0x30, 0x31, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01, 0x05,
1325 0x00, 0x04, 0x20,
1326];
1327const PKCS1_V15_DIGESTINFO_SHA384_PREFIX: &[u8] = &[
1328 0x30, 0x41, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02, 0x05,
1329 0x00, 0x04, 0x30,
1330];
1331const PKCS1_V15_DIGESTINFO_SHA512_PREFIX: &[u8] = &[
1332 0x30, 0x51, 0x30, 0x0d, 0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03, 0x05,
1333 0x00, 0x04, 0x40,
1334];
1335
1336fn emsa_pkcs1_v15_encode(hash: &[u8], digest_info_prefix: &[u8], k: usize) -> Result<Vec<u8>> {
1356 let t_len = digest_info_prefix.len() + hash.len();
1357 if k < t_len + 11 {
1358 return Err(Error::InvalidLength("rsa modulus too short for pkcs1 v1.5"));
1359 }
1360 let ps_len = k - t_len - 3;
1361 let mut em = Vec::with_capacity(k);
1362 em.push(0x00);
1363 em.push(0x01);
1364 em.extend(core::iter::repeat(0xff_u8).take(ps_len));
1365 em.push(0x00);
1366 em.extend_from_slice(digest_info_prefix);
1367 em.extend_from_slice(hash);
1368 Ok(em)
1369}
1370
1371fn emsa_pss_encode_sha256(
1392 m_hash: &[u8; 32],
1393 salt: &[u8],
1394 em_bits: usize,
1395 em_len: usize,
1396) -> Result<Vec<u8>> {
1397 const HASH_LEN: usize = 32;
1398 if em_len < HASH_LEN + salt.len() + 2 {
1399 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1400 }
1401
1402 let mut m_prime = vec![0_u8; 8];
1403 m_prime.extend_from_slice(m_hash);
1404 m_prime.extend_from_slice(salt);
1405 let h = noxtls_sha256(&m_prime);
1406
1407 let ps_len = em_len - salt.len() - HASH_LEN - 2;
1408 let mut db = vec![0_u8; ps_len];
1409 db.push(0x01);
1410 db.extend_from_slice(salt);
1411
1412 let db_mask = mgf1_sha256(&h, em_len - HASH_LEN - 1)?;
1413 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1414 *byte ^= *mask;
1415 }
1416
1417 let unused_bits = 8 * em_len - em_bits;
1418 if unused_bits > 0 {
1419 db[0] &= 0xff_u8 >> unused_bits;
1420 }
1421
1422 let mut em = db;
1423 em.extend_from_slice(&h);
1424 em.push(0xbc);
1425 Ok(em)
1426}
1427
1428fn emsa_pss_verify_sha256(
1449 m_hash: &[u8; 32],
1450 em: &[u8],
1451 em_bits: usize,
1452 salt_len: usize,
1453) -> Result<()> {
1454 const HASH_LEN: usize = 32;
1455 if em.len() < HASH_LEN + salt_len + 2 {
1456 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1457 }
1458 if em.last().copied() != Some(0xbc) {
1459 return Err(Error::CryptoFailure("RSA verification failed"));
1460 }
1461
1462 let db_len = em.len() - HASH_LEN - 1;
1463 let (masked_db, rest) = em.split_at(db_len);
1464 let h = &rest[..HASH_LEN];
1465
1466 let unused_bits = 8 * em.len() - em_bits;
1467 if unused_bits > 0 {
1468 let mask = 0xff_u8 << (8 - unused_bits);
1469 if masked_db[0] & mask != 0 {
1470 return Err(Error::CryptoFailure("RSA verification failed"));
1471 }
1472 }
1473
1474 let db_mask = mgf1_sha256(h, db_len)?;
1475 let mut db = masked_db.to_vec();
1476 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1477 *byte ^= *mask;
1478 }
1479 if unused_bits > 0 {
1480 db[0] &= 0xff_u8 >> unused_bits;
1481 }
1482
1483 let ps_len = em.len() - HASH_LEN - salt_len - 2;
1484 if !ct_all_zero(&db[..ps_len]) || db[ps_len] != 0x01 {
1485 return Err(Error::CryptoFailure("RSA verification failed"));
1486 }
1487 let salt = &db[db.len() - salt_len..];
1488
1489 let mut m_prime = vec![0_u8; 8];
1490 m_prime.extend_from_slice(m_hash);
1491 m_prime.extend_from_slice(salt);
1492 let expected_h = noxtls_sha256(&m_prime);
1493 if ct_bytes_eq(expected_h.as_slice(), h) {
1494 Ok(())
1495 } else {
1496 Err(Error::CryptoFailure("RSA verification failed"))
1497 }
1498}
1499
1500fn emsa_pss_encode_sha384(
1521 m_hash: &[u8; 48],
1522 salt: &[u8],
1523 em_bits: usize,
1524 em_len: usize,
1525) -> Result<Vec<u8>> {
1526 const HASH_LEN: usize = 48;
1527 if em_len < HASH_LEN + salt.len() + 2 {
1528 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1529 }
1530
1531 let mut m_prime = vec![0_u8; 8];
1532 m_prime.extend_from_slice(m_hash);
1533 m_prime.extend_from_slice(salt);
1534 let h = noxtls_sha384(&m_prime);
1535
1536 let ps_len = em_len - salt.len() - HASH_LEN - 2;
1537 let mut db = vec![0_u8; ps_len];
1538 db.push(0x01);
1539 db.extend_from_slice(salt);
1540
1541 let db_mask = mgf1_sha384(&h, em_len - HASH_LEN - 1)?;
1542 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1543 *byte ^= *mask;
1544 }
1545
1546 let unused_bits = 8 * em_len - em_bits;
1547 if unused_bits > 0 {
1548 db[0] &= 0xff_u8 >> unused_bits;
1549 }
1550
1551 let mut em = db;
1552 em.extend_from_slice(&h);
1553 em.push(0xbc);
1554 Ok(em)
1555}
1556
1557fn emsa_pss_encode(
1558 m_hash: &[u8],
1559 salt: &[u8],
1560 em_bits: usize,
1561 em_len: usize,
1562 message_hash: RsaPssHashAlgorithm,
1563 mgf_hash: RsaPssHashAlgorithm,
1564) -> Result<Vec<u8>> {
1565 let hash_len = m_hash.len();
1566 if em_len < hash_len + salt.len() + 2 {
1567 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1568 }
1569
1570 let mut m_prime = vec![0_u8; 8];
1571 m_prime.extend_from_slice(m_hash);
1572 m_prime.extend_from_slice(salt);
1573 let h = rsa_pss_hash(message_hash, &m_prime);
1574
1575 let ps_len = em_len - salt.len() - hash_len - 2;
1576 let mut db = vec![0_u8; ps_len];
1577 db.push(0x01);
1578 db.extend_from_slice(salt);
1579
1580 let db_mask = mgf1_hash(mgf_hash, &h, em_len - hash_len - 1)?;
1581 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1582 *byte ^= *mask;
1583 }
1584
1585 let unused_bits = 8 * em_len - em_bits;
1586 if unused_bits > 0 {
1587 db[0] &= 0xff_u8 >> unused_bits;
1588 }
1589
1590 let mut em = db;
1591 em.extend_from_slice(&h);
1592 em.push(0xbc);
1593 Ok(em)
1594}
1595
1596fn emsa_pss_verify_sha384(
1617 m_hash: &[u8; 48],
1618 em: &[u8],
1619 em_bits: usize,
1620 salt_len: usize,
1621) -> Result<()> {
1622 const HASH_LEN: usize = 48;
1623 if em.len() < HASH_LEN + salt_len + 2 {
1624 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1625 }
1626 if em.last().copied() != Some(0xbc) {
1627 return Err(Error::CryptoFailure("RSA verification failed"));
1628 }
1629
1630 let db_len = em.len() - HASH_LEN - 1;
1631 let (masked_db, rest) = em.split_at(db_len);
1632 let h = &rest[..HASH_LEN];
1633
1634 let unused_bits = 8 * em.len() - em_bits;
1635 if unused_bits > 0 {
1636 let mask = 0xff_u8 << (8 - unused_bits);
1637 if masked_db[0] & mask != 0 {
1638 return Err(Error::CryptoFailure("RSA verification failed"));
1639 }
1640 }
1641
1642 let db_mask = mgf1_sha384(h, db_len)?;
1643 let mut db = masked_db.to_vec();
1644 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1645 *byte ^= *mask;
1646 }
1647 if unused_bits > 0 {
1648 db[0] &= 0xff_u8 >> unused_bits;
1649 }
1650
1651 let ps_len = em.len() - HASH_LEN - salt_len - 2;
1652 if !ct_all_zero(&db[..ps_len]) || db[ps_len] != 0x01 {
1653 return Err(Error::CryptoFailure("RSA verification failed"));
1654 }
1655 let salt = &db[db.len() - salt_len..];
1656
1657 let mut m_prime = vec![0_u8; 8];
1658 m_prime.extend_from_slice(m_hash);
1659 m_prime.extend_from_slice(salt);
1660 let expected_h = noxtls_sha384(&m_prime);
1661 if ct_bytes_eq(expected_h.as_slice(), h) {
1662 Ok(())
1663 } else {
1664 Err(Error::CryptoFailure("RSA verification failed"))
1665 }
1666}
1667
1668fn emsa_pss_verify(
1669 m_hash: &[u8],
1670 em: &[u8],
1671 em_bits: usize,
1672 salt_len: usize,
1673 message_hash: RsaPssHashAlgorithm,
1674 mgf_hash: RsaPssHashAlgorithm,
1675) -> Result<()> {
1676 let hash_len = m_hash.len();
1677 if em.len() < hash_len + salt_len + 2 {
1678 return Err(Error::InvalidLength("rsa modulus too short for pss"));
1679 }
1680 if em.last().copied() != Some(0xbc) {
1681 return Err(Error::CryptoFailure("RSA verification failed"));
1682 }
1683
1684 let db_len = em.len() - hash_len - 1;
1685 let (masked_db, rest) = em.split_at(db_len);
1686 let h = &rest[..hash_len];
1687
1688 let unused_bits = 8 * em.len() - em_bits;
1689 if unused_bits > 0 {
1690 let mask = 0xff_u8 << (8 - unused_bits);
1691 if masked_db[0] & mask != 0 {
1692 return Err(Error::CryptoFailure("RSA verification failed"));
1693 }
1694 }
1695
1696 let db_mask = mgf1_hash(mgf_hash, h, db_len)?;
1697 let mut db = masked_db.to_vec();
1698 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1699 *byte ^= *mask;
1700 }
1701 if unused_bits > 0 {
1702 db[0] &= 0xff_u8 >> unused_bits;
1703 }
1704
1705 let ps_len = em.len() - hash_len - salt_len - 2;
1706 if !ct_all_zero(&db[..ps_len]) || db[ps_len] != 0x01 {
1707 return Err(Error::CryptoFailure("RSA verification failed"));
1708 }
1709 let salt = &db[db.len() - salt_len..];
1710
1711 let mut m_prime = vec![0_u8; 8];
1712 m_prime.extend_from_slice(m_hash);
1713 m_prime.extend_from_slice(salt);
1714 let expected_h = rsa_pss_hash(message_hash, &m_prime);
1715 if ct_bytes_eq(expected_h.as_slice(), h) {
1716 Ok(())
1717 } else {
1718 Err(Error::CryptoFailure("RSA verification failed"))
1719 }
1720}
1721
1722fn rsa_pss_hash(hash: RsaPssHashAlgorithm, input: &[u8]) -> Vec<u8> {
1723 match hash {
1724 RsaPssHashAlgorithm::Sha1 => noxtls_sha1(input).to_vec(),
1725 RsaPssHashAlgorithm::Sha256 => noxtls_sha256(input).to_vec(),
1726 RsaPssHashAlgorithm::Sha384 => noxtls_sha384(input).to_vec(),
1727 RsaPssHashAlgorithm::Sha512 => noxtls_sha512(input).to_vec(),
1728 }
1729}
1730
1731fn mgf1_hash(hash: RsaPssHashAlgorithm, seed: &[u8], out_len: usize) -> Result<Vec<u8>> {
1732 match hash {
1733 RsaPssHashAlgorithm::Sha1 => mgf1(seed, out_len, RsaPssHashAlgorithm::Sha1),
1734 RsaPssHashAlgorithm::Sha256 => mgf1_sha256(seed, out_len),
1735 RsaPssHashAlgorithm::Sha384 => mgf1_sha384(seed, out_len),
1736 RsaPssHashAlgorithm::Sha512 => mgf1(seed, out_len, RsaPssHashAlgorithm::Sha512),
1737 }
1738}
1739
1740fn mgf1(seed: &[u8], out_len: usize, hash: RsaPssHashAlgorithm) -> Result<Vec<u8>> {
1741 let mut out = Vec::with_capacity(out_len);
1742 let mut counter = 0_u32;
1743 while out.len() < out_len {
1744 if counter == u32::MAX {
1745 return Err(Error::InvalidLength("mgf1 output too large"));
1746 }
1747 let mut block_input = Vec::with_capacity(seed.len() + 4);
1748 block_input.extend_from_slice(seed);
1749 block_input.extend_from_slice(&counter.to_be_bytes());
1750 out.extend_from_slice(&rsa_pss_hash(hash, &block_input));
1751 counter = counter.wrapping_add(1);
1752 }
1753 out.truncate(out_len);
1754 Ok(out)
1755}
1756
1757fn mgf1_sha256(seed: &[u8], out_len: usize) -> Result<Vec<u8>> {
1776 let mut out = Vec::with_capacity(out_len);
1777 let mut counter = 0_u32;
1778 while out.len() < out_len {
1779 if counter == u32::MAX {
1780 return Err(Error::InvalidLength("mgf1 output too large"));
1781 }
1782 let mut block_input = Vec::with_capacity(seed.len() + 4);
1783 block_input.extend_from_slice(seed);
1784 block_input.extend_from_slice(&counter.to_be_bytes());
1785 out.extend_from_slice(&noxtls_sha256(&block_input));
1786 counter = counter.wrapping_add(1);
1787 }
1788 out.truncate(out_len);
1789 Ok(out)
1790}
1791
1792fn mgf1_sha384(seed: &[u8], out_len: usize) -> Result<Vec<u8>> {
1811 let mut out = Vec::with_capacity(out_len);
1812 let mut counter = 0_u32;
1813 while out.len() < out_len {
1814 if counter == u32::MAX {
1815 return Err(Error::InvalidLength("mgf1 output too large"));
1816 }
1817 let mut block_input = Vec::with_capacity(seed.len() + 4);
1818 block_input.extend_from_slice(seed);
1819 block_input.extend_from_slice(&counter.to_be_bytes());
1820 out.extend_from_slice(&noxtls_sha384(&block_input));
1821 counter = counter.wrapping_add(1);
1822 }
1823 out.truncate(out_len);
1824 Ok(out)
1825}
1826
1827fn drbg_nonzero_padding(drbg: &mut HmacDrbgSha256, len: usize) -> Result<Vec<u8>> {
1846 let mut out = Vec::with_capacity(len);
1847 while out.len() < len {
1848 let block = drbg.generate(len.saturating_sub(out.len()), b"rsa_pkcs1_v15_ps")?;
1849 for byte in block {
1850 if byte != 0 {
1851 out.push(byte);
1852 if out.len() == len {
1853 break;
1854 }
1855 }
1856 }
1857 }
1858 Ok(out)
1859}
1860
1861fn emea_oaep_encode_sha256(
1882 plaintext: &[u8],
1883 label: &[u8],
1884 seed: &[u8],
1885 k: usize,
1886) -> Result<Vec<u8>> {
1887 const HASH_LEN: usize = 32;
1888 if seed.len() != HASH_LEN {
1889 return Err(Error::InvalidLength("rsa oaep seed must be 32 bytes"));
1890 }
1891 if k < (2 * HASH_LEN + 2) {
1892 return Err(Error::InvalidLength(
1893 "rsa modulus too short for oaep sha256",
1894 ));
1895 }
1896 if plaintext.len() > k - (2 * HASH_LEN + 2) {
1897 return Err(Error::InvalidLength(
1898 "rsa plaintext too long for oaep sha256",
1899 ));
1900 }
1901 let l_hash = noxtls_sha256(label);
1902 let ps_len = k - plaintext.len() - (2 * HASH_LEN + 2);
1903 let mut db = Vec::with_capacity(k - HASH_LEN - 1);
1904 db.extend_from_slice(&l_hash);
1905 db.extend(core::iter::repeat(0_u8).take(ps_len));
1906 db.push(0x01);
1907 db.extend_from_slice(plaintext);
1908 let db_mask = mgf1_sha256(seed, k - HASH_LEN - 1)?;
1909 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1910 *byte ^= *mask;
1911 }
1912 let seed_mask = mgf1_sha256(&db, HASH_LEN)?;
1913 let mut masked_seed = seed.to_vec();
1914 for (byte, mask) in masked_seed.iter_mut().zip(seed_mask.iter()) {
1915 *byte ^= *mask;
1916 }
1917 let mut em = Vec::with_capacity(k);
1918 em.push(0x00);
1919 em.extend_from_slice(&masked_seed);
1920 em.extend_from_slice(&db);
1921 Ok(em)
1922}
1923
1924fn decode_oaep_sha256_plaintext(encoded: &[u8], label: &[u8]) -> Result<Vec<u8>> {
1943 const HASH_LEN: usize = 32;
1944 if encoded.len() < (2 * HASH_LEN + 2) {
1945 return Err(Error::InvalidLength(
1946 "rsa modulus too short for oaep sha256",
1947 ));
1948 }
1949 let mut invalid = 0_u8;
1950 invalid |= encoded[0];
1951 let (masked_seed, masked_db) = encoded[1..].split_at(HASH_LEN);
1952 let seed_mask = mgf1_sha256(masked_db, HASH_LEN)?;
1953 let mut seed = masked_seed.to_vec();
1954 for (byte, mask) in seed.iter_mut().zip(seed_mask.iter()) {
1955 *byte ^= *mask;
1956 }
1957 let db_mask = mgf1_sha256(&seed, masked_db.len())?;
1958 let mut db = masked_db.to_vec();
1959 for (byte, mask) in db.iter_mut().zip(db_mask.iter()) {
1960 *byte ^= *mask;
1961 }
1962 let expected_l_hash = noxtls_sha256(label);
1963 invalid |= u8::from(!ct_bytes_eq(&db[..HASH_LEN], expected_l_hash.as_slice()));
1964 let rest = &db[HASH_LEN..];
1965 let mut marker_idx = 0_usize;
1966 let mut found_marker = 0_u8;
1967 let mut invalid_ps = 0_u8;
1968 for (idx, &byte) in rest.iter().enumerate() {
1969 let is_zero = u8::from(byte == 0);
1970 let is_one = u8::from(byte == 1);
1971 let before_marker = 1_u8 ^ found_marker;
1972 let should_set = before_marker & is_one;
1973 marker_idx = ct_select_usize(should_set, idx, marker_idx);
1974 invalid_ps |= before_marker & (1_u8 ^ is_zero) & (1_u8 ^ is_one);
1975 found_marker |= is_one;
1976 }
1977 invalid |= invalid_ps;
1978 invalid |= 1_u8 ^ found_marker;
1979 if invalid != 0 {
1980 return Err(Error::CryptoFailure("rsa decryption failed"));
1981 }
1982 Ok(rest[marker_idx.saturating_add(1)..].to_vec())
1983}
1984
1985fn decode_pkcs1_v15_plaintext(encoded: &[u8]) -> Result<Vec<u8>> {
2003 if encoded.len() < 11 {
2004 return Err(Error::CryptoFailure("rsa decryption failed"));
2005 }
2006 let mut invalid = 0_u8;
2007 invalid |= encoded[0];
2008 invalid |= encoded[1] ^ 0x02;
2009
2010 let mut sep_idx = 0_usize;
2011 let mut found_sep = 0_u8;
2012 for (idx, &byte) in encoded.iter().enumerate().skip(2) {
2013 let is_zero = u8::from(byte == 0);
2014 let should_set = is_zero & (1_u8 ^ found_sep);
2015 sep_idx = ct_select_usize(should_set, idx, sep_idx);
2016 found_sep |= is_zero;
2017 }
2018 if found_sep == 0 {
2019 invalid |= 1;
2020 }
2021 if sep_idx < 10 {
2022 invalid |= 1;
2023 }
2024 if invalid != 0 {
2025 return Err(Error::CryptoFailure("rsa decryption failed"));
2026 }
2027 Ok(encoded[sep_idx + 1..].to_vec())
2028}
2029
2030fn ct_bytes_eq(left: &[u8], right: &[u8]) -> bool {
2045 if left.len() != right.len() {
2046 return false;
2047 }
2048 let mut diff = 0_u8;
2049 for (&l, &r) in left.iter().zip(right.iter()) {
2050 diff |= l ^ r;
2051 }
2052 diff == 0
2053}
2054
2055fn ct_all_zero(bytes: &[u8]) -> bool {
2069 let mut acc = 0_u8;
2070 for &byte in bytes {
2071 acc |= byte;
2072 }
2073 acc == 0
2074}
2075
2076fn ct_select_usize(selector: u8, if_one: usize, if_zero: usize) -> usize {
2092 let mask = (0_usize).wrapping_sub(usize::from(selector));
2093 (if_one & mask) | (if_zero & !mask)
2094}
2095
2096fn validate_private_components(n: &BigUint, d: &BigUint) -> Result<()> {
2115 validate_modulus(n)?;
2116 if d.is_zero() {
2117 return Err(Error::CryptoFailure(
2118 "rsa private exponent must be non-zero",
2119 ));
2120 }
2121 if !d.is_odd() {
2122 return Err(Error::CryptoFailure("rsa private exponent must be odd"));
2123 }
2124 if d.cmp(n).is_ge() {
2125 return Err(Error::CryptoFailure(
2126 "rsa private exponent must be smaller than modulus",
2127 ));
2128 }
2129 Ok(())
2130}
2131
2132fn validate_public_components(n: &BigUint, e: &BigUint) -> Result<()> {
2151 validate_modulus(n)?;
2152 let three = BigUint::from_u128(3);
2153 if e.cmp(&three).is_lt() {
2154 return Err(Error::CryptoFailure(
2155 "rsa public exponent must be at least 3",
2156 ));
2157 }
2158 if !e.is_odd() {
2159 return Err(Error::CryptoFailure("rsa public exponent must be odd"));
2160 }
2161 if e.cmp(n).is_ge() {
2162 return Err(Error::CryptoFailure(
2163 "rsa public exponent must be smaller than modulus",
2164 ));
2165 }
2166 Ok(())
2167}
2168
2169fn validate_modulus(n: &BigUint) -> Result<()> {
2187 let three = BigUint::from_u128(3);
2188 if n.cmp(&three).is_lt() {
2189 return Err(Error::CryptoFailure("rsa modulus must be greater than 3"));
2190 }
2191 if !n.is_odd() {
2192 return Err(Error::CryptoFailure("rsa modulus must be odd"));
2193 }
2194 Ok(())
2195}
2196
2197fn validate_crt_components(n: &BigUint, crt: &RsaPrivateCrtComponents) -> Result<()> {
2216 if crt.p.is_zero()
2217 || crt.q.is_zero()
2218 || crt.dp.is_zero()
2219 || crt.dq.is_zero()
2220 || crt.qinv.is_zero()
2221 {
2222 return Err(Error::CryptoFailure("rsa crt parameters must be non-zero"));
2223 }
2224 if !crt.p.is_odd() || !crt.q.is_odd() {
2225 return Err(Error::CryptoFailure("rsa crt primes must be odd"));
2226 }
2227 if crt.p.mul(&crt.q).cmp(n).is_ne() {
2228 return Err(Error::CryptoFailure(
2229 "rsa crt prime product must equal modulus",
2230 ));
2231 }
2232 if crt.dp.cmp(&crt.p).is_ge() || crt.dq.cmp(&crt.q).is_ge() {
2233 return Err(Error::CryptoFailure("rsa crt exponents must be reduced"));
2234 }
2235 if crt.qinv.cmp(&crt.p).is_ge() {
2236 return Err(Error::CryptoFailure(
2237 "rsa crt coefficient must be smaller than p",
2238 ));
2239 }
2240 let one = BigUint::one();
2241 if crt.q.mul(&crt.qinv).modulo(&crt.p).cmp(&one).is_ne() {
2242 return Err(Error::CryptoFailure(
2243 "rsa crt coefficient must be inverse of q modulo p",
2244 ));
2245 }
2246 Ok(())
2247}
2248
2249fn generate_rsa_prime_candidate_auto(
2269 bits: usize,
2270 e: &BigUint,
2271 drbg: &mut HmacDrbgSha256,
2272) -> Result<BigUint> {
2273 let one = BigUint::one();
2274 let mut attempts = 0_u32;
2275 while attempts < 20_000 {
2276 let candidate = random_biguint_with_bits(bits, drbg, b"rsa_prime_candidate")?;
2277 if candidate.bit_len() != bits {
2278 attempts = attempts.saturating_add(1);
2279 continue;
2280 }
2281 if !is_probable_prime(&candidate) {
2282 attempts = attempts.saturating_add(1);
2283 continue;
2284 }
2285 let pm1 = candidate.sub(&one);
2286 if BigUint::gcd(e, &pm1).cmp(&one).is_eq() {
2287 return Ok(candidate);
2288 }
2289 attempts = attempts.saturating_add(1);
2290 }
2291 Err(Error::StateError(
2292 "rsa prime generation exhausted attempt budget",
2293 ))
2294}
2295
2296fn random_biguint_with_bits(
2316 bits: usize,
2317 drbg: &mut HmacDrbgSha256,
2318 label: &[u8],
2319) -> Result<BigUint> {
2320 if bits < 2 {
2321 return Err(Error::InvalidLength(
2322 "rsa prime candidate bits must be at least 2",
2323 ));
2324 }
2325 let byte_len = bits.div_ceil(8);
2326 let mut random = drbg.generate(byte_len, label)?;
2327 let top_bits = bits % 8;
2328 if top_bits != 0 {
2329 random[0] &= (1_u8 << top_bits) - 1;
2330 }
2331 let high_bit_index = (bits - 1) % 8;
2332 random[0] |= 1_u8 << high_bit_index;
2333 let last = random.len() - 1;
2334 random[last] |= 1;
2335 Ok(BigUint::from_be_bytes(&random))
2336}
2337
2338fn is_probable_prime(n: &BigUint) -> bool {
2352 let two = BigUint::from_u128(2);
2353 if n.cmp(&two).is_lt() {
2354 return false;
2355 }
2356 for small in [2_u32, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37] {
2357 let small_bn = BigUint::from_u128(u128::from(small));
2358 if n.cmp(&small_bn).is_eq() {
2359 return true;
2360 }
2361 if n.mod_u32(small) == 0 {
2362 return false;
2363 }
2364 }
2365 let one = BigUint::one();
2366 let n_minus_one = n.sub(&one);
2367 let mut d = n_minus_one.clone();
2368 let mut s = 0_u32;
2369 while d.is_even() {
2370 d = d.shr1();
2371 s = s.saturating_add(1);
2372 }
2373 for witness in [2_u32, 3, 5, 7, 11, 13, 17, 19, 23, 29] {
2374 if !miller_rabin_round(n, &d, s, witness) {
2375 return false;
2376 }
2377 }
2378 true
2379}
2380
2381fn miller_rabin_round(n: &BigUint, d: &BigUint, s: u32, witness: u32) -> bool {
2398 let a = BigUint::from_u128(u128::from(witness)).modulo(n);
2399 if a.is_zero() {
2400 return true;
2401 }
2402 let one = BigUint::one();
2403 let n_minus_one = n.sub(&one);
2404 let mut x = BigUint::mod_exp(&a, d, n);
2405 if x.cmp(&one).is_eq() || x.cmp(&n_minus_one).is_eq() {
2406 return true;
2407 }
2408 for _ in 1..s {
2409 x = x.mul(&x).modulo(n);
2410 if x.cmp(&n_minus_one).is_eq() {
2411 return true;
2412 }
2413 }
2414 false
2415}