Skip to main content

rsa/traits/
modular.rs

1// TODO: document the public surface once the trait shape settles.
2#![allow(missing_docs)]
3
4use core::borrow::Borrow;
5
6#[cfg(feature = "alloc")]
7use alloc::boxed::Box;
8#[cfg(feature = "alloc")]
9use crypto_bigint::{
10    modular::{BoxedMontyForm, BoxedMontyParams},
11    BoxedUint, Resize as CryptoResize,
12};
13#[cfg(feature = "alloc")]
14use crypto_bigint::{NonZero as CryptoNonZero, Odd as CryptoOdd};
15#[cfg(feature = "modmath")]
16use fixed_bigint::ConstBitPrimInt;
17#[cfg(not(feature = "modmath"))]
18use num_traits::PrimInt;
19use num_traits::{FromBytes as NumFromBytes, ToBytes as NumToBytes, Zero};
20use zeroize::Zeroize;
21
22use crate::errors::{Error, Result};
23
24pub trait NumBytes: Borrow<[u8]> + AsRef<[u8]> {}
25
26impl<T> NumBytes for T where T: Borrow<[u8]> + AsRef<[u8]> {}
27
28#[repr(transparent)]
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct NonZero<T>(T);
31
32#[repr(transparent)]
33#[derive(Clone, Debug, Eq, PartialEq)]
34pub struct Odd<T>(T);
35
36pub trait IntegerResize: Sized {
37    type Output;
38
39    fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output;
40    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output>;
41}
42
43pub trait FixedWidthUnsignedInt: Zeroize + Clone + Copy {
44    type Bytes: NumBytes + Default + AsMut<[u8]>;
45
46    fn leading_zeros(&self) -> u32;
47    fn to_be_bytes(&self) -> Self::Bytes;
48    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
49    fn bits_precision(&self) -> u32;
50}
51
52#[cfg(feature = "modmath")]
53impl<T> FixedWidthUnsignedInt for T
54where
55    T: Zeroize + Clone + Copy + ConstBitPrimInt + Zero + NumToBytes + NumFromBytes,
56    T: NumToBytes<Bytes = <T as NumFromBytes>::Bytes>,
57    <T as NumToBytes>::Bytes: NumBytes + Default + AsMut<[u8]>,
58{
59    type Bytes = <T as NumToBytes>::Bytes;
60
61    fn leading_zeros(&self) -> u32 {
62        ConstBitPrimInt::leading_zeros(*self)
63    }
64
65    fn to_be_bytes(&self) -> Self::Bytes {
66        NumToBytes::to_be_bytes(self)
67    }
68
69    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
70        let mut repr = <T as NumFromBytes>::Bytes::default();
71        let out = repr.as_mut();
72        let out_len = out.len();
73        if bytes.len() > out_len {
74            return Err(Error::InvalidArguments);
75        }
76        out[out_len - bytes.len()..].copy_from_slice(bytes);
77        Ok(NumFromBytes::from_be_bytes(&repr))
78    }
79
80    fn bits_precision(&self) -> u32 {
81        ConstBitPrimInt::count_zeros(<T as Zero>::zero())
82    }
83}
84
85#[cfg(not(feature = "modmath"))]
86impl<T> FixedWidthUnsignedInt for T
87where
88    T: Zeroize + Clone + Copy + PrimInt + NumToBytes + NumFromBytes,
89    T: NumToBytes<Bytes = <T as NumFromBytes>::Bytes>,
90    <T as NumToBytes>::Bytes: NumBytes + Default + AsMut<[u8]>,
91{
92    type Bytes = <T as NumToBytes>::Bytes;
93
94    fn leading_zeros(&self) -> u32 {
95        PrimInt::leading_zeros(*self)
96    }
97
98    fn to_be_bytes(&self) -> Self::Bytes {
99        NumToBytes::to_be_bytes(self)
100    }
101
102    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
103        let mut repr = <T as NumFromBytes>::Bytes::default();
104        let out = repr.as_mut();
105        let out_len = out.len();
106        if bytes.len() > out_len {
107            return Err(Error::InvalidArguments);
108        }
109        out[out_len - bytes.len()..].copy_from_slice(bytes);
110        Ok(NumFromBytes::from_be_bytes(&repr))
111    }
112
113    fn bits_precision(&self) -> u32 {
114        <T as Zero>::zero().count_zeros()
115    }
116}
117
118#[cfg(not(feature = "alloc"))]
119impl<T> IntegerResize for T
120where
121    T: FixedWidthUnsignedInt,
122{
123    type Output = Self;
124
125    fn resize_unchecked(self, _at_least_bits_precision: u32) -> Self::Output {
126        self
127    }
128
129    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
130        // Mirrors `crypto_bigint::Resize::try_resize`: returns `Some` iff
131        // the actual value fits in `at_least_bits_precision` bits. T is
132        // fixed-width and `resize_unchecked` is a no-op, but the check
133        // still needs to reject values that wouldn't survive a narrower
134        // precision.
135        let value_bits = self.bits_precision() - self.leading_zeros();
136        if value_bits <= at_least_bits_precision {
137            Some(self)
138        } else {
139            None
140        }
141    }
142}
143
144#[cfg(not(feature = "alloc"))]
145impl<T> UnsignedModularInt for T
146where
147    T: FixedWidthUnsignedInt + PartialOrd,
148{
149    type Bytes = <T as FixedWidthUnsignedInt>::Bytes;
150
151    fn leading_zeros(&self) -> u32 {
152        FixedWidthUnsignedInt::leading_zeros(self)
153    }
154
155    fn to_be_bytes(&self) -> Self::Bytes {
156        FixedWidthUnsignedInt::to_be_bytes(self)
157    }
158
159    fn as_nz_ref(&self) -> NonZero<Self> {
160        NonZero::new(*self).expect("value is non-zero")
161    }
162
163    fn bits(&self) -> u32 {
164        self.bits_precision() - self.leading_zeros()
165    }
166
167    fn bits_precision(&self) -> u32 {
168        FixedWidthUnsignedInt::bits_precision(self)
169    }
170
171    #[cfg(feature = "alloc")]
172    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
173        unreachable!("alloc-gated")
174    }
175}
176
177#[cfg(not(feature = "alloc"))]
178impl<T> TryFromBeBytes for T
179where
180    T: FixedWidthUnsignedInt,
181{
182    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
183        FixedWidthUnsignedInt::try_from_be_bytes_vartime(bytes)
184    }
185}
186
187pub trait TryFromBeBytes: Sized {
188    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
189}
190
191pub trait UnsignedModularInt:
192    Zeroize + Clone + PartialOrd + IntegerResize<Output = Self> + TryFromBeBytes
193{
194    type Bytes: NumBytes + AsMut<[u8]>;
195    fn leading_zeros(&self) -> u32;
196    fn to_be_bytes(&self) -> Self::Bytes;
197    fn as_nz_ref(&self) -> NonZero<Self>;
198    fn bits(&self) -> u32;
199    fn bits_precision(&self) -> u32;
200    #[cfg(feature = "alloc")]
201    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]>;
202}
203
204impl<T> NonZero<T>
205where
206    T: UnsignedModularInt,
207{
208    pub fn new(value: T) -> Option<Self> {
209        if value.bits() == 0 {
210            None
211        } else {
212            Some(Self(value))
213        }
214    }
215
216    pub fn get(self) -> T {
217        self.0
218    }
219
220    #[allow(clippy::should_implement_trait)]
221    pub fn as_ref(&self) -> &T {
222        &self.0
223    }
224
225    pub fn bits(&self) -> u32 {
226        self.0.bits()
227    }
228
229    pub fn bits_precision(&self) -> u32 {
230        self.0.bits_precision()
231    }
232
233    pub fn to_be_bytes(&self) -> T::Bytes {
234        self.0.to_be_bytes()
235    }
236
237    #[cfg(feature = "alloc")]
238    pub fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
239        self.0.to_be_bytes_trimmed_vartime()
240    }
241}
242
243impl<T> Odd<T>
244where
245    T: UnsignedModularInt,
246{
247    pub fn new(value: T) -> Option<Self> {
248        let non_zero = NonZero::new(value)?;
249        let bytes = non_zero.as_ref().to_be_bytes();
250        let bytes = bytes.as_ref();
251        let is_odd = bytes.last().map(|byte| byte & 1 == 1).unwrap_or(false);
252        if is_odd {
253            Some(Self(non_zero.get()))
254        } else {
255            None
256        }
257    }
258
259    pub fn get(self) -> T {
260        self.0
261    }
262
263    #[allow(clippy::should_implement_trait)]
264    pub fn as_ref(&self) -> &T {
265        &self.0
266    }
267
268    pub fn as_nz_ref(&self) -> NonZero<T> {
269        NonZero::new(self.0.clone()).expect("odd values are non-zero")
270    }
271
272    pub fn bits_precision(&self) -> u32 {
273        self.0.bits_precision()
274    }
275}
276
277/// Build a Montgomery-domain value.
278///
279/// Two constructors with **different input contracts**:
280///
281/// - [`from_reduced`](Self::from_reduced) — caller guarantees `integer <
282///   params.modulus()`. Implementations may rely on this; no reduction is
283///   performed. Use this when you already know the value is reduced.
284/// - [`from_value`](Self::from_value) — accepts any `integer` in
285///   `[0, 2^bits_precision)`. Implementations MUST handle the unreduced
286///   case (either by reducing internally or by using a Montgomery primitive
287///   that tolerates unreduced inputs, e.g. CIOS with `raw * R²`).
288///
289/// No default `from_value` is provided on purpose. Forwarding to
290/// `from_reduced` would silently produce wrong results for unreduced
291/// inputs on backends that don't tolerate them — the trait makes this
292/// distinction explicit so each implementor confronts it.
293pub trait IntoMontyForm<P: ModulusParams>: Sized {
294    /// Build from an integer already reduced modulo `params.modulus()`.
295    fn from_reduced(integer: P::Modulus, params: &P) -> Self;
296
297    /// Build from any integer in `[0, 2^bits_precision)`, handling
298    /// reduction internally if needed.
299    fn from_value(integer: P::Modulus, params: &P) -> Self;
300}
301
302#[cfg(feature = "alloc")]
303impl IntoMontyForm<BoxedMontyParams> for BoxedMontyForm {
304    fn from_reduced(integer: BoxedUint, params: &BoxedMontyParams) -> Self {
305        BoxedMontyForm::new(integer, params)
306    }
307
308    fn from_value(integer: BoxedUint, params: &BoxedMontyParams) -> Self {
309        let modulus =
310            CryptoNonZero::new(params.modulus().as_ref().clone()).expect("modulus is non-zero");
311        let reduced = integer.rem_vartime(&modulus);
312        Self::from_reduced(reduced, params)
313    }
314}
315
316pub trait PowBoundedExp<M: ModulusParams>: Sized {
317    fn pow_bounded_exp(&self, exp: &M::Modulus, exp_bits: u32) -> Self;
318    fn retrieve(&self) -> M::Modulus;
319}
320
321#[cfg(feature = "alloc")]
322impl PowBoundedExp<BoxedMontyParams> for BoxedMontyForm {
323    fn pow_bounded_exp(&self, exp: &BoxedUint, exp_bits: u32) -> Self {
324        self.clone().pow_bounded_exp(exp, exp_bits)
325    }
326
327    fn retrieve(&self) -> BoxedUint {
328        self.clone().retrieve()
329    }
330}
331
332pub trait Pow<M: ModulusParams>: Sized {
333    fn pow(&self, exp: &M::Modulus) -> Self;
334}
335
336#[cfg(feature = "alloc")]
337impl Pow<BoxedMontyParams> for BoxedMontyForm {
338    fn pow(&self, exp: &BoxedUint) -> Self {
339        self.clone().pow(exp)
340    }
341}
342
343pub trait ModulusParams: Sized {
344    type Modulus: UnsignedModularInt;
345    type MontgomeryForm: IntoMontyForm<Self> + PowBoundedExp<Self>;
346    fn modulus(&self) -> &Odd<Self::Modulus>;
347    fn bits_precision(&self) -> u32;
348}
349
350#[cfg(feature = "alloc")]
351impl ModulusParams for BoxedMontyParams {
352    type Modulus = BoxedUint;
353    type MontgomeryForm = BoxedMontyForm;
354    fn modulus(&self) -> &Odd<Self::Modulus> {
355        // Our `Odd<T>` is `#[repr(transparent)]` over `T`. `crypto_bigint::Odd<T>`
356        // is a single-field tuple struct around `T`, not formally
357        // `#[repr(transparent)]` — verify layout at compile time so a future
358        // crypto_bigint version that changes representation fails to build
359        // instead of producing silent UB.
360        const _: () = assert!(
361            core::mem::size_of::<CryptoOdd<BoxedUint>>() == core::mem::size_of::<Odd<BoxedUint>>()
362        );
363        const _: () = assert!(
364            core::mem::align_of::<CryptoOdd<BoxedUint>>()
365                == core::mem::align_of::<Odd<BoxedUint>>()
366        );
367        unsafe {
368            &*(self.modulus() as *const CryptoOdd<Self::Modulus> as *const Odd<Self::Modulus>)
369        }
370    }
371    fn bits_precision(&self) -> u32 {
372        self.bits_precision()
373    }
374}
375
376#[cfg(feature = "alloc")]
377impl IntegerResize for BoxedUint {
378    type Output = Self;
379
380    fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output {
381        CryptoResize::resize_unchecked(self, at_least_bits_precision)
382    }
383
384    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
385        CryptoResize::try_resize(self, at_least_bits_precision)
386    }
387}
388
389#[cfg(feature = "alloc")]
390impl UnsignedModularInt for BoxedUint {
391    type Bytes = alloc::boxed::Box<[u8]>;
392
393    fn leading_zeros(&self) -> u32 {
394        self.leading_zeros()
395    }
396
397    fn to_be_bytes(&self) -> Self::Bytes {
398        self.to_be_bytes()
399    }
400    #[cfg(feature = "alloc")]
401    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
402        self.to_be_bytes_trimmed_vartime()
403    }
404    fn as_nz_ref(&self) -> NonZero<Self> {
405        NonZero::new(self.clone()).expect("Value is non-zero")
406    }
407    fn bits(&self) -> u32 {
408        self.bits()
409    }
410    fn bits_precision(&self) -> u32 {
411        self.bits_precision()
412    }
413}
414
415#[cfg(feature = "alloc")]
416impl TryFromBeBytes for BoxedUint {
417    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
418        Ok(BoxedUint::from_be_slice_vartime(bytes))
419    }
420}