Skip to main content

rsa/traits/
modular.rs

1// TODO: document the public surface once the trait shape settles.
2#![allow(missing_docs)]
3
4use core::borrow::Borrow;
5
6#[cfg(feature = "alloc")]
7use alloc::boxed::Box;
8#[cfg(feature = "alloc")]
9use crypto_bigint::{
10    modular::{BoxedMontyForm, BoxedMontyParams},
11    BoxedUint, Resize as CryptoResize,
12};
13#[cfg(feature = "alloc")]
14use crypto_bigint::{NonZero as CryptoNonZero, Odd as CryptoOdd};
15use num_traits::{FromBytes as NumFromBytes, PrimInt, ToBytes as NumToBytes, Zero};
16use zeroize::Zeroize;
17
18use crate::errors::{Error, Result};
19
20pub trait NumBytes: Borrow<[u8]> + AsRef<[u8]> {}
21
22impl<T> NumBytes for T where T: Borrow<[u8]> + AsRef<[u8]> {}
23
24#[repr(transparent)]
25#[derive(Clone, Debug, Eq, PartialEq)]
26pub struct NonZero<T>(T);
27
28#[repr(transparent)]
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct Odd<T>(T);
31
32pub trait IntegerResize: Sized {
33    type Output;
34
35    fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output;
36    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output>;
37}
38
39pub trait FixedWidthUnsignedInt: Zeroize + Clone + Copy {
40    type Bytes: NumBytes + Default + AsMut<[u8]>;
41
42    fn leading_zeros(&self) -> u32;
43    fn to_be_bytes(&self) -> Self::Bytes;
44    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
45    fn bits_precision(&self) -> u32;
46}
47
48impl<T> FixedWidthUnsignedInt for T
49where
50    T: Zeroize + Clone + Copy + PrimInt + NumToBytes + NumFromBytes,
51    T: NumToBytes<Bytes = <T as NumFromBytes>::Bytes>,
52    <T as NumToBytes>::Bytes: NumBytes + Default + AsMut<[u8]>,
53{
54    type Bytes = <T as NumToBytes>::Bytes;
55
56    fn leading_zeros(&self) -> u32 {
57        PrimInt::leading_zeros(*self)
58    }
59
60    fn to_be_bytes(&self) -> Self::Bytes {
61        NumToBytes::to_be_bytes(self)
62    }
63
64    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
65        let mut repr = <T as NumFromBytes>::Bytes::default();
66        let out = repr.as_mut();
67        let out_len = out.len();
68        if bytes.len() > out_len {
69            return Err(Error::InvalidArguments);
70        }
71        out[out_len - bytes.len()..].copy_from_slice(bytes);
72        Ok(NumFromBytes::from_be_bytes(&repr))
73    }
74
75    fn bits_precision(&self) -> u32 {
76        <T as Zero>::zero().count_zeros()
77    }
78}
79
80#[cfg(not(feature = "alloc"))]
81impl<T> IntegerResize for T
82where
83    T: FixedWidthUnsignedInt,
84{
85    type Output = Self;
86
87    fn resize_unchecked(self, _at_least_bits_precision: u32) -> Self::Output {
88        self
89    }
90
91    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
92        if at_least_bits_precision >= self.bits_precision() {
93            Some(self)
94        } else {
95            None
96        }
97    }
98}
99
100#[cfg(not(feature = "alloc"))]
101impl<T> UnsignedModularInt for T
102where
103    T: FixedWidthUnsignedInt + core::ops::Rem<Output = T> + PartialOrd,
104{
105    type Bytes = <T as FixedWidthUnsignedInt>::Bytes;
106
107    fn leading_zeros(&self) -> u32 {
108        FixedWidthUnsignedInt::leading_zeros(self)
109    }
110
111    fn to_be_bytes(&self) -> Self::Bytes {
112        FixedWidthUnsignedInt::to_be_bytes(self)
113    }
114
115    fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self {
116        *self % *modulus.as_ref()
117    }
118
119    fn as_nz_ref(&self) -> NonZero<Self> {
120        NonZero::new(*self).expect("value is non-zero")
121    }
122
123    fn bits(&self) -> u32 {
124        self.bits_precision() - self.leading_zeros()
125    }
126
127    fn bits_precision(&self) -> u32 {
128        FixedWidthUnsignedInt::bits_precision(self)
129    }
130
131    #[cfg(feature = "alloc")]
132    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
133        unreachable!("alloc-gated")
134    }
135}
136
137#[cfg(not(feature = "alloc"))]
138impl<T> TryFromBeBytes for T
139where
140    T: FixedWidthUnsignedInt + core::ops::Rem<Output = T>,
141{
142    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
143        FixedWidthUnsignedInt::try_from_be_bytes_vartime(bytes)
144    }
145}
146
147pub trait TryFromBeBytes: Sized {
148    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
149}
150
151pub trait UnsignedModularInt:
152    Zeroize + Clone + PartialOrd + IntegerResize<Output = Self> + TryFromBeBytes
153{
154    type Bytes: NumBytes + AsMut<[u8]>;
155    fn leading_zeros(&self) -> u32;
156    fn to_be_bytes(&self) -> Self::Bytes;
157    fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self;
158    fn as_nz_ref(&self) -> NonZero<Self>;
159    fn bits(&self) -> u32;
160    fn bits_precision(&self) -> u32;
161    #[cfg(feature = "alloc")]
162    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]>;
163}
164
165impl<T> NonZero<T>
166where
167    T: UnsignedModularInt,
168{
169    pub fn new(value: T) -> Option<Self> {
170        if value.bits() == 0 {
171            None
172        } else {
173            Some(Self(value))
174        }
175    }
176
177    pub fn get(self) -> T {
178        self.0
179    }
180
181    #[allow(clippy::should_implement_trait)]
182    pub fn as_ref(&self) -> &T {
183        &self.0
184    }
185
186    pub fn bits(&self) -> u32 {
187        self.0.bits()
188    }
189
190    pub fn bits_precision(&self) -> u32 {
191        self.0.bits_precision()
192    }
193
194    pub fn to_be_bytes(&self) -> T::Bytes {
195        self.0.to_be_bytes()
196    }
197
198    #[cfg(feature = "alloc")]
199    pub fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
200        self.0.to_be_bytes_trimmed_vartime()
201    }
202}
203
204impl<T> Odd<T>
205where
206    T: UnsignedModularInt,
207{
208    pub fn new(value: T) -> Option<Self> {
209        let non_zero = NonZero::new(value)?;
210        let bytes = non_zero.as_ref().to_be_bytes();
211        let bytes = bytes.as_ref();
212        let is_odd = bytes.last().map(|byte| byte & 1 == 1).unwrap_or(false);
213        if is_odd {
214            Some(Self(non_zero.get()))
215        } else {
216            None
217        }
218    }
219
220    pub fn get(self) -> T {
221        self.0
222    }
223
224    #[allow(clippy::should_implement_trait)]
225    pub fn as_ref(&self) -> &T {
226        &self.0
227    }
228
229    pub fn as_nz_ref(&self) -> NonZero<T> {
230        NonZero::new(self.0.clone()).expect("odd values are non-zero")
231    }
232
233    pub fn bits_precision(&self) -> u32 {
234        self.0.bits_precision()
235    }
236}
237
238/// Build a Montgomery-domain value from an integer already reduced modulo `params.modulus()`.
239pub trait IntoMontyForm<P: ModulusParams>: Sized {
240    fn from_reduced(integer: P::Modulus, params: &P) -> Self;
241}
242
243#[cfg(feature = "alloc")]
244impl IntoMontyForm<BoxedMontyParams> for BoxedMontyForm {
245    fn from_reduced(integer: BoxedUint, params: &BoxedMontyParams) -> Self {
246        BoxedMontyForm::new(integer, params)
247    }
248}
249
250pub trait PowBoundedExp<M: ModulusParams>: Sized {
251    fn pow_bounded_exp(&self, exp: &M::Modulus, exp_bits: u32) -> Self;
252    fn retrieve(&self) -> M::Modulus;
253}
254
255#[cfg(feature = "alloc")]
256impl PowBoundedExp<BoxedMontyParams> for BoxedMontyForm {
257    fn pow_bounded_exp(&self, exp: &BoxedUint, exp_bits: u32) -> Self {
258        self.clone().pow_bounded_exp(exp, exp_bits)
259    }
260
261    fn retrieve(&self) -> BoxedUint {
262        self.clone().retrieve()
263    }
264}
265
266pub trait Pow<M: ModulusParams>: Sized {
267    fn pow(&self, exp: &M::Modulus) -> Self;
268}
269
270#[cfg(feature = "alloc")]
271impl Pow<BoxedMontyParams> for BoxedMontyForm {
272    fn pow(&self, exp: &BoxedUint) -> Self {
273        self.clone().pow(exp)
274    }
275}
276
277pub trait ModulusParams: Sized {
278    type Modulus: UnsignedModularInt;
279    type MontgomeryForm: IntoMontyForm<Self> + PowBoundedExp<Self>;
280    fn modulus(&self) -> &Odd<Self::Modulus>;
281    fn bits_precision(&self) -> u32;
282}
283
284#[cfg(feature = "alloc")]
285impl ModulusParams for BoxedMontyParams {
286    type Modulus = BoxedUint;
287    type MontgomeryForm = BoxedMontyForm;
288    fn modulus(&self) -> &Odd<Self::Modulus> {
289        // Our `Odd<T>` is `#[repr(transparent)]` over `T`. `crypto_bigint::Odd<T>`
290        // is a single-field tuple struct around `T`, not formally
291        // `#[repr(transparent)]` — verify layout at compile time so a future
292        // crypto_bigint version that changes representation fails to build
293        // instead of producing silent UB.
294        const _: () = assert!(
295            core::mem::size_of::<CryptoOdd<BoxedUint>>() == core::mem::size_of::<Odd<BoxedUint>>()
296        );
297        const _: () = assert!(
298            core::mem::align_of::<CryptoOdd<BoxedUint>>()
299                == core::mem::align_of::<Odd<BoxedUint>>()
300        );
301        unsafe {
302            &*(self.modulus() as *const CryptoOdd<Self::Modulus> as *const Odd<Self::Modulus>)
303        }
304    }
305    fn bits_precision(&self) -> u32 {
306        self.bits_precision()
307    }
308}
309
310#[cfg(feature = "alloc")]
311impl IntegerResize for BoxedUint {
312    type Output = Self;
313
314    fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output {
315        CryptoResize::resize_unchecked(self, at_least_bits_precision)
316    }
317
318    fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
319        CryptoResize::try_resize(self, at_least_bits_precision)
320    }
321}
322
323#[cfg(feature = "alloc")]
324impl UnsignedModularInt for BoxedUint {
325    type Bytes = alloc::boxed::Box<[u8]>;
326
327    fn leading_zeros(&self) -> u32 {
328        self.leading_zeros()
329    }
330
331    fn to_be_bytes(&self) -> Self::Bytes {
332        self.to_be_bytes()
333    }
334    #[cfg(feature = "alloc")]
335    fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
336        self.to_be_bytes_trimmed_vartime()
337    }
338    fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self {
339        self.rem_vartime(&CryptoNonZero::new(modulus.as_ref().clone()).expect("Value is non-zero"))
340    }
341    fn as_nz_ref(&self) -> NonZero<Self> {
342        NonZero::new(self.clone()).expect("Value is non-zero")
343    }
344    fn bits(&self) -> u32 {
345        self.bits()
346    }
347    fn bits_precision(&self) -> u32 {
348        self.bits_precision()
349    }
350}
351
352#[cfg(feature = "alloc")]
353impl TryFromBeBytes for BoxedUint {
354    fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
355        Ok(BoxedUint::from_be_slice_vartime(bytes))
356    }
357}