jwt_compact_frame/alg/
rsa.rs

1//! RSA-based JWT algorithms: `RS*` and `PS*`.
2
3pub use rsa::{errors::Error as RsaError, RsaPrivateKey, RsaPublicKey};
4
5use rand_core::{CryptoRng, RngCore};
6use rsa::{
7	traits::{PrivateKeyParts, PublicKeyParts},
8	BigUint, Pkcs1v15Sign, Pss,
9};
10use sha2::{Digest, Sha256, Sha384, Sha512};
11
12use core::{fmt, str::FromStr};
13
14use crate::{
15	alg::{SecretBytes, StrongKey, WeakKeyError},
16	alloc::{Cow, String, ToOwned, Vec},
17	jwk::{JsonWebKey, JwkError, KeyType, RsaPrimeFactor, RsaPrivateParts},
18	Algorithm, AlgorithmSignature,
19};
20use scale_info::prelude::boxed::Box;
21
22/// RSA signature.
23#[derive(Debug)]
24#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
25pub struct RsaSignature(Vec<u8>);
26
27impl AlgorithmSignature for RsaSignature {
28	const LENGTH: Option<core::num::NonZeroUsize> = None;
29
30	fn try_from_slice(bytes: &[u8]) -> anyhow::Result<Self> {
31		Ok(RsaSignature(bytes.to_vec()))
32	}
33
34	fn as_bytes(&self) -> Cow<'_, [u8]> {
35		Cow::Borrowed(&self.0)
36	}
37}
38
39/// RSA hash algorithm.
40#[derive(Debug, Copy, Clone, Eq, PartialEq)]
41enum HashAlg {
42	Sha256,
43	Sha384,
44	Sha512,
45}
46
47impl HashAlg {
48	fn digest(self, message: &[u8]) -> Box<[u8]> {
49		match self {
50			Self::Sha256 => {
51				let digest: [u8; 32] = *(Sha256::digest(message).as_ref());
52				Box::new(digest)
53			},
54			Self::Sha384 => {
55				let mut digest = [0_u8; 48];
56				digest.copy_from_slice(Sha384::digest(message).as_ref());
57				Box::new(digest)
58			},
59			Self::Sha512 => {
60				let mut digest = [0_u8; 64];
61				digest.copy_from_slice(Sha512::digest(message).as_ref());
62				Box::new(digest)
63			},
64		}
65	}
66}
67
68/// RSA padding algorithm.
69#[derive(Debug, Copy, Clone, Eq, PartialEq)]
70enum Padding {
71	Pkcs1v15,
72	Pss,
73}
74
75#[derive(Debug)]
76enum PaddingScheme {
77	Pkcs1v15(Pkcs1v15Sign),
78	Pss(Pss),
79}
80
81/// Bit length of an RSA key modulus (aka RSA key length).
82#[derive(Debug, Copy, Clone, Eq, PartialEq)]
83#[non_exhaustive]
84#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
85pub enum ModulusBits {
86	/// 2048 bits. This is the minimum recommended key length as of 2020.
87	TwoKibibytes,
88	/// 3072 bits.
89	ThreeKibibytes,
90	/// 4096 bits.
91	FourKibibytes,
92}
93
94impl ModulusBits {
95	/// Converts this length to the numeric value.
96	pub fn bits(self) -> usize {
97		match self {
98			Self::TwoKibibytes => 2_048,
99			Self::ThreeKibibytes => 3_072,
100			Self::FourKibibytes => 4_096,
101		}
102	}
103
104	fn is_valid_bits(bits: usize) -> bool {
105		matches!(bits, 2_048 | 3_072 | 4_096)
106	}
107}
108
109impl TryFrom<usize> for ModulusBits {
110	type Error = ModulusBitsError;
111
112	fn try_from(value: usize) -> Result<Self, Self::Error> {
113		match value {
114			2_048 => Ok(Self::TwoKibibytes),
115			3_072 => Ok(Self::ThreeKibibytes),
116			4_096 => Ok(Self::FourKibibytes),
117			_ => Err(ModulusBitsError(())),
118		}
119	}
120}
121
122/// Error type returned when a conversion of an integer into `ModulusBits` fails.
123#[derive(Debug)]
124#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
125pub struct ModulusBitsError(());
126
127impl fmt::Display for ModulusBitsError {
128	fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
129		formatter.write_str(
130			"Unsupported bit length of RSA modulus; only lengths 2048, 3072 and 4096 \
131            are supported.",
132		)
133	}
134}
135
136#[cfg(feature = "std")]
137impl std::error::Error for ModulusBitsError {}
138
139/// Integrity algorithm using [RSA] digital signatures.
140///
141/// Depending on the variation, the algorithm employs PKCS#1 v1.5 or PSS padding and
142/// one of the hash functions from the SHA-2 family: SHA-256, SHA-384, or SHA-512.
143/// See [RFC 7518] for more details. Depending on the chosen parameters,
144/// the name of the algorithm is one of `RS256`, `RS384`, `RS512`, `PS256`, `PS384`, `PS512`:
145///
146/// - `R` / `P` denote the padding scheme: PKCS#1 v1.5 for `R`, PSS for `P`
147/// - `256` / `384` / `512` denote the hash function
148///
149/// The length of RSA keys is not unequivocally specified by the algorithm; nevertheless,
150/// it **MUST** be at least 2048 bits as per RFC 7518. To minimize risks of misconfiguration,
151/// use [`StrongAlg`](super::StrongAlg) wrapper around `Rsa`:
152///
153/// ```
154/// # use jwt_compact_frame::alg::{StrongAlg, Rsa};
155/// const ALG: StrongAlg<Rsa> = StrongAlg(Rsa::rs256());
156/// // `ALG` will not support RSA keys with unsecure lengths by design!
157/// ```
158///
159/// [RSA]: https://en.wikipedia.org/wiki/RSA_(cryptosystem)
160/// [RFC 7518]: https://www.rfc-editor.org/rfc/rfc7518.html
161#[derive(Debug, Clone, Copy, PartialEq, Eq)]
162#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
163pub struct Rsa {
164	hash_alg: HashAlg,
165	padding_alg: Padding,
166}
167
168impl Algorithm for Rsa {
169	type Signature = RsaSignature;
170	type SigningKey = RsaPrivateKey;
171	type VerifyingKey = RsaPublicKey;
172
173	fn name(&self) -> Cow<'static, str> {
174		Cow::Borrowed(self.alg_name())
175	}
176
177	fn sign(&self, signing_key: &Self::SigningKey, message: &[u8]) -> Self::Signature {
178		let digest = self.hash_alg.digest(message);
179		let digest = digest.as_ref();
180		let signing_result = match self.padding_scheme() {
181			PaddingScheme::Pkcs1v15(padding) => signing_key.sign_with_rng(&mut rand_core::OsRng, padding, digest),
182			PaddingScheme::Pss(padding) => signing_key.sign_with_rng(&mut rand_core::OsRng, padding, digest),
183		};
184		RsaSignature(signing_result.expect("Unexpected RSA signature failure"))
185	}
186
187	fn verify_signature(
188		&self,
189		signature: &Self::Signature,
190		verifying_key: &Self::VerifyingKey,
191		message: &[u8],
192	) -> bool {
193		let digest = self.hash_alg.digest(message);
194		let verify_result = match self.padding_scheme() {
195			PaddingScheme::Pkcs1v15(padding) => verifying_key.verify(padding, &digest, &signature.0),
196			PaddingScheme::Pss(padding) => verifying_key.verify(padding, &digest, &signature.0),
197		};
198		verify_result.is_ok()
199	}
200}
201
202impl Rsa {
203	const fn new(hash_alg: HashAlg, padding_alg: Padding) -> Self {
204		Rsa { hash_alg, padding_alg }
205	}
206
207	/// RSA with SHA-256 and PKCS#1 v1.5 padding.
208	pub const fn rs256() -> Rsa {
209		Rsa::new(HashAlg::Sha256, Padding::Pkcs1v15)
210	}
211
212	/// RSA with SHA-384 and PKCS#1 v1.5 padding.
213	pub const fn rs384() -> Rsa {
214		Rsa::new(HashAlg::Sha384, Padding::Pkcs1v15)
215	}
216
217	/// RSA with SHA-512 and PKCS#1 v1.5 padding.
218	pub const fn rs512() -> Rsa {
219		Rsa::new(HashAlg::Sha512, Padding::Pkcs1v15)
220	}
221
222	/// RSA with SHA-256 and PSS padding.
223	pub const fn ps256() -> Rsa {
224		Rsa::new(HashAlg::Sha256, Padding::Pss)
225	}
226
227	/// RSA with SHA-384 and PSS padding.
228	pub const fn ps384() -> Rsa {
229		Rsa::new(HashAlg::Sha384, Padding::Pss)
230	}
231
232	/// RSA with SHA-512 and PSS padding.
233	pub const fn ps512() -> Rsa {
234		Rsa::new(HashAlg::Sha512, Padding::Pss)
235	}
236
237	/// RSA based on the specified algorithm name.
238	///
239	/// # Panics
240	///
241	/// - Panics if the name is not one of the six RSA-based JWS algorithms. Prefer using
242	///   the [`FromStr`] trait if the conversion is potentially fallible.
243	pub fn with_name(name: &str) -> Self {
244		name.parse().unwrap()
245	}
246
247	fn padding_scheme(self) -> PaddingScheme {
248		match self.padding_alg {
249			Padding::Pkcs1v15 => PaddingScheme::Pkcs1v15(match self.hash_alg {
250				HashAlg::Sha256 => Pkcs1v15Sign::new::<Sha256>(),
251				HashAlg::Sha384 => Pkcs1v15Sign::new::<Sha384>(),
252				HashAlg::Sha512 => Pkcs1v15Sign::new::<Sha512>(),
253			}),
254			Padding::Pss => {
255				// The salt length needs to be set to the size of hash function output;
256				// see https://www.rfc-editor.org/rfc/rfc7518.html#section-3.5.
257				PaddingScheme::Pss(match self.hash_alg {
258					HashAlg::Sha256 => Pss::new_with_salt::<Sha256>(Sha256::output_size()),
259					HashAlg::Sha384 => Pss::new_with_salt::<Sha384>(Sha384::output_size()),
260					HashAlg::Sha512 => Pss::new_with_salt::<Sha512>(Sha512::output_size()),
261				})
262			},
263		}
264	}
265
266	fn alg_name(self) -> &'static str {
267		match (self.padding_alg, self.hash_alg) {
268			(Padding::Pkcs1v15, HashAlg::Sha256) => "RS256",
269			(Padding::Pkcs1v15, HashAlg::Sha384) => "RS384",
270			(Padding::Pkcs1v15, HashAlg::Sha512) => "RS512",
271			(Padding::Pss, HashAlg::Sha256) => "PS256",
272			(Padding::Pss, HashAlg::Sha384) => "PS384",
273			(Padding::Pss, HashAlg::Sha512) => "PS512",
274		}
275	}
276
277	/// Generates a new key pair with the specified modulus bit length (aka key length).
278	pub fn generate<R: CryptoRng + RngCore>(
279		rng: &mut R,
280		modulus_bits: ModulusBits,
281	) -> rsa::errors::Result<(StrongKey<RsaPrivateKey>, StrongKey<RsaPublicKey>)> {
282		let signing_key = RsaPrivateKey::new(rng, modulus_bits.bits())?;
283		let verifying_key = signing_key.to_public_key();
284		Ok((StrongKey(signing_key), StrongKey(verifying_key)))
285	}
286}
287
288impl FromStr for Rsa {
289	type Err = RsaParseError;
290
291	fn from_str(s: &str) -> Result<Self, Self::Err> {
292		Ok(match s {
293			"RS256" => Self::rs256(),
294			"RS384" => Self::rs384(),
295			"RS512" => Self::rs512(),
296			"PS256" => Self::ps256(),
297			"PS384" => Self::ps384(),
298			"PS512" => Self::ps512(),
299			_ => return Err(RsaParseError(s.to_owned())),
300		})
301	}
302}
303
304/// Errors that can occur when parsing an [`Rsa`] algorithm from a string.
305#[derive(Debug)]
306#[cfg_attr(docsrs, doc(cfg(feature = "rsa")))]
307pub struct RsaParseError(String);
308
309impl fmt::Display for RsaParseError {
310	fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
311		write!(formatter, "Invalid RSA algorithm name: {}", self.0)
312	}
313}
314
315#[cfg(feature = "std")]
316impl std::error::Error for RsaParseError {}
317
318impl StrongKey<RsaPrivateKey> {
319	/// Converts this private key to a public key.
320	pub fn to_public_key(&self) -> StrongKey<RsaPublicKey> {
321		StrongKey(self.0.to_public_key())
322	}
323}
324
325impl TryFrom<RsaPrivateKey> for StrongKey<RsaPrivateKey> {
326	type Error = WeakKeyError<RsaPrivateKey>;
327
328	fn try_from(key: RsaPrivateKey) -> Result<Self, Self::Error> {
329		if ModulusBits::is_valid_bits(key.n().bits()) {
330			Ok(StrongKey(key))
331		} else {
332			Err(WeakKeyError(key))
333		}
334	}
335}
336
337impl TryFrom<RsaPublicKey> for StrongKey<RsaPublicKey> {
338	type Error = WeakKeyError<RsaPublicKey>;
339
340	fn try_from(key: RsaPublicKey) -> Result<Self, Self::Error> {
341		if ModulusBits::is_valid_bits(key.n().bits()) {
342			Ok(StrongKey(key))
343		} else {
344			Err(WeakKeyError(key))
345		}
346	}
347}
348
349impl<'a> From<&'a RsaPublicKey> for JsonWebKey<'a> {
350	fn from(key: &'a RsaPublicKey) -> JsonWebKey<'a> {
351		JsonWebKey::Rsa {
352			modulus: Cow::Owned(key.n().to_bytes_be()),
353			public_exponent: Cow::Owned(key.e().to_bytes_be()),
354			private_parts: None,
355		}
356	}
357}
358
359impl TryFrom<&JsonWebKey<'_>> for RsaPublicKey {
360	type Error = JwkError;
361
362	fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
363		let JsonWebKey::Rsa { modulus, public_exponent, .. } = jwk else {
364			return Err(JwkError::key_type(jwk, KeyType::Rsa));
365		};
366
367		let e = BigUint::from_bytes_be(public_exponent);
368		let n = BigUint::from_bytes_be(modulus);
369		Self::new(n, e).map_err(|err| JwkError::custom(anyhow::anyhow!(err)))
370	}
371}
372
373/// ⚠ **Warning.** Contrary to [RFC 7518], this implementation does not set `dp`, `dq`, and `qi`
374/// fields in the JWK root object, as well as `d` and `t` fields for additional factors
375/// (i.e., in the `oth` array).
376///
377/// [RFC 7518]: https://tools.ietf.org/html/rfc7518#section-6.3.2
378impl<'a> From<&'a RsaPrivateKey> for JsonWebKey<'a> {
379	fn from(key: &'a RsaPrivateKey) -> JsonWebKey<'a> {
380		const MSG: &str = "RsaPrivateKey must have at least 2 prime factors";
381
382		let p = key.primes().get(0).expect(MSG);
383		let q = key.primes().get(1).expect(MSG);
384
385		let private_parts = RsaPrivateParts {
386			private_exponent: SecretBytes::owned(key.d().to_bytes_be()),
387			prime_factor_p: SecretBytes::owned(p.to_bytes_be()),
388			prime_factor_q: SecretBytes::owned(q.to_bytes_be()),
389			p_crt_exponent: None,
390			q_crt_exponent: None,
391			q_crt_coefficient: None,
392			other_prime_factors: key.primes()[2..]
393				.iter()
394				.map(|factor| RsaPrimeFactor {
395					factor: SecretBytes::owned(factor.to_bytes_be()),
396					crt_exponent: None,
397					crt_coefficient: None,
398				})
399				.collect(),
400		};
401
402		JsonWebKey::Rsa {
403			modulus: Cow::Owned(key.n().to_bytes_be()),
404			public_exponent: Cow::Owned(key.e().to_bytes_be()),
405			private_parts: Some(private_parts),
406		}
407	}
408}
409
410/// ⚠ **Warning.** Contrary to [RFC 7518] (at least, in spirit), this conversion ignores
411/// `dp`, `dq`, and `qi` fields from JWK, as well as `d` and `t` fields for additional factors.
412///
413/// [RFC 7518]: https://www.rfc-editor.org/rfc/rfc7518.html
414impl TryFrom<&JsonWebKey<'_>> for RsaPrivateKey {
415	type Error = JwkError;
416
417	fn try_from(jwk: &JsonWebKey<'_>) -> Result<Self, Self::Error> {
418		let JsonWebKey::Rsa { modulus, public_exponent, private_parts } = jwk else {
419			return Err(JwkError::key_type(jwk, KeyType::Rsa));
420		};
421
422		let RsaPrivateParts { private_exponent: d, prime_factor_p, prime_factor_q, other_prime_factors, .. } =
423			private_parts.as_ref().ok_or_else(|| JwkError::NoField("d".into()))?;
424
425		let e = BigUint::from_bytes_be(public_exponent);
426		let n = BigUint::from_bytes_be(modulus);
427		let d = BigUint::from_bytes_be(d);
428
429		let mut factors = Vec::with_capacity(2 + other_prime_factors.len());
430		factors.push(BigUint::from_bytes_be(prime_factor_p));
431		factors.push(BigUint::from_bytes_be(prime_factor_q));
432		factors.extend(other_prime_factors.iter().map(|prime| BigUint::from_bytes_be(&prime.factor)));
433
434		let key = Self::from_components(n, e, d, factors);
435		let key = key.map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
436		key.validate().map_err(|err| JwkError::custom(anyhow::anyhow!(err)))?;
437		Ok(key)
438	}
439}