Skip to main content

rsa/
modmath_support.rs

1//! Generic `modmath` backend adapters for fixed-width RSA public-key paths.
2
3// TODO: document the public surface once the trait shape settles.
4#![allow(missing_docs)]
5
6#[cfg(feature = "alloc")]
7use alloc::boxed::Box;
8use core::ops::{Rem, Shr, ShrAssign};
9
10use modmath::{
11    compute_n_prime_newton, compute_r2_mod_n, compute_r_mod_n, type_bit_width, CiosMontMul, Parity,
12    WideMul,
13};
14use num_traits::ops::overflowing::OverflowingAdd;
15use num_traits::ops::wrapping::{WrappingAdd, WrappingMul, WrappingSub};
16use num_traits::{One, Zero};
17use zeroize::Zeroize;
18
19use crate::{
20    algorithms::rsa::rsa_encrypt,
21    errors::{Error, Result},
22    key::GenericRsaPublicKey,
23    traits::{
24        modular::{
25            IntegerResize, IntoMontyForm, ModulusParams, NonZero, Odd, Pow, PowBoundedExp,
26            TryFromBeBytes, UnsignedModularInt,
27        },
28        FixedWidthUnsignedInt,
29    },
30};
31
32pub trait ModMathInt:
33    FixedWidthUnsignedInt
34    + From<u8>
35    + PartialOrd
36    + One
37    + Zero
38    + Parity
39    + OverflowingAdd
40    + WideMul
41    + CiosMontMul
42    + WrappingAdd
43    + WrappingMul
44    + WrappingSub
45    + Rem<Output = Self>
46    + Shr<usize, Output = Self>
47    + ShrAssign<usize>
48{
49}
50
51impl<T> ModMathInt for T where
52    T: FixedWidthUnsignedInt
53        + From<u8>
54        + PartialOrd
55        + One
56        + Zero
57        + Parity
58        + OverflowingAdd
59        + WideMul
60        + CiosMontMul
61        + WrappingAdd
62        + WrappingMul
63        + WrappingSub
64        + Rem<Output = Self>
65        + Shr<usize, Output = Self>
66        + ShrAssign<usize>
67{
68}
69
70#[cfg(feature = "alloc")]
71fn wrap_value<T>(value: T) -> ModMathValue<T> {
72    ModMathValue(value)
73}
74
75#[cfg(not(feature = "alloc"))]
76fn wrap_value<T>(value: T) -> ModMathValue<T> {
77    value
78}
79
80#[cfg(feature = "alloc")]
81fn unwrap_value<T: Copy>(value: &ModMathValue<T>) -> T {
82    value.0
83}
84
85#[cfg(feature = "alloc")]
86fn unwrap_value_ref<T>(value: &ModMathValue<T>) -> &T {
87    &value.0
88}
89
90#[cfg(not(feature = "alloc"))]
91fn unwrap_value_ref<T>(value: &ModMathValue<T>) -> &T {
92    value
93}
94
95#[cfg(not(feature = "alloc"))]
96fn unwrap_value<T: Copy>(value: &ModMathValue<T>) -> T {
97    *value
98}
99
100#[cfg(feature = "alloc")]
101#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)]
102pub struct ModMathValue<T>(pub T);
103
104#[cfg(feature = "alloc")]
105impl<T> ModMathValue<T> {
106    pub fn from_inner(inner: T) -> Self {
107        Self(inner)
108    }
109
110    pub fn inner(&self) -> &T {
111        &self.0
112    }
113}
114
115#[cfg(feature = "alloc")]
116impl<T> Zeroize for ModMathValue<T>
117where
118    T: Zeroize,
119{
120    fn zeroize(&mut self) {
121        self.0.zeroize();
122    }
123}
124
125#[cfg(feature = "alloc")]
126impl<T> From<u8> for ModMathValue<T>
127where
128    T: ModMathInt,
129{
130    fn from(value: u8) -> Self {
131        Self(<T as From<u8>>::from(value))
132    }
133}
134
135#[cfg(feature = "alloc")]
136impl<T> IntegerResize for ModMathValue<T>
137where
138    T: ModMathInt,
139{
140    type Output = Self;
141
142    fn resize_unchecked(self, _at_least_bits_precision: u32) -> Self::Output {
143        self
144    }
145
146    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
147        if at_least_bits_precision >= self.bits_precision() {
148            Some(self)
149        } else {
150            None
151        }
152    }
153}
154
155#[cfg(feature = "alloc")]
156impl<T> UnsignedModularInt for ModMathValue<T>
157where
158    T: ModMathInt,
159{
160    type Bytes = <T as FixedWidthUnsignedInt>::Bytes;
161
162    fn leading_zeros(&self) -> u32 {
163        FixedWidthUnsignedInt::leading_zeros(&self.0)
164    }
165
166    fn to_be_bytes(&self) -> Self::Bytes {
167        FixedWidthUnsignedInt::to_be_bytes(&self.0)
168    }
169
170    #[cfg(feature = "alloc")]
171    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
172        let bytes = self.to_be_bytes();
173        let bytes = bytes.as_ref();
174        let first_non_zero = bytes
175            .iter()
176            .position(|b| *b != 0)
177            .unwrap_or(bytes.len().saturating_sub(1));
178        bytes[first_non_zero..].to_vec().into_boxed_slice()
179    }
180
181    fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self {
182        Self(self.0 % modulus.as_ref().0)
183    }
184
185    fn as_nz_ref(&self) -> NonZero<Self> {
186        NonZero::new(*self).expect("value is non-zero")
187    }
188
189    fn bits(&self) -> u32 {
190        self.bits_precision() - self.leading_zeros()
191    }
192
193    fn bits_precision(&self) -> u32 {
194        FixedWidthUnsignedInt::bits_precision(&self.0)
195    }
196}
197
198#[cfg(feature = "alloc")]
199impl<T> TryFromBeBytes for ModMathValue<T>
200where
201    T: ModMathInt,
202{
203    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
204        Ok(Self(
205            <T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(bytes)?,
206        ))
207    }
208}
209
210#[cfg(not(feature = "alloc"))]
211pub type ModMathValue<T> = T;
212
213#[derive(Clone, Debug)]
214pub struct ModMathParams<T: ModMathInt> {
215    modulus: Odd<ModMathValue<T>>,
216    // Montgomery constants for R = 2^W, where W = type_bit_width::<T>().
217    // n_prime satisfies modulus * n_prime ≡ -1 (mod R).
218    n_prime: T,
219    // r_mod_n = R mod modulus = 2^W mod modulus.  Also serves as 1 in Montgomery form.
220    r_mod_n: T,
221    // r2_mod_n = R^2 mod modulus.  Used by wide_montgomery_mul to convert into Montgomery form.
222    r2_mod_n: T,
223}
224
225impl<T: ModMathInt> ModMathParams<T> {
226    /// Create modular arithmetic parameters for an odd, non-zero modulus.
227    pub fn new(modulus: T) -> Result<Self> {
228        let modulus_odd = Odd::new(wrap_value(modulus)).ok_or(Error::InvalidModulus)?;
229        let w = type_bit_width::<T>();
230        let n_prime = compute_n_prime_newton(modulus, w);
231        let r_mod_n = compute_r_mod_n(modulus, w);
232        let r2_mod_n = compute_r2_mod_n(r_mod_n, modulus, w);
233        Ok(Self {
234            modulus: modulus_odd,
235            n_prime,
236            r_mod_n,
237            r2_mod_n,
238        })
239    }
240}
241
242/// Construct a public key backed by the `modmath` adapter from big-endian
243/// modulus bytes and a public exponent.
244pub fn public_key_from_be_bytes<T>(
245    modulus: &[u8],
246    exponent: u32,
247) -> Result<GenericRsaPublicKey<ModMathValue<T>, ModMathParams<T>>>
248where
249    T: ModMathInt,
250{
251    let n = wrap_value(<T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(
252        modulus,
253    )?);
254    let exponent = exponent.to_be_bytes();
255    let e = wrap_value(<T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(
256        &exponent,
257    )?);
258    GenericRsaPublicKey::from_components(n, e, ModMathParams::new(unwrap_value(&n))?)
259}
260
261/// Apply the raw RSA public operation to a fixed-width block.
262///
263/// For signature use-cases this recovers the encoded message representative.
264pub fn rsa_public_op<T>(
265    key: &GenericRsaPublicKey<ModMathValue<T>, ModMathParams<T>>,
266    input: &[u8],
267) -> Result<<ModMathValue<T> as UnsignedModularInt>::Bytes>
268where
269    T: ModMathInt,
270{
271    let input = wrap_value(<T as FixedWidthUnsignedInt>::try_from_be_bytes_vartime(
272        input,
273    )?);
274    Ok(rsa_encrypt(key, &input)?.to_be_bytes())
275}
276
277/// A value held in Montgomery form modulo a `ModMathParams` modulus.
278///
279/// `integer_mont` stores `a * R mod N`, where `R = 2^W` and `W = type_bit_width::<T>()`.
280#[derive(Clone, Debug)]
281pub struct ModMathForm<T: ModMathInt> {
282    integer_mont: ModMathValue<T>,
283    params: ModMathParams<T>,
284}
285
286impl<T: ModMathInt> IntoMontyForm<ModMathParams<T>> for ModMathForm<T> {
287    fn from_reduced(integer: ModMathValue<T>, params: &ModMathParams<T>) -> Self {
288        // a_mont = a * R mod N, computed via CIOS as a * R^2 * R^-1 mod N.
289        let a_mont = T::cios_mont_mul(
290            unwrap_value_ref(&integer),
291            &params.r2_mod_n,
292            unwrap_value_ref(params.modulus.as_ref()),
293            &params.n_prime,
294        )
295        .expect("CIOS Montgomery mul requires non-empty word array");
296        Self {
297            integer_mont: wrap_value(a_mont),
298            params: params.clone(),
299        }
300    }
301}
302
303impl<T: ModMathInt> ModMathForm<T> {
304    fn pow_loop(&self, exp_raw: T) -> T {
305        let modulus = unwrap_value_ref(self.params.modulus.as_ref());
306        let n_prime = &self.params.n_prime;
307        let mut base_mont = unwrap_value(&self.integer_mont);
308        // 1 in Montgomery form is R mod N.
309        let mut result_mont = self.params.r_mod_n;
310        let mut e = exp_raw;
311        while !e.is_zero() {
312            if e.is_odd() {
313                result_mont = T::cios_mont_mul(&result_mont, &base_mont, modulus, n_prime)
314                    .expect("CIOS Montgomery mul requires non-empty word array");
315            }
316            base_mont = T::cios_mont_mul(&base_mont, &base_mont, modulus, n_prime)
317                .expect("CIOS Montgomery mul requires non-empty word array");
318            e >>= 1;
319        }
320        result_mont
321    }
322
323    fn to_reduced(&self) -> T {
324        // a_mont * 1 * R^-1 mod N = a (regular form).
325        let one = <T as From<u8>>::from(1u8);
326        T::cios_mont_mul(
327            unwrap_value_ref(&self.integer_mont),
328            &one,
329            unwrap_value_ref(self.params.modulus.as_ref()),
330            &self.params.n_prime,
331        )
332        .expect("CIOS Montgomery mul requires non-empty word array")
333    }
334}
335
336impl<T: ModMathInt> Pow<ModMathParams<T>> for ModMathForm<T> {
337    fn pow(&self, exp: &ModMathValue<T>) -> Self {
338        let result_mont = self.pow_loop(unwrap_value(exp));
339        Self {
340            integer_mont: wrap_value(result_mont),
341            params: self.params.clone(),
342        }
343    }
344}
345
346impl<T: ModMathInt> PowBoundedExp<ModMathParams<T>> for ModMathForm<T> {
347    fn pow_bounded_exp(&self, exp: &ModMathValue<T>, _exp_bits: u32) -> Self {
348        // The LSB-first loop exits naturally when the exponent reaches zero,
349        // so the `_exp_bits` hint is unused here.
350        let result_mont = self.pow_loop(unwrap_value(exp));
351        Self {
352            integer_mont: wrap_value(result_mont),
353            params: self.params.clone(),
354        }
355    }
356
357    fn retrieve(&self) -> ModMathValue<T> {
358        wrap_value(self.to_reduced())
359    }
360}
361
362impl<T: ModMathInt> ModulusParams for ModMathParams<T> {
363    type Modulus = ModMathValue<T>;
364    type MontgomeryForm = ModMathForm<T>;
365
366    fn modulus(&self) -> &Odd<Self::Modulus> {
367        &self.modulus
368    }
369
370    fn bits_precision(&self) -> u32 {
371        self.modulus.bits_precision()
372    }
373}
374
375#[cfg(test)]
376#[cfg(all(feature = "alloc", feature = "private-key"))]
377mod tests {
378    use fixed_bigint::FixedUInt;
379    use rand::rngs::ChaCha8Rng;
380    use rand_core::SeedableRng;
381    use sha1::Sha1;
382    use signature::hazmat::PrehashVerifier;
383
384    use super::{public_key_from_be_bytes, ModMathParams, ModMathValue};
385    use crate::key::GenericRsaPublicKey;
386    use crate::pkcs1v15::{GenericEncryptingKey, GenericSignature, GenericVerifyingKey};
387    use crate::{traits::RandomizedEncryptor, BoxedUint, Pkcs1v15Encrypt, RsaPublicKey};
388
389    #[test]
390    fn verify_pkcs1v15_signature_with_modmath_fixed_uint() {
391        type U512 = FixedUInt<u8, 64>;
392
393        let digest: [u8; 20] = [
394            0x43, 0x0c, 0xe3, 0x4d, 0x02, 0x07, 0x24, 0xed, 0x75, 0xa1, 0x96, 0xdf, 0xc2, 0xad,
395            0x67, 0xc7, 0x77, 0x72, 0xd1, 0x69,
396        ];
397        let modulus: [u8; 64] = [
398            0x96, 0x9D, 0x03, 0xFF, 0xA9, 0x8D, 0x88, 0x8F, 0x3A, 0xA4, 0xF2, 0xFE, 0xD2, 0x32,
399            0xE6, 0x1C, 0x4A, 0xCF, 0x06, 0x63, 0xA9, 0x2F, 0x99, 0x03, 0x4C, 0xF7, 0xB7, 0x24,
400            0x5A, 0x1A, 0x1E, 0x5E, 0xAF, 0xA5, 0x65, 0xAF, 0xB9, 0x0B, 0xAB, 0x22, 0x85, 0x71,
401            0x2F, 0xAA, 0x50, 0x39, 0x39, 0xA0, 0x65, 0xFB, 0x60, 0xDD, 0x08, 0x28, 0xA3, 0x84,
402            0xF2, 0x6D, 0x8A, 0xFC, 0x28, 0x6D, 0xF6, 0xCF,
403        ];
404        let signature: [u8; 64] = [
405            0x45, 0x53, 0xF3, 0xAF, 0x16, 0xAF, 0x63, 0x97, 0xB0, 0xD3, 0x2F, 0x8A, 0xEC, 0xD5,
406            0x4C, 0xF1, 0xF3, 0xD0, 0x0C, 0x9F, 0x42, 0xDC, 0x68, 0xCB, 0xD7, 0x05, 0xCE, 0xA5,
407            0xA9, 0x70, 0x95, 0x3E, 0xC0, 0xBC, 0x4A, 0x18, 0xED, 0x91, 0xA3, 0x5D, 0x66, 0xEC,
408            0xDA, 0x4A, 0x83, 0x32, 0xCF, 0xC3, 0xA3, 0xAB, 0x21, 0xAD, 0x59, 0xB2, 0x2E, 0x87,
409            0xC2, 0x73, 0xFF, 0x08, 0x88, 0xDD, 0x4D, 0xE0,
410        ];
411
412        let key = public_key_from_be_bytes::<U512>(&modulus, 3).unwrap();
413        let verifying_key = GenericVerifyingKey::<Sha1, _, _>::new(key);
414        let signature =
415            GenericSignature::from(ModMathValue::from_inner(U512::from_be_bytes(&signature)));
416        verifying_key.verify_prehash(&digest, &signature).unwrap();
417    }
418
419    #[test]
420    fn verify_pkcs1v15_signature_with_modmath_fixed_uint32() {
421        type U512 = FixedUInt<u32, 16>;
422
423        let digest: [u8; 20] = [
424            0x43, 0x0c, 0xe3, 0x4d, 0x02, 0x07, 0x24, 0xed, 0x75, 0xa1, 0x96, 0xdf, 0xc2, 0xad,
425            0x67, 0xc7, 0x77, 0x72, 0xd1, 0x69,
426        ];
427        let modulus: [u8; 64] = [
428            0x96, 0x9D, 0x03, 0xFF, 0xA9, 0x8D, 0x88, 0x8F, 0x3A, 0xA4, 0xF2, 0xFE, 0xD2, 0x32,
429            0xE6, 0x1C, 0x4A, 0xCF, 0x06, 0x63, 0xA9, 0x2F, 0x99, 0x03, 0x4C, 0xF7, 0xB7, 0x24,
430            0x5A, 0x1A, 0x1E, 0x5E, 0xAF, 0xA5, 0x65, 0xAF, 0xB9, 0x0B, 0xAB, 0x22, 0x85, 0x71,
431            0x2F, 0xAA, 0x50, 0x39, 0x39, 0xA0, 0x65, 0xFB, 0x60, 0xDD, 0x08, 0x28, 0xA3, 0x84,
432            0xF2, 0x6D, 0x8A, 0xFC, 0x28, 0x6D, 0xF6, 0xCF,
433        ];
434        let signature: [u8; 64] = [
435            0x45, 0x53, 0xF3, 0xAF, 0x16, 0xAF, 0x63, 0x97, 0xB0, 0xD3, 0x2F, 0x8A, 0xEC, 0xD5,
436            0x4C, 0xF1, 0xF3, 0xD0, 0x0C, 0x9F, 0x42, 0xDC, 0x68, 0xCB, 0xD7, 0x05, 0xCE, 0xA5,
437            0xA9, 0x70, 0x95, 0x3E, 0xC0, 0xBC, 0x4A, 0x18, 0xED, 0x91, 0xA3, 0x5D, 0x66, 0xEC,
438            0xDA, 0x4A, 0x83, 0x32, 0xCF, 0xC3, 0xA3, 0xAB, 0x21, 0xAD, 0x59, 0xB2, 0x2E, 0x87,
439            0xC2, 0x73, 0xFF, 0x08, 0x88, 0xDD, 0x4D, 0xE0,
440        ];
441
442        let n = U512::from_be_bytes(&modulus);
443        let e = U512::from(3u8);
444        let key = GenericRsaPublicKey::from_components(
445            ModMathValue::from_inner(n),
446            ModMathValue::from_inner(e),
447            ModMathParams::new(n).unwrap(),
448        )
449        .unwrap();
450        let verifying_key = GenericVerifyingKey::<Sha1, _, _>::new(key);
451        let signature =
452            GenericSignature::from(ModMathValue::from_inner(U512::from_be_bytes(&signature)));
453        verifying_key.verify_prehash(&digest, &signature).unwrap();
454    }
455
456    #[test]
457    fn encrypt_pkcs1v15_with_modmath_fixed_uint_matches_boxeduint() {
458        type U512 = FixedUInt<u8, 64>;
459
460        let modulus: [u8; 64] = [
461            0x96, 0x9D, 0x03, 0xFF, 0xA9, 0x8D, 0x88, 0x8F, 0x3A, 0xA4, 0xF2, 0xFE, 0xD2, 0x32,
462            0xE6, 0x1C, 0x4A, 0xCF, 0x06, 0x63, 0xA9, 0x2F, 0x99, 0x03, 0x4C, 0xF7, 0xB7, 0x24,
463            0x5A, 0x1A, 0x1E, 0x5E, 0xAF, 0xA5, 0x65, 0xAF, 0xB9, 0x0B, 0xAB, 0x22, 0x85, 0x71,
464            0x2F, 0xAA, 0x50, 0x39, 0x39, 0xA0, 0x65, 0xFB, 0x60, 0xDD, 0x08, 0x28, 0xA3, 0x84,
465            0xF2, 0x6D, 0x8A, 0xFC, 0x28, 0x6D, 0xF6, 0xCF,
466        ];
467        let msg = b"hello world!";
468
469        let modmath_key = public_key_from_be_bytes::<U512>(&modulus, 3).unwrap();
470        let boxed_key = RsaPublicKey::new(
471            BoxedUint::from_be_slice(&modulus, 512).unwrap(),
472            3u64.into(),
473        )
474        .unwrap();
475
476        let mut modmath_rng = ChaCha8Rng::from_seed([42; 32]);
477        let mut boxed_rng = ChaCha8Rng::from_seed([42; 32]);
478        let mut storage = [0u8; 64];
479
480        let modmath_ciphertext = GenericEncryptingKey::new(modmath_key)
481            .encrypt_with_rng_into(&mut modmath_rng, msg, &mut storage)
482            .unwrap();
483        let boxed_ciphertext = boxed_key
484            .encrypt(&mut boxed_rng, Pkcs1v15Encrypt, msg)
485            .unwrap();
486
487        assert_eq!(modmath_ciphertext, boxed_ciphertext.as_slice());
488    }
489}