1#![allow(missing_docs)]
5
6#[cfg(feature = "alloc")]
7use alloc::boxed::Box;
8use core::ops::{Rem, Shr, ShrAssign};
9
10use modmath::{
11 compute_n_prime_newton, compute_r2_mod_n, compute_r_mod_n, type_bit_width, CiosMontMul, Parity,
12 WideMul,
13};
14use num_traits::ops::overflowing::OverflowingAdd;
15use num_traits::ops::wrapping::{WrappingAdd, WrappingMul, WrappingSub};
16use num_traits::{One, Zero};
17use zeroize::Zeroize;
18
19use crate::{
20 algorithms::rsa::rsa_encrypt,
21 errors::{Error, Result},
22 key::GenericRsaPublicKey,
23 traits::{
24 modular::{
25 IntegerResize, IntoMontyForm, ModulusParams, NonZero, Odd, Pow, PowBoundedExp,
26 TryFromBeBytes, UnsignedModularInt,
27 },
28 FixedWidthUnsignedInt,
29 },
30};
31
32pub trait ModMathInt:
33 FixedWidthUnsignedInt
34 + From<u8>
35 + PartialOrd
36 + One
37 + Zero
38 + Parity
39 + OverflowingAdd
40 + WideMul
41 + CiosMontMul
42 + WrappingAdd
43 + WrappingMul
44 + WrappingSub
45 + Rem<Output = Self>
46 + Shr<usize, Output = Self>
47 + ShrAssign<usize>
48{
49}
50
51impl<T> ModMathInt for T where
52 T: FixedWidthUnsignedInt
53 + From<u8>
54 + PartialOrd
55 + One
56 + Zero
57 + Parity
58 + OverflowingAdd
59 + WideMul
60 + CiosMontMul
61 + WrappingAdd
62 + WrappingMul
63 + WrappingSub
64 + Rem<Output = Self>
65 + Shr<usize, Output = Self>
66 + ShrAssign<usize>
67{
68}
69
70#[cfg(feature = "alloc")]
71fn wrap_value<T>(value: T) -> ModMathValue<T> {
72 ModMathValue(value)
73}
74
75#[cfg(not(feature = "alloc"))]
76fn wrap_value<T>(value: T) -> ModMathValue<T> {
77 value
78}
79
80#[cfg(feature = "alloc")]
81fn unwrap_value<T: Copy>(value: &ModMathValue<T>) -> T {
82 value.0
83}
84
85#[cfg(feature = "alloc")]
86fn unwrap_value_ref<T>(value: &ModMathValue<T>) -> &T {
87 &value.0
88}
89
90#[cfg(not(feature = "alloc"))]
91fn unwrap_value_ref<T>(value: &ModMathValue<T>) -> &T {
92 value
93}
94
95#[cfg(not(feature = "alloc"))]
96fn unwrap_value<T: Copy>(value: &ModMathValue<T>) -> T {
97 *value
98}
99
100#[cfg(feature = "alloc")]
101#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)]
102pub struct ModMathValue<T>(pub T);
103
104#[cfg(feature = "alloc")]
105impl<T> ModMathValue<T> {
106 pub fn from_inner(inner: T) -> Self {
107 Self(inner)
108 }
109
110 pub fn inner(&self) -> &T {
111 &self.0
112 }
113}
114
115#[cfg(feature = "alloc")]
116impl<T> Zeroize for ModMathValue<T>
117where
118 T: Zeroize,
119{
120 fn zeroize(&mut self) {
121 self.0.zeroize();
122 }
123}
124
125#[cfg(feature = "alloc")]
126impl<T> From<u8> for ModMathValue<T>
127where
128 T: ModMathInt,
129{
130 fn from(value: u8) -> Self {
131 Self(<T as From<u8>>::from(value))
132 }
133}
134
135#[cfg(feature = "alloc")]
136impl<T> IntegerResize for ModMathValue<T>
137where
138 T: ModMathInt,
139{
140 type Output = Self;
141
142 fn resize_unchecked(self, _at_least_bits_precision: u32) -> Self::Output {
143 self
144 }
145
146 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
147 if at_least_bits_precision >= self.bits_precision() {
148 Some(self)
149 } else {
150 None
151 }
152 }
153}
154
155#[cfg(feature = "alloc")]
156impl<T> UnsignedModularInt for ModMathValue<T>
157where
158 T: ModMathInt,
159{
160 type Bytes = <T as FixedWidthUnsignedInt>::Bytes;
161
162 fn leading_zeros(&self) -> u32 {
163 FixedWidthUnsignedInt::leading_zeros(&self.0)
164 }
165
166 fn to_be_bytes(&self) -> Self::Bytes {
167 FixedWidthUnsignedInt::to_be_bytes(&self.0)
168 }
169
170 #[cfg(feature = "alloc")]
171 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
172 let bytes = self.to_be_bytes();
173 let bytes = bytes.as_ref();
174 let first_non_zero = bytes
175 .iter()
176 .position(|b| *b != 0)
177 .unwrap_or(bytes.len().saturating_sub(1));
178 bytes[first_non_zero..].to_vec().into_boxed_slice()
179 }
180
181 fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self {
182 Self(self.0 % modulus.as_ref().0)
183 }
184
185 fn as_nz_ref(&self) -> NonZero<Self> {
186 NonZero::new(*self).expect("value is non-zero")
187 }
188
189 fn bits(&self) -> u32 {
190 self.bits_precision() - self.leading_zeros()
191 }
192
193 fn bits_precision(&self) -> u32 {
194 FixedWidthUnsignedInt::bits_precision(&self.0)
195 }
196}
197
198#[cfg(feature = "alloc")]
199impl<T> TryFromBeBytes for ModMathValue<T>
200where
201 T: ModMathInt,
202{
203 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
204 Ok(Self(
205 <T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(bytes)?,
206 ))
207 }
208}
209
210#[cfg(not(feature = "alloc"))]
211pub type ModMathValue<T> = T;
212
213#[derive(Clone, Debug)]
214pub struct ModMathParams<T: ModMathInt> {
215 modulus: Odd<ModMathValue<T>>,
216 n_prime: T,
219 r_mod_n: T,
221 r2_mod_n: T,
223}
224
225impl<T: ModMathInt> ModMathParams<T> {
226 pub fn new(modulus: T) -> Result<Self> {
228 let modulus_odd = Odd::new(wrap_value(modulus)).ok_or(Error::InvalidModulus)?;
229 let w = type_bit_width::<T>();
230 let n_prime = compute_n_prime_newton(modulus, w);
231 let r_mod_n = compute_r_mod_n(modulus, w);
232 let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
233 Ok(Self {
234 modulus: modulus_odd,
235 n_prime,
236 r_mod_n,
237 r2_mod_n,
238 })
239 }
240}
241
242pub fn public_key_from_be_bytes<T>(
245 modulus: &[u8],
246 exponent: u32,
247) -> Result<GenericRsaPublicKey<ModMathValue<T>, ModMathParams<T>>>
248where
249 T: ModMathInt,
250{
251 let n = wrap_value(<T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(
252 modulus,
253 )?);
254 let exponent = exponent.to_be_bytes();
255 let e = wrap_value(<T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(
256 &exponent,
257 )?);
258 GenericRsaPublicKey::from_components(n, e, ModMathParams::new(unwrap_value(&n))?)
259}
260
261pub fn rsa_public_op<T>(
265 key: &GenericRsaPublicKey<ModMathValue<T>, ModMathParams<T>>,
266 input: &[u8],
267) -> Result<<ModMathValue<T> as UnsignedModularInt>::Bytes>
268where
269 T: ModMathInt,
270{
271 let input = wrap_value(<T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(
272 input,
273 )?);
274 Ok(rsa_encrypt(key, &input)?.to_be_bytes())
275}
276
277#[derive(Clone, Debug)]
281pub struct ModMathForm<T: ModMathInt> {
282 integer_mont: ModMathValue<T>,
283 params: ModMathParams<T>,
284}
285
286impl<T: ModMathInt> IntoMontyForm<ModMathParams<T>> for ModMathForm<T> {
287 fn from_reduced(integer: ModMathValue<T>, params: &ModMathParams<T>) -> Self {
288 let a_mont = T::cios_mont_mul(
290 unwrap_value_ref(&integer),
291 ¶ms.r2_mod_n,
292 unwrap_value_ref(params.modulus.as_ref()),
293 ¶ms.n_prime,
294 )
295 .expect("CIOS Montgomery mul requires non-empty word array");
296 Self {
297 integer_mont: wrap_value(a_mont),
298 params: params.clone(),
299 }
300 }
301}
302
303impl<T: ModMathInt> ModMathForm<T> {
304 fn pow_loop(&self, exp_raw: T) -> T {
305 let modulus = unwrap_value_ref(self.params.modulus.as_ref());
306 let n_prime = &self.params.n_prime;
307 let mut base_mont = unwrap_value(&self.integer_mont);
308 let mut result_mont = self.params.r_mod_n;
310 let mut e = exp_raw;
311 while !e.is_zero() {
312 if e.is_odd() {
313 result_mont = T::cios_mont_mul(&result_mont, &base_mont, modulus, n_prime)
314 .expect("CIOS Montgomery mul requires non-empty word array");
315 }
316 base_mont = T::cios_mont_mul(&base_mont, &base_mont, modulus, n_prime)
317 .expect("CIOS Montgomery mul requires non-empty word array");
318 e >>= 1;
319 }
320 result_mont
321 }
322
323 fn to_reduced(&self) -> T {
324 let one = <T as From<u8>>::from(1u8);
326 T::cios_mont_mul(
327 unwrap_value_ref(&self.integer_mont),
328 &one,
329 unwrap_value_ref(self.params.modulus.as_ref()),
330 &self.params.n_prime,
331 )
332 .expect("CIOS Montgomery mul requires non-empty word array")
333 }
334}
335
336impl<T: ModMathInt> Pow<ModMathParams<T>> for ModMathForm<T> {
337 fn pow(&self, exp: &ModMathValue<T>) -> Self {
338 let result_mont = self.pow_loop(unwrap_value(exp));
339 Self {
340 integer_mont: wrap_value(result_mont),
341 params: self.params.clone(),
342 }
343 }
344}
345
346impl<T: ModMathInt> PowBoundedExp<ModMathParams<T>> for ModMathForm<T> {
347 fn pow_bounded_exp(&self, exp: &ModMathValue<T>, _exp_bits: u32) -> Self {
348 let result_mont = self.pow_loop(unwrap_value(exp));
351 Self {
352 integer_mont: wrap_value(result_mont),
353 params: self.params.clone(),
354 }
355 }
356
357 fn retrieve(&self) -> ModMathValue<T> {
358 wrap_value(self.to_reduced())
359 }
360}
361
362impl<T: ModMathInt> ModulusParams for ModMathParams<T> {
363 type Modulus = ModMathValue<T>;
364 type MontgomeryForm = ModMathForm<T>;
365
366 fn modulus(&self) -> &Odd<Self::Modulus> {
367 &self.modulus
368 }
369
370 fn bits_precision(&self) -> u32 {
371 self.modulus.bits_precision()
372 }
373}
374
375#[cfg(test)]
376#[cfg(all(feature = "alloc", feature = "private-key"))]
377mod tests {
378 use fixed_bigint::FixedUInt;
379 use rand::rngs::ChaCha8Rng;
380 use rand_core::SeedableRng;
381 use sha1::Sha1;
382 use signature::hazmat::PrehashVerifier;
383
384 use super::{public_key_from_be_bytes, ModMathParams, ModMathValue};
385 use crate::key::GenericRsaPublicKey;
386 use crate::pkcs1v15::{GenericEncryptingKey, GenericSignature, GenericVerifyingKey};
387 use crate::{traits::RandomizedEncryptor, BoxedUint, Pkcs1v15Encrypt, RsaPublicKey};
388
389 #[test]
390 fn verify_pkcs1v15_signature_with_modmath_fixed_uint() {
391 type U512 = FixedUInt<u8, 64>;
392
393 let digest: [u8; 20] = [
394 0x43, 0x0c, 0xe3, 0x4d, 0x02, 0x07, 0x24, 0xed, 0x75, 0xa1, 0x96, 0xdf, 0xc2, 0xad,
395 0x67, 0xc7, 0x77, 0x72, 0xd1, 0x69,
396 ];
397 let modulus: [u8; 64] = [
398 0x96, 0x9D, 0x03, 0xFF, 0xA9, 0x8D, 0x88, 0x8F, 0x3A, 0xA4, 0xF2, 0xFE, 0xD2, 0x32,
399 0xE6, 0x1C, 0x4A, 0xCF, 0x06, 0x63, 0xA9, 0x2F, 0x99, 0x03, 0x4C, 0xF7, 0xB7, 0x24,
400 0x5A, 0x1A, 0x1E, 0x5E, 0xAF, 0xA5, 0x65, 0xAF, 0xB9, 0x0B, 0xAB, 0x22, 0x85, 0x71,
401 0x2F, 0xAA, 0x50, 0x39, 0x39, 0xA0, 0x65, 0xFB, 0x60, 0xDD, 0x08, 0x28, 0xA3, 0x84,
402 0xF2, 0x6D, 0x8A, 0xFC, 0x28, 0x6D, 0xF6, 0xCF,
403 ];
404 let signature: [u8; 64] = [
405 0x45, 0x53, 0xF3, 0xAF, 0x16, 0xAF, 0x63, 0x97, 0xB0, 0xD3, 0x2F, 0x8A, 0xEC, 0xD5,
406 0x4C, 0xF1, 0xF3, 0xD0, 0x0C, 0x9F, 0x42, 0xDC, 0x68, 0xCB, 0xD7, 0x05, 0xCE, 0xA5,
407 0xA9, 0x70, 0x95, 0x3E, 0xC0, 0xBC, 0x4A, 0x18, 0xED, 0x91, 0xA3, 0x5D, 0x66, 0xEC,
408 0xDA, 0x4A, 0x83, 0x32, 0xCF, 0xC3, 0xA3, 0xAB, 0x21, 0xAD, 0x59, 0xB2, 0x2E, 0x87,
409 0xC2, 0x73, 0xFF, 0x08, 0x88, 0xDD, 0x4D, 0xE0,
410 ];
411
412 let key = public_key_from_be_bytes::<U512>(&modulus, 3).unwrap();
413 let verifying_key = GenericVerifyingKey::<Sha1, _, _>::new(key);
414 let signature =
415 GenericSignature::from(ModMathValue::from_inner(U512::from_be_bytes(&signature)));
416 verifying_key.verify_prehash(&digest, &signature).unwrap();
417 }
418
419 #[test]
420 fn verify_pkcs1v15_signature_with_modmath_fixed_uint32() {
421 type U512 = FixedUInt<u32, 16>;
422
423 let digest: [u8; 20] = [
424 0x43, 0x0c, 0xe3, 0x4d, 0x02, 0x07, 0x24, 0xed, 0x75, 0xa1, 0x96, 0xdf, 0xc2, 0xad,
425 0x67, 0xc7, 0x77, 0x72, 0xd1, 0x69,
426 ];
427 let modulus: [u8; 64] = [
428 0x96, 0x9D, 0x03, 0xFF, 0xA9, 0x8D, 0x88, 0x8F, 0x3A, 0xA4, 0xF2, 0xFE, 0xD2, 0x32,
429 0xE6, 0x1C, 0x4A, 0xCF, 0x06, 0x63, 0xA9, 0x2F, 0x99, 0x03, 0x4C, 0xF7, 0xB7, 0x24,
430 0x5A, 0x1A, 0x1E, 0x5E, 0xAF, 0xA5, 0x65, 0xAF, 0xB9, 0x0B, 0xAB, 0x22, 0x85, 0x71,
431 0x2F, 0xAA, 0x50, 0x39, 0x39, 0xA0, 0x65, 0xFB, 0x60, 0xDD, 0x08, 0x28, 0xA3, 0x84,
432 0xF2, 0x6D, 0x8A, 0xFC, 0x28, 0x6D, 0xF6, 0xCF,
433 ];
434 let signature: [u8; 64] = [
435 0x45, 0x53, 0xF3, 0xAF, 0x16, 0xAF, 0x63, 0x97, 0xB0, 0xD3, 0x2F, 0x8A, 0xEC, 0xD5,
436 0x4C, 0xF1, 0xF3, 0xD0, 0x0C, 0x9F, 0x42, 0xDC, 0x68, 0xCB, 0xD7, 0x05, 0xCE, 0xA5,
437 0xA9, 0x70, 0x95, 0x3E, 0xC0, 0xBC, 0x4A, 0x18, 0xED, 0x91, 0xA3, 0x5D, 0x66, 0xEC,
438 0xDA, 0x4A, 0x83, 0x32, 0xCF, 0xC3, 0xA3, 0xAB, 0x21, 0xAD, 0x59, 0xB2, 0x2E, 0x87,
439 0xC2, 0x73, 0xFF, 0x08, 0x88, 0xDD, 0x4D, 0xE0,
440 ];
441
442 let n = U512::from_be_bytes(&modulus);
443 let e = U512::from(3u8);
444 let key = GenericRsaPublicKey::from_components(
445 ModMathValue::from_inner(n),
446 ModMathValue::from_inner(e),
447 ModMathParams::new(n).unwrap(),
448 )
449 .unwrap();
450 let verifying_key = GenericVerifyingKey::<Sha1, _, _>::new(key);
451 let signature =
452 GenericSignature::from(ModMathValue::from_inner(U512::from_be_bytes(&signature)));
453 verifying_key.verify_prehash(&digest, &signature).unwrap();
454 }
455
456 #[test]
457 fn encrypt_pkcs1v15_with_modmath_fixed_uint_matches_boxeduint() {
458 type U512 = FixedUInt<u8, 64>;
459
460 let modulus: [u8; 64] = [
461 0x96, 0x9D, 0x03, 0xFF, 0xA9, 0x8D, 0x88, 0x8F, 0x3A, 0xA4, 0xF2, 0xFE, 0xD2, 0x32,
462 0xE6, 0x1C, 0x4A, 0xCF, 0x06, 0x63, 0xA9, 0x2F, 0x99, 0x03, 0x4C, 0xF7, 0xB7, 0x24,
463 0x5A, 0x1A, 0x1E, 0x5E, 0xAF, 0xA5, 0x65, 0xAF, 0xB9, 0x0B, 0xAB, 0x22, 0x85, 0x71,
464 0x2F, 0xAA, 0x50, 0x39, 0x39, 0xA0, 0x65, 0xFB, 0x60, 0xDD, 0x08, 0x28, 0xA3, 0x84,
465 0xF2, 0x6D, 0x8A, 0xFC, 0x28, 0x6D, 0xF6, 0xCF,
466 ];
467 let msg = b"hello world!";
468
469 let modmath_key = public_key_from_be_bytes::<U512>(&modulus, 3).unwrap();
470 let boxed_key = RsaPublicKey::new(
471 BoxedUint::from_be_slice(&modulus, 512).unwrap(),
472 3u64.into(),
473 )
474 .unwrap();
475
476 let mut modmath_rng = ChaCha8Rng::from_seed([42; 32]);
477 let mut boxed_rng = ChaCha8Rng::from_seed([42; 32]);
478 let mut storage = [0u8; 64];
479
480 let modmath_ciphertext = GenericEncryptingKey::new(modmath_key)
481 .encrypt_with_rng_into(&mut modmath_rng, msg, &mut storage)
482 .unwrap();
483 let boxed_ciphertext = boxed_key
484 .encrypt(&mut boxed_rng, Pkcs1v15Encrypt, msg)
485 .unwrap();
486
487 assert_eq!(modmath_ciphertext, boxed_ciphertext.as_slice());
488 }
489}