1pub use rsa::{errors::Error as RsaError, RsaPrivateKey, RsaPublicKey};
4
5use rand_core::{CryptoRng, RngCore};
6use rsa::{
7 traits::{PrivateKeyParts, PublicKeyParts},
8 BigUint, Pkcs1v15Sign, Pss,
9};
10use sha2::{Digest, Sha256, Sha384, Sha512};
11
12use core::{fmt, str::FromStr};
13
14use crate::{
15 alg::{SecretBytes, StrongKey, WeakKeyError},
16 alloc::{Cow, String, ToOwned, Vec},
17 jwk::{JsonWebKey, JwkError, KeyType, RsaPrimeFactor, RsaPrivateParts},
18 Algorithm, AlgorithmSignature,
19};
20use scale_info::prelude::boxed::Box;
21
22#[derive(Debug)]
24#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
25pub struct RsaSignature(Vec<u8>);
26
27impl AlgorithmSignature for RsaSignature {
28 const LENGTH: Option<core::num::NonZeroUsize> = None;
29
30 fn try_from_slice(bytes: &[u8]) -> anyhow::Result<Self> {
31 Ok(RsaSignature(bytes.to_vec()))
32 }
33
34 fn as_bytes(&self) -> Cow<'_, [u8]> {
35 Cow::Borrowed(&self.0)
36 }
37}
38
39#[derive(Debug, Copy, Clone, Eq, PartialEq)]
41enum HashAlg {
42 Sha256,
43 Sha384,
44 Sha512,
45}
46
47impl HashAlg {
48 fn digest(self, message: &[u8]) -> Box<[u8]> {
49 match self {
50 Self::Sha256 => {
51 let digest: [u8; 32] = *(Sha256::digest(message).as_ref());
52 Box::new(digest)
53 },
54 Self::Sha384 => {
55 let mut digest = [0_u8; 48];
56 digest.copy_from_slice(Sha384::digest(message).as_ref());
57 Box::new(digest)
58 },
59 Self::Sha512 => {
60 let mut digest = [0_u8; 64];
61 digest.copy_from_slice(Sha512::digest(message).as_ref());
62 Box::new(digest)
63 },
64 }
65 }
66}
67
68#[derive(Debug, Copy, Clone, Eq, PartialEq)]
70enum Padding {
71 Pkcs1v15,
72 Pss,
73}
74
75#[derive(Debug)]
76enum PaddingScheme {
77 Pkcs1v15(Pkcs1v15Sign),
78 Pss(Pss),
79}
80
81#[derive(Debug, Copy, Clone, Eq, PartialEq)]
83#[non_exhaustive]
84#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
85pub enum ModulusBits {
86 TwoKibibytes,
88 ThreeKibibytes,
90 FourKibibytes,
92}
93
94impl ModulusBits {
95 pub fn bits(self) -> usize {
97 match self {
98 Self::TwoKibibytes => 2_048,
99 Self::ThreeKibibytes => 3_072,
100 Self::FourKibibytes => 4_096,
101 }
102 }
103
104 fn is_valid_bits(bits: usize) -> bool {
105 matches!(bits, 2_048 | 3_072 | 4_096)
106 }
107}
108
109impl TryFrom<usize> for ModulusBits {
110 type Error = ModulusBitsError;
111
112 fn try_from(value: usize) -> Result<Self, Self::Error> {
113 match value {
114 2_048 => Ok(Self::TwoKibibytes),
115 3_072 => Ok(Self::ThreeKibibytes),
116 4_096 => Ok(Self::FourKibibytes),
117 _ => Err(ModulusBitsError(())),
118 }
119 }
120}
121
122#[derive(Debug)]
124#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
125pub struct ModulusBitsError(());
126
127impl fmt::Display for ModulusBitsError {
128 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
129 formatter.write_str(
130 "Unsupported bit length of RSA modulus; only lengths 2048, 3072 and 4096 \
131 are supported.",
132 )
133 }
134}
135
136#[cfg(feature = "std")]
137impl std::error::Error for ModulusBitsError {}
138
139#[derive(Debug, Clone, Copy, PartialEq, Eq)]
162#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
163pub struct Rsa {
164 hash_alg: HashAlg,
165 padding_alg: Padding,
166}
167
168impl Algorithm for Rsa {
169 type Signature = RsaSignature;
170 type SigningKey = RsaPrivateKey;
171 type VerifyingKey = RsaPublicKey;
172
173 fn name(&self) -> Cow<'static, str> {
174 Cow::Borrowed(self.alg_name())
175 }
176
177 fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
178 let digest = self.hash_alg.digest(message);
179 let digest = digest.as_ref();
180 let signing_result = match self.padding_scheme() {
181 PaddingScheme::Pkcs1v15(padding) => signing_key.sign_with_rng(&mut rand_core::OsRng, padding, digest),
182 PaddingScheme::Pss(padding) => signing_key.sign_with_rng(&mut rand_core::OsRng, padding, digest),
183 };
184 RsaSignature(signing_result.expect("Unexpected RSA signature failure"))
185 }
186
187 fn verify_signature(
188 &self,
189 signature: &Self::Signature,
190 verifying_key: &Self::VerifyingKey,
191 message: &[u8],
192 ) -> bool {
193 let digest = self.hash_alg.digest(message);
194 let verify_result = match self.padding_scheme() {
195 PaddingScheme::Pkcs1v15(padding) => verifying_key.verify(padding, &digest, &signature.0),
196 PaddingScheme::Pss(padding) => verifying_key.verify(padding, &digest, &signature.0),
197 };
198 verify_result.is_ok()
199 }
200}
201
202impl Rsa {
203 const fn new(hash_alg: HashAlg, padding_alg: Padding) -> Self {
204 Rsa { hash_alg, padding_alg }
205 }
206
207 pub const fn rs256() -> Rsa {
209 Rsa::new(HashAlg::Sha256, Padding::Pkcs1v15)
210 }
211
212 pub const fn rs384() -> Rsa {
214 Rsa::new(HashAlg::Sha384, Padding::Pkcs1v15)
215 }
216
217 pub const fn rs512() -> Rsa {
219 Rsa::new(HashAlg::Sha512, Padding::Pkcs1v15)
220 }
221
222 pub const fn ps256() -> Rsa {
224 Rsa::new(HashAlg::Sha256, Padding::Pss)
225 }
226
227 pub const fn ps384() -> Rsa {
229 Rsa::new(HashAlg::Sha384, Padding::Pss)
230 }
231
232 pub const fn ps512() -> Rsa {
234 Rsa::new(HashAlg::Sha512, Padding::Pss)
235 }
236
237 pub fn with_name(name: &str) -> Self {
244 name.parse().unwrap()
245 }
246
247 fn padding_scheme(self) -> PaddingScheme {
248 match self.padding_alg {
249 Padding::Pkcs1v15 => PaddingScheme::Pkcs1v15(match self.hash_alg {
250 HashAlg::Sha256 => Pkcs1v15Sign::new::<Sha256>(),
251 HashAlg::Sha384 => Pkcs1v15Sign::new::<Sha384>(),
252 HashAlg::Sha512 => Pkcs1v15Sign::new::<Sha512>(),
253 }),
254 Padding::Pss => {
255 PaddingScheme::Pss(match self.hash_alg {
258 HashAlg::Sha256 => Pss::new_with_salt::<Sha256>(Sha256::output_size()),
259 HashAlg::Sha384 => Pss::new_with_salt::<Sha384>(Sha384::output_size()),
260 HashAlg::Sha512 => Pss::new_with_salt::<Sha512>(Sha512::output_size()),
261 })
262 },
263 }
264 }
265
266 fn alg_name(self) -> &'static str {
267 match (self.padding_alg, self.hash_alg) {
268 (Padding::Pkcs1v15, HashAlg::Sha256) => "RS256",
269 (Padding::Pkcs1v15, HashAlg::Sha384) => "RS384",
270 (Padding::Pkcs1v15, HashAlg::Sha512) => "RS512",
271 (Padding::Pss, HashAlg::Sha256) => "PS256",
272 (Padding::Pss, HashAlg::Sha384) => "PS384",
273 (Padding::Pss, HashAlg::Sha512) => "PS512",
274 }
275 }
276
277 pub fn generate<R: CryptoRng + RngCore>(
279 rng: &mut R,
280 modulus_bits: ModulusBits,
281 ) -> rsa::errors::Result<(StrongKey<RsaPrivateKey>, StrongKey<RsaPublicKey>)> {
282 let signing_key = RsaPrivateKey::new(rng, modulus_bits.bits())?;
283 let verifying_key = signing_key.to_public_key();
284 Ok((StrongKey(signing_key), StrongKey(verifying_key)))
285 }
286}
287
288impl FromStr for Rsa {
289 type Err = RsaParseError;
290
291 fn from_str(s: &str) -> Result<Self, Self::Err> {
292 Ok(match s {
293 "RS256" => Self::rs256(),
294 "RS384" => Self::rs384(),
295 "RS512" => Self::rs512(),
296 "PS256" => Self::ps256(),
297 "PS384" => Self::ps384(),
298 "PS512" => Self::ps512(),
299 _ => return Err(RsaParseError(s.to_owned())),
300 })
301 }
302}
303
304#[derive(Debug)]
306#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
307pub struct RsaParseError(String);
308
309impl fmt::Display for RsaParseError {
310 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
311 write!(formatter, "Invalid RSA algorithm name: {}", self.0)
312 }
313}
314
315#[cfg(feature = "std")]
316impl std::error::Error for RsaParseError {}
317
318impl StrongKey<RsaPrivateKey> {
319 pub fn to_public_key(&self) -> StrongKey<RsaPublicKey> {
321 StrongKey(self.0.to_public_key())
322 }
323}
324
325impl TryFrom<RsaPrivateKey> for StrongKey<RsaPrivateKey> {
326 type Error = WeakKeyError<RsaPrivateKey>;
327
328 fn try_from(key: RsaPrivateKey) -> Result<Self, Self::Error> {
329 if ModulusBits::is_valid_bits(key.n().bits()) {
330 Ok(StrongKey(key))
331 } else {
332 Err(WeakKeyError(key))
333 }
334 }
335}
336
337impl TryFrom<RsaPublicKey> for StrongKey<RsaPublicKey> {
338 type Error = WeakKeyError<RsaPublicKey>;
339
340 fn try_from(key: RsaPublicKey) -> Result<Self, Self::Error> {
341 if ModulusBits::is_valid_bits(key.n().bits()) {
342 Ok(StrongKey(key))
343 } else {
344 Err(WeakKeyError(key))
345 }
346 }
347}
348
349impl<'a> From<&'a RsaPublicKey> for JsonWebKey<'a> {
350 fn from(key: &'a RsaPublicKey) -> JsonWebKey<'a> {
351 JsonWebKey::Rsa {
352 modulus: Cow::Owned(key.n().to_bytes_be()),
353 public_exponent: Cow::Owned(key.e().to_bytes_be()),
354 private_parts: None,
355 }
356 }
357}
358
359impl TryFrom<&JsonWebKey<'_>> for RsaPublicKey {
360 type Error = JwkError;
361
362 fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
363 let JsonWebKey::Rsa { modulus, public_exponent, .. } = jwk else {
364 return Err(JwkError::key_type(jwk, KeyType::Rsa));
365 };
366
367 let e = BigUint::from_bytes_be(public_exponent);
368 let n = BigUint::from_bytes_be(modulus);
369 Self::new(n, e).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
370 }
371}
372
373impl<'a> From<&'a RsaPrivateKey> for JsonWebKey<'a> {
379 fn from(key: &'a RsaPrivateKey) -> JsonWebKey<'a> {
380 const MSG: &str = "RsaPrivateKey must have at least 2 prime factors";
381
382 let p = key.primes().get(0).expect(MSG);
383 let q = key.primes().get(1).expect(MSG);
384
385 let private_parts = RsaPrivateParts {
386 private_exponent: SecretBytes::owned(key.d().to_bytes_be()),
387 prime_factor_p: SecretBytes::owned(p.to_bytes_be()),
388 prime_factor_q: SecretBytes::owned(q.to_bytes_be()),
389 p_crt_exponent: None,
390 q_crt_exponent: None,
391 q_crt_coefficient: None,
392 other_prime_factors: key.primes()[2..]
393 .iter()
394 .map(|factor| RsaPrimeFactor {
395 factor: SecretBytes::owned(factor.to_bytes_be()),
396 crt_exponent: None,
397 crt_coefficient: None,
398 })
399 .collect(),
400 };
401
402 JsonWebKey::Rsa {
403 modulus: Cow::Owned(key.n().to_bytes_be()),
404 public_exponent: Cow::Owned(key.e().to_bytes_be()),
405 private_parts: Some(private_parts),
406 }
407 }
408}
409
410impl TryFrom<&JsonWebKey<'_>> for RsaPrivateKey {
415 type Error = JwkError;
416
417 fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
418 let JsonWebKey::Rsa { modulus, public_exponent, private_parts } = jwk else {
419 return Err(JwkError::key_type(jwk, KeyType::Rsa));
420 };
421
422 let RsaPrivateParts { private_exponent: d, prime_factor_p, prime_factor_q, other_prime_factors, .. } =
423 private_parts.as_ref().ok_or_else(|| JwkError::NoField("d".into()))?;
424
425 let e = BigUint::from_bytes_be(public_exponent);
426 let n = BigUint::from_bytes_be(modulus);
427 let d = BigUint::from_bytes_be(d);
428
429 let mut factors = Vec::with_capacity(2 + other_prime_factors.len());
430 factors.push(BigUint::from_bytes_be(prime_factor_p));
431 factors.push(BigUint::from_bytes_be(prime_factor_q));
432 factors.extend(other_prime_factors.iter().map(|prime| BigUint::from_bytes_be(&prime.factor)));
433
434 let key = Self::from_components(n, e, d, factors);
435 let key = key.map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
436 key.validate().map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
437 Ok(key)
438 }
439}