1use core::fmt;
9
10use crate::public_key::bigint::{BigUint, MontgomeryCtx};
11use crate::public_key::primes::{
12 gcd, is_probable_prime, lcm, mod_inverse, mod_pow, random_probable_prime,
13};
14use crate::Csprng;
15
16#[derive(Clone, Debug, Eq, PartialEq)]
18pub struct RsaPublicKey {
19 e: BigUint,
20 n: BigUint,
21}
22
23#[derive(Clone, Eq, PartialEq)]
25pub struct RsaPrivateKey {
26 e: BigUint,
27 d: BigUint,
28 n: BigUint,
29 p: BigUint,
30 q: BigUint,
31 d_p: BigUint,
32 d_q: BigUint,
33 q_inv: BigUint,
34 p_ctx: MontgomeryCtx,
35 q_ctx: MontgomeryCtx,
36}
37
38pub struct Rsa;
40
41impl RsaPublicKey {
42 #[must_use]
43 pub(crate) fn from_components(e: BigUint, n: BigUint) -> Self {
44 Self { e, n }
45 }
46
47 #[must_use]
49 pub fn exponent(&self) -> &BigUint {
50 &self.e
51 }
52
53 #[must_use]
55 pub fn modulus(&self) -> &BigUint {
56 &self.n
57 }
58
59 #[must_use]
70 pub fn encrypt_raw(&self, message: &BigUint) -> BigUint {
71 mod_pow(message, &self.e, &self.n)
72 }
73}
74
75impl RsaPrivateKey {
76 #[must_use]
78 pub(crate) fn public_exponent(&self) -> &BigUint {
79 &self.e
80 }
81
82 #[must_use]
84 pub fn exponent(&self) -> &BigUint {
85 &self.d
86 }
87
88 #[must_use]
90 pub fn modulus(&self) -> &BigUint {
91 &self.n
92 }
93
94 #[must_use]
96 pub(crate) fn prime1(&self) -> &BigUint {
97 &self.p
98 }
99
100 #[must_use]
102 pub(crate) fn prime2(&self) -> &BigUint {
103 &self.q
104 }
105
106 #[must_use]
108 pub(crate) fn crt_exponent1(&self) -> &BigUint {
109 &self.d_p
110 }
111
112 #[must_use]
114 pub(crate) fn crt_exponent2(&self) -> &BigUint {
115 &self.d_q
116 }
117
118 #[must_use]
120 pub(crate) fn crt_coefficient(&self) -> &BigUint {
121 &self.q_inv
122 }
123
124 #[must_use]
131 pub fn decrypt_raw(&self, ciphertext: &BigUint) -> BigUint {
132 let c_mod_p = ciphertext.modulo(&self.p);
138 let c_mod_q = ciphertext.modulo(&self.q);
139 let m1 = self.p_ctx.pow(&c_mod_p, &self.d_p);
140 let m2 = self.q_ctx.pow(&c_mod_q, &self.d_q);
141
142 let m2_mod_p = m2.modulo(&self.p);
147 let delta = if m1 >= m2_mod_p {
148 m1.sub_ref(&m2_mod_p)
149 } else {
150 m1.add_ref(&self.p).sub_ref(&m2_mod_p)
151 };
152 let h = BigUint::mod_mul(&self.q_inv, &delta, &self.p);
153 m2.add_ref(&self.q.mul_ref(&h))
154 }
155}
156
157impl fmt::Debug for RsaPrivateKey {
158 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
159 f.write_str("RsaPrivateKey(<redacted>)")
160 }
161}
162
163impl Rsa {
164 #[must_use]
170 pub fn from_primes_with_exponent(
171 p: &BigUint,
172 q: &BigUint,
173 exponent: &BigUint,
174 ) -> Option<(RsaPublicKey, RsaPrivateKey)> {
175 if p == q || !is_probable_prime(p) || !is_probable_prime(q) {
176 return None;
177 }
178 if exponent <= &BigUint::one() {
179 return None;
180 }
181
182 let p_minus_one = p.sub_ref(&BigUint::one());
183 let q_minus_one = q.sub_ref(&BigUint::one());
184 let lambda = lcm(&p_minus_one, &q_minus_one);
185 if gcd(exponent, &lambda) != BigUint::one() {
186 return None;
187 }
188
189 let d = mod_inverse(exponent, &lambda)?;
190 let n = p.mul_ref(q);
191 let d_p = d.modulo(&p_minus_one);
192 let d_q = d.modulo(&q_minus_one);
193 let q_inv = mod_inverse(q, p)?;
194 let p_ctx = MontgomeryCtx::new(p)?;
195 let q_ctx = MontgomeryCtx::new(q)?;
196
197 Some((
198 RsaPublicKey {
199 e: exponent.clone(),
200 n: n.clone(),
201 },
202 RsaPrivateKey {
203 e: exponent.clone(),
204 d,
205 n,
206 p: p.clone(),
207 q: q.clone(),
208 d_p,
209 d_q,
210 q_inv,
211 p_ctx,
212 q_ctx,
213 },
214 ))
215 }
216
217 #[must_use]
228 pub fn from_primes(p: &BigUint, q: &BigUint) -> Option<(RsaPublicKey, RsaPrivateKey)> {
229 if p == q || !is_probable_prime(p) || !is_probable_prime(q) {
230 return None;
231 }
232
233 let p_minus_one = p.sub_ref(&BigUint::one());
234 let q_minus_one = q.sub_ref(&BigUint::one());
235 let lambda = lcm(&p_minus_one, &q_minus_one);
236
237 let mut exponent_bit = 16usize;
238 loop {
239 let mut exponent = BigUint::zero();
240 exponent.set_bit(exponent_bit);
241 exponent = exponent.add_ref(&BigUint::one());
242 if gcd(&exponent, &lambda) == BigUint::one() {
243 return Self::from_primes_with_exponent(p, q, &exponent);
244 }
245 exponent_bit += 1;
246 }
247 }
248
249 #[must_use]
257 pub fn generate_with_exponent<R: Csprng>(
258 rng: &mut R,
259 bits: usize,
260 exponent: &BigUint,
261 ) -> Option<(RsaPublicKey, RsaPrivateKey)> {
262 if bits < 32 {
266 return None;
267 }
268
269 let p_bits = bits / 2;
270 let q_bits = bits - p_bits;
271 loop {
272 let p = random_probable_prime(rng, p_bits)?;
273 let q = random_probable_prime(rng, q_bits)?;
274 if let Some(keypair) = Self::from_primes_with_exponent(&p, &q, exponent) {
275 return Some(keypair);
276 }
277 }
278 }
279
280 #[must_use]
283 pub fn generate<R: Csprng>(rng: &mut R, bits: usize) -> Option<(RsaPublicKey, RsaPrivateKey)> {
284 if bits < 32 {
287 return None;
288 }
289
290 let p_bits = bits / 2;
291 let q_bits = bits - p_bits;
292 loop {
293 let p = random_probable_prime(rng, p_bits)?;
294 let q = random_probable_prime(rng, q_bits)?;
295 if let Some(keypair) = Self::from_primes(&p, &q) {
296 return Some(keypair);
297 }
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::Rsa;
305 use crate::public_key::bigint::BigUint;
306 use crate::CtrDrbgAes256;
307
308 #[test]
309 fn derive_reference_key_with_default_exponent() {
310 let p = BigUint::from_u64(61);
311 let q = BigUint::from_u64(53);
312 let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
313 assert_eq!(public.modulus(), &BigUint::from_u64(3_233));
314 assert_eq!(public.exponent(), &BigUint::from_u64(65_537));
315 assert_eq!(private.exponent(), &BigUint::from_u64(413));
316 assert_eq!(private.modulus(), &BigUint::from_u64(3_233));
317 }
318
319 #[test]
320 fn roundtrip_small_messages() {
321 let p = BigUint::from_u64(61);
322 let q = BigUint::from_u64(53);
323 let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
324
325 for msg in [0u64, 1, 2, 65, 123, 3_232] {
326 let message = BigUint::from_u64(msg);
327 let ciphertext = public.encrypt_raw(&message);
328 let plaintext = private.decrypt_raw(&ciphertext);
329 assert_eq!(plaintext, message);
330 }
331 }
332
333 #[test]
334 fn exact_small_ciphertext_matches_reference() {
335 let p = BigUint::from_u64(61);
336 let q = BigUint::from_u64(53);
337 let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
338 let message = BigUint::from_u64(65);
339 let ciphertext = public.encrypt_raw(&message);
340 assert_eq!(ciphertext, BigUint::from_u64(2_790));
341 assert_eq!(private.decrypt_raw(&ciphertext), message);
342 }
343
344 #[test]
345 fn raw_rsa_is_multiplicatively_homomorphic() {
346 let p = BigUint::from_u64(61);
347 let q = BigUint::from_u64(53);
348 let (public, private) = Rsa::from_primes(&p, &q).expect("valid RSA key");
349 let left = BigUint::from_u64(12);
350 let right = BigUint::from_u64(17);
351
352 let left_cipher = public.encrypt_raw(&left);
353 let right_cipher = public.encrypt_raw(&right);
354 let combined_cipher = BigUint::mod_mul(&left_cipher, &right_cipher, public.modulus());
355 let decrypted = private.decrypt_raw(&combined_cipher);
356 let expected = BigUint::mod_mul(&left, &right, public.modulus());
357
358 assert_eq!(decrypted, expected);
359 }
360
361 #[test]
362 fn explicit_exponent_matches_classic_example() {
363 let p = BigUint::from_u64(61);
364 let q = BigUint::from_u64(53);
365 let exponent = BigUint::from_u64(17);
366 let (public, private) =
367 Rsa::from_primes_with_exponent(&p, &q, &exponent).expect("valid RSA key");
368 assert_eq!(public.exponent(), &BigUint::from_u64(17));
369 assert_eq!(private.exponent(), &BigUint::from_u64(413));
370 }
371
372 #[test]
373 fn rejects_non_invertible_exponent() {
374 let p = BigUint::from_u64(11);
375 let q = BigUint::from_u64(13);
376 let exponent = BigUint::from_u64(3);
377 assert!(Rsa::from_primes_with_exponent(&p, &q, &exponent).is_none());
378 }
379
380 #[test]
381 fn generate_keypair_roundtrip() {
382 let seed = [0x55u8; 48];
383 let mut drbg = CtrDrbgAes256::new(&seed);
384 let (public, private) = Rsa::generate(&mut drbg, 64).expect("generated RSA key");
385 assert!(public.modulus().bits() >= 63);
386 let message = BigUint::from_u64(42);
387 let ciphertext = public.encrypt_raw(&message);
388 assert_eq!(private.decrypt_raw(&ciphertext), message);
389 }
390}