1#![allow(missing_docs)]
3
4use core::borrow::Borrow;
5
6#[cfg(feature = "alloc")]
7use alloc::boxed::Box;
8#[cfg(feature = "alloc")]
9use crypto_bigint::{
10 modular::{BoxedMontyForm, BoxedMontyParams},
11 BoxedUint, Resize as CryptoResize,
12};
13#[cfg(feature = "alloc")]
14use crypto_bigint::{NonZero as CryptoNonZero, Odd as CryptoOdd};
15use num_traits::{FromBytes as NumFromBytes, PrimInt, ToBytes as NumToBytes, Zero};
16use zeroize::Zeroize;
17
18use crate::errors::{Error, Result};
19
20pub trait NumBytes: Borrow<[u8]> + AsRef<[u8]> {}
21
22impl<T> NumBytes for T where T: Borrow<[u8]> + AsRef<[u8]> {}
23
24#[repr(transparent)]
25#[derive(Clone, Debug, Eq, PartialEq)]
26pub struct NonZero<T>(T);
27
28#[repr(transparent)]
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct Odd<T>(T);
31
32pub trait IntegerResize: Sized {
33 type Output;
34
35 fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output;
36 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output>;
37}
38
39pub trait FixedWidthUnsignedInt: Zeroize + Clone + Copy {
40 type Bytes: NumBytes + Default + AsMut<[u8]>;
41
42 fn leading_zeros(&self) -> u32;
43 fn to_be_bytes(&self) -> Self::Bytes;
44 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
45 fn bits_precision(&self) -> u32;
46}
47
48impl<T> FixedWidthUnsignedInt for T
49where
50 T: Zeroize + Clone + Copy + PrimInt + NumToBytes + NumFromBytes,
51 T: NumToBytes<Bytes = <T as NumFromBytes>::Bytes>,
52 <T as NumToBytes>::Bytes: NumBytes + Default + AsMut<[u8]>,
53{
54 type Bytes = <T as NumToBytes>::Bytes;
55
56 fn leading_zeros(&self) -> u32 {
57 PrimInt::leading_zeros(*self)
58 }
59
60 fn to_be_bytes(&self) -> Self::Bytes {
61 NumToBytes::to_be_bytes(self)
62 }
63
64 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
65 let mut repr = <T as NumFromBytes>::Bytes::default();
66 let out = repr.as_mut();
67 let out_len = out.len();
68 if bytes.len() > out_len {
69 return Err(Error::InvalidArguments);
70 }
71 out[out_len - bytes.len()..].copy_from_slice(bytes);
72 Ok(NumFromBytes::from_be_bytes(&repr))
73 }
74
75 fn bits_precision(&self) -> u32 {
76 <T as Zero>::zero().count_zeros()
77 }
78}
79
80#[cfg(not(feature = "alloc"))]
81impl<T> IntegerResize for T
82where
83 T: FixedWidthUnsignedInt,
84{
85 type Output = Self;
86
87 fn resize_unchecked(self, _at_least_bits_precision: u32) -> Self::Output {
88 self
89 }
90
91 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
92 if at_least_bits_precision >= self.bits_precision() {
93 Some(self)
94 } else {
95 None
96 }
97 }
98}
99
100#[cfg(not(feature = "alloc"))]
101impl<T> UnsignedModularInt for T
102where
103 T: FixedWidthUnsignedInt + core::ops::Rem<Output = T> + PartialOrd,
104{
105 type Bytes = <T as FixedWidthUnsignedInt>::Bytes;
106
107 fn leading_zeros(&self) -> u32 {
108 FixedWidthUnsignedInt::leading_zeros(self)
109 }
110
111 fn to_be_bytes(&self) -> Self::Bytes {
112 FixedWidthUnsignedInt::to_be_bytes(self)
113 }
114
115 fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self {
116 *self % *modulus.as_ref()
117 }
118
119 fn as_nz_ref(&self) -> NonZero<Self> {
120 NonZero::new(*self).expect("value is non-zero")
121 }
122
123 fn bits(&self) -> u32 {
124 self.bits_precision() - self.leading_zeros()
125 }
126
127 fn bits_precision(&self) -> u32 {
128 FixedWidthUnsignedInt::bits_precision(self)
129 }
130
131 #[cfg(feature = "alloc")]
132 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
133 unreachable!("alloc-gated")
134 }
135}
136
137#[cfg(not(feature = "alloc"))]
138impl<T> TryFromBeBytes for T
139where
140 T: FixedWidthUnsignedInt + core::ops::Rem<Output = T>,
141{
142 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
143 FixedWidthUnsignedInt::try_from_be_bytes_vartime(bytes)
144 }
145}
146
147pub trait TryFromBeBytes: Sized {
148 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
149}
150
151pub trait UnsignedModularInt:
152 Zeroize + Clone + PartialOrd + IntegerResize<Output = Self> + TryFromBeBytes
153{
154 type Bytes: NumBytes + AsMut<[u8]>;
155 fn leading_zeros(&self) -> u32;
156 fn to_be_bytes(&self) -> Self::Bytes;
157 fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self;
158 fn as_nz_ref(&self) -> NonZero<Self>;
159 fn bits(&self) -> u32;
160 fn bits_precision(&self) -> u32;
161 #[cfg(feature = "alloc")]
162 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]>;
163}
164
165impl<T> NonZero<T>
166where
167 T: UnsignedModularInt,
168{
169 pub fn new(value: T) -> Option<Self> {
170 if value.bits() == 0 {
171 None
172 } else {
173 Some(Self(value))
174 }
175 }
176
177 pub fn get(self) -> T {
178 self.0
179 }
180
181 #[allow(clippy::should_implement_trait)]
182 pub fn as_ref(&self) -> &T {
183 &self.0
184 }
185
186 pub fn bits(&self) -> u32 {
187 self.0.bits()
188 }
189
190 pub fn bits_precision(&self) -> u32 {
191 self.0.bits_precision()
192 }
193
194 pub fn to_be_bytes(&self) -> T::Bytes {
195 self.0.to_be_bytes()
196 }
197
198 #[cfg(feature = "alloc")]
199 pub fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
200 self.0.to_be_bytes_trimmed_vartime()
201 }
202}
203
204impl<T> Odd<T>
205where
206 T: UnsignedModularInt,
207{
208 pub fn new(value: T) -> Option<Self> {
209 let non_zero = NonZero::new(value)?;
210 let bytes = non_zero.as_ref().to_be_bytes();
211 let bytes = bytes.as_ref();
212 let is_odd = bytes.last().map(|byte| byte & 1 == 1).unwrap_or(false);
213 if is_odd {
214 Some(Self(non_zero.get()))
215 } else {
216 None
217 }
218 }
219
220 pub fn get(self) -> T {
221 self.0
222 }
223
224 #[allow(clippy::should_implement_trait)]
225 pub fn as_ref(&self) -> &T {
226 &self.0
227 }
228
229 pub fn as_nz_ref(&self) -> NonZero<T> {
230 NonZero::new(self.0.clone()).expect("odd values are non-zero")
231 }
232
233 pub fn bits_precision(&self) -> u32 {
234 self.0.bits_precision()
235 }
236}
237
238pub trait IntoMontyForm<P: ModulusParams>: Sized {
240 fn from_reduced(integer: P::Modulus, params: &P) -> Self;
241}
242
243#[cfg(feature = "alloc")]
244impl IntoMontyForm<BoxedMontyParams> for BoxedMontyForm {
245 fn from_reduced(integer: BoxedUint, params: &BoxedMontyParams) -> Self {
246 BoxedMontyForm::new(integer, params)
247 }
248}
249
250pub trait PowBoundedExp<M: ModulusParams>: Sized {
251 fn pow_bounded_exp(&self, exp: &M::Modulus, exp_bits: u32) -> Self;
252 fn retrieve(&self) -> M::Modulus;
253}
254
255#[cfg(feature = "alloc")]
256impl PowBoundedExp<BoxedMontyParams> for BoxedMontyForm {
257 fn pow_bounded_exp(&self, exp: &BoxedUint, exp_bits: u32) -> Self {
258 self.clone().pow_bounded_exp(exp, exp_bits)
259 }
260
261 fn retrieve(&self) -> BoxedUint {
262 self.clone().retrieve()
263 }
264}
265
266pub trait Pow<M: ModulusParams>: Sized {
267 fn pow(&self, exp: &M::Modulus) -> Self;
268}
269
270#[cfg(feature = "alloc")]
271impl Pow<BoxedMontyParams> for BoxedMontyForm {
272 fn pow(&self, exp: &BoxedUint) -> Self {
273 self.clone().pow(exp)
274 }
275}
276
277pub trait ModulusParams: Sized {
278 type Modulus: UnsignedModularInt;
279 type MontgomeryForm: IntoMontyForm<Self> + PowBoundedExp<Self>;
280 fn modulus(&self) -> &Odd<Self::Modulus>;
281 fn bits_precision(&self) -> u32;
282}
283
284#[cfg(feature = "alloc")]
285impl ModulusParams for BoxedMontyParams {
286 type Modulus = BoxedUint;
287 type MontgomeryForm = BoxedMontyForm;
288 fn modulus(&self) -> &Odd<Self::Modulus> {
289 const _: () = assert!(
295 core::mem::size_of::<CryptoOdd<BoxedUint>>() == core::mem::size_of::<Odd<BoxedUint>>()
296 );
297 const _: () = assert!(
298 core::mem::align_of::<CryptoOdd<BoxedUint>>()
299 == core::mem::align_of::<Odd<BoxedUint>>()
300 );
301 unsafe {
302 &*(self.modulus() as *const CryptoOdd<Self::Modulus> as *const Odd<Self::Modulus>)
303 }
304 }
305 fn bits_precision(&self) -> u32 {
306 self.bits_precision()
307 }
308}
309
310#[cfg(feature = "alloc")]
311impl IntegerResize for BoxedUint {
312 type Output = Self;
313
314 fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output {
315 CryptoResize::resize_unchecked(self, at_least_bits_precision)
316 }
317
318 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
319 CryptoResize::try_resize(self, at_least_bits_precision)
320 }
321}
322
323#[cfg(feature = "alloc")]
324impl UnsignedModularInt for BoxedUint {
325 type Bytes = alloc::boxed::Box<[u8]>;
326
327 fn leading_zeros(&self) -> u32 {
328 self.leading_zeros()
329 }
330
331 fn to_be_bytes(&self) -> Self::Bytes {
332 self.to_be_bytes()
333 }
334 #[cfg(feature = "alloc")]
335 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
336 self.to_be_bytes_trimmed_vartime()
337 }
338 fn rem_vartime(&self, modulus: &NonZero<Self>) -> Self {
339 self.rem_vartime(&CryptoNonZero::new(modulus.as_ref().clone()).expect("Value is non-zero"))
340 }
341 fn as_nz_ref(&self) -> NonZero<Self> {
342 NonZero::new(self.clone()).expect("Value is non-zero")
343 }
344 fn bits(&self) -> u32 {
345 self.bits()
346 }
347 fn bits_precision(&self) -> u32 {
348 self.bits_precision()
349 }
350}
351
352#[cfg(feature = "alloc")]
353impl TryFromBeBytes for BoxedUint {
354 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
355 Ok(BoxedUint::from_be_slice_vartime(bytes))
356 }
357}