1#![allow(missing_docs)]
3
4use core::borrow::Borrow;
5
6#[cfg(feature = "alloc")]
7use alloc::boxed::Box;
8#[cfg(feature = "alloc")]
9use crypto_bigint::{
10 modular::{BoxedMontyForm, BoxedMontyParams},
11 BoxedUint, Resize as CryptoResize,
12};
13#[cfg(feature = "alloc")]
14use crypto_bigint::{NonZero as CryptoNonZero, Odd as CryptoOdd};
15#[cfg(feature = "modmath")]
16use fixed_bigint::ConstBitPrimInt;
17#[cfg(not(feature = "modmath"))]
18use num_traits::PrimInt;
19use num_traits::{FromBytes as NumFromBytes, ToBytes as NumToBytes, Zero};
20use zeroize::Zeroize;
21
22use crate::errors::{Error, Result};
23
24pub trait NumBytes: Borrow<[u8]> + AsRef<[u8]> {}
25
26impl<T> NumBytes for T where T: Borrow<[u8]> + AsRef<[u8]> {}
27
28#[repr(transparent)]
29#[derive(Clone, Debug, Eq, PartialEq)]
30pub struct NonZero<T>(T);
31
32#[repr(transparent)]
33#[derive(Clone, Debug, Eq, PartialEq)]
34pub struct Odd<T>(T);
35
36pub trait IntegerResize: Sized {
37 type Output;
38
39 fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output;
40 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output>;
41}
42
43pub trait FixedWidthUnsignedInt: Zeroize + Clone + Copy {
44 type Bytes: NumBytes + Default + AsMut<[u8]>;
45
46 fn leading_zeros(&self) -> u32;
47 fn to_be_bytes(&self) -> Self::Bytes;
48 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
49 fn bits_precision(&self) -> u32;
50}
51
52#[cfg(feature = "modmath")]
53impl<T> FixedWidthUnsignedInt for T
54where
55 T: Zeroize + Clone + Copy + ConstBitPrimInt + Zero + NumToBytes + NumFromBytes,
56 T: NumToBytes<Bytes = <T as NumFromBytes>::Bytes>,
57 <T as NumToBytes>::Bytes: NumBytes + Default + AsMut<[u8]>,
58{
59 type Bytes = <T as NumToBytes>::Bytes;
60
61 fn leading_zeros(&self) -> u32 {
62 ConstBitPrimInt::leading_zeros(*self)
63 }
64
65 fn to_be_bytes(&self) -> Self::Bytes {
66 NumToBytes::to_be_bytes(self)
67 }
68
69 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
70 let mut repr = <T as NumFromBytes>::Bytes::default();
71 let out = repr.as_mut();
72 let out_len = out.len();
73 if bytes.len() > out_len {
74 return Err(Error::InvalidArguments);
75 }
76 out[out_len - bytes.len()..].copy_from_slice(bytes);
77 Ok(NumFromBytes::from_be_bytes(&repr))
78 }
79
80 fn bits_precision(&self) -> u32 {
81 ConstBitPrimInt::count_zeros(<T as Zero>::zero())
82 }
83}
84
85#[cfg(not(feature = "modmath"))]
86impl<T> FixedWidthUnsignedInt for T
87where
88 T: Zeroize + Clone + Copy + PrimInt + NumToBytes + NumFromBytes,
89 T: NumToBytes<Bytes = <T as NumFromBytes>::Bytes>,
90 <T as NumToBytes>::Bytes: NumBytes + Default + AsMut<[u8]>,
91{
92 type Bytes = <T as NumToBytes>::Bytes;
93
94 fn leading_zeros(&self) -> u32 {
95 PrimInt::leading_zeros(*self)
96 }
97
98 fn to_be_bytes(&self) -> Self::Bytes {
99 NumToBytes::to_be_bytes(self)
100 }
101
102 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
103 let mut repr = <T as NumFromBytes>::Bytes::default();
104 let out = repr.as_mut();
105 let out_len = out.len();
106 if bytes.len() > out_len {
107 return Err(Error::InvalidArguments);
108 }
109 out[out_len - bytes.len()..].copy_from_slice(bytes);
110 Ok(NumFromBytes::from_be_bytes(&repr))
111 }
112
113 fn bits_precision(&self) -> u32 {
114 <T as Zero>::zero().count_zeros()
115 }
116}
117
118#[cfg(not(feature = "alloc"))]
119impl<T> IntegerResize for T
120where
121 T: FixedWidthUnsignedInt,
122{
123 type Output = Self;
124
125 fn resize_unchecked(self, _at_least_bits_precision: u32) -> Self::Output {
126 self
127 }
128
129 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
130 let value_bits = self.bits_precision() - self.leading_zeros();
136 if value_bits <= at_least_bits_precision {
137 Some(self)
138 } else {
139 None
140 }
141 }
142}
143
144#[cfg(not(feature = "alloc"))]
145impl<T> UnsignedModularInt for T
146where
147 T: FixedWidthUnsignedInt + PartialOrd,
148{
149 type Bytes = <T as FixedWidthUnsignedInt>::Bytes;
150
151 fn leading_zeros(&self) -> u32 {
152 FixedWidthUnsignedInt::leading_zeros(self)
153 }
154
155 fn to_be_bytes(&self) -> Self::Bytes {
156 FixedWidthUnsignedInt::to_be_bytes(self)
157 }
158
159 fn as_nz_ref(&self) -> NonZero<Self> {
160 NonZero::new(*self).expect("value is non-zero")
161 }
162
163 fn bits(&self) -> u32 {
164 self.bits_precision() - self.leading_zeros()
165 }
166
167 fn bits_precision(&self) -> u32 {
168 FixedWidthUnsignedInt::bits_precision(self)
169 }
170
171 #[cfg(feature = "alloc")]
172 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
173 unreachable!("alloc-gated")
174 }
175}
176
177#[cfg(not(feature = "alloc"))]
178impl<T> TryFromBeBytes for T
179where
180 T: FixedWidthUnsignedInt,
181{
182 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
183 FixedWidthUnsignedInt::try_from_be_bytes_vartime(bytes)
184 }
185}
186
187pub trait TryFromBeBytes: Sized {
188 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self>;
189}
190
191pub trait UnsignedModularInt:
192 Zeroize + Clone + PartialOrd + IntegerResize<Output = Self> + TryFromBeBytes
193{
194 type Bytes: NumBytes + AsMut<[u8]>;
195 fn leading_zeros(&self) -> u32;
196 fn to_be_bytes(&self) -> Self::Bytes;
197 fn as_nz_ref(&self) -> NonZero<Self>;
198 fn bits(&self) -> u32;
199 fn bits_precision(&self) -> u32;
200 #[cfg(feature = "alloc")]
201 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]>;
202}
203
204impl<T> NonZero<T>
205where
206 T: UnsignedModularInt,
207{
208 pub fn new(value: T) -> Option<Self> {
209 if value.bits() == 0 {
210 None
211 } else {
212 Some(Self(value))
213 }
214 }
215
216 pub fn get(self) -> T {
217 self.0
218 }
219
220 #[allow(clippy::should_implement_trait)]
221 pub fn as_ref(&self) -> &T {
222 &self.0
223 }
224
225 pub fn bits(&self) -> u32 {
226 self.0.bits()
227 }
228
229 pub fn bits_precision(&self) -> u32 {
230 self.0.bits_precision()
231 }
232
233 pub fn to_be_bytes(&self) -> T::Bytes {
234 self.0.to_be_bytes()
235 }
236
237 #[cfg(feature = "alloc")]
238 pub fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
239 self.0.to_be_bytes_trimmed_vartime()
240 }
241}
242
243impl<T> Odd<T>
244where
245 T: UnsignedModularInt,
246{
247 pub fn new(value: T) -> Option<Self> {
248 let non_zero = NonZero::new(value)?;
249 let bytes = non_zero.as_ref().to_be_bytes();
250 let bytes = bytes.as_ref();
251 let is_odd = bytes.last().map(|byte| byte & 1 == 1).unwrap_or(false);
252 if is_odd {
253 Some(Self(non_zero.get()))
254 } else {
255 None
256 }
257 }
258
259 pub fn get(self) -> T {
260 self.0
261 }
262
263 #[allow(clippy::should_implement_trait)]
264 pub fn as_ref(&self) -> &T {
265 &self.0
266 }
267
268 pub fn as_nz_ref(&self) -> NonZero<T> {
269 NonZero::new(self.0.clone()).expect("odd values are non-zero")
270 }
271
272 pub fn bits_precision(&self) -> u32 {
273 self.0.bits_precision()
274 }
275}
276
277pub trait IntoMontyForm<P: ModulusParams>: Sized {
294 fn from_reduced(integer: P::Modulus, params: &P) -> Self;
296
297 fn from_value(integer: P::Modulus, params: &P) -> Self;
300}
301
302#[cfg(feature = "alloc")]
303impl IntoMontyForm<BoxedMontyParams> for BoxedMontyForm {
304 fn from_reduced(integer: BoxedUint, params: &BoxedMontyParams) -> Self {
305 BoxedMontyForm::new(integer, params)
306 }
307
308 fn from_value(integer: BoxedUint, params: &BoxedMontyParams) -> Self {
309 let modulus =
310 CryptoNonZero::new(params.modulus().as_ref().clone()).expect("modulus is non-zero");
311 let reduced = integer.rem_vartime(&modulus);
312 Self::from_reduced(reduced, params)
313 }
314}
315
316pub trait PowBoundedExp<M: ModulusParams>: Sized {
317 fn pow_bounded_exp(&self, exp: &M::Modulus, exp_bits: u32) -> Self;
318 fn retrieve(&self) -> M::Modulus;
319}
320
321#[cfg(feature = "alloc")]
322impl PowBoundedExp<BoxedMontyParams> for BoxedMontyForm {
323 fn pow_bounded_exp(&self, exp: &BoxedUint, exp_bits: u32) -> Self {
324 self.clone().pow_bounded_exp(exp, exp_bits)
325 }
326
327 fn retrieve(&self) -> BoxedUint {
328 self.clone().retrieve()
329 }
330}
331
332pub trait Pow<M: ModulusParams>: Sized {
333 fn pow(&self, exp: &M::Modulus) -> Self;
334}
335
336#[cfg(feature = "alloc")]
337impl Pow<BoxedMontyParams> for BoxedMontyForm {
338 fn pow(&self, exp: &BoxedUint) -> Self {
339 self.clone().pow(exp)
340 }
341}
342
343pub trait ModulusParams: Sized {
344 type Modulus: UnsignedModularInt;
345 type MontgomeryForm: IntoMontyForm<Self> + PowBoundedExp<Self>;
346 fn modulus(&self) -> &Odd<Self::Modulus>;
347 fn bits_precision(&self) -> u32;
348}
349
350#[cfg(feature = "alloc")]
351impl ModulusParams for BoxedMontyParams {
352 type Modulus = BoxedUint;
353 type MontgomeryForm = BoxedMontyForm;
354 fn modulus(&self) -> &Odd<Self::Modulus> {
355 const _: () = assert!(
361 core::mem::size_of::<CryptoOdd<BoxedUint>>() == core::mem::size_of::<Odd<BoxedUint>>()
362 );
363 const _: () = assert!(
364 core::mem::align_of::<CryptoOdd<BoxedUint>>()
365 == core::mem::align_of::<Odd<BoxedUint>>()
366 );
367 unsafe {
368 &*(self.modulus() as *const CryptoOdd<Self::Modulus> as *const Odd<Self::Modulus>)
369 }
370 }
371 fn bits_precision(&self) -> u32 {
372 self.bits_precision()
373 }
374}
375
376#[cfg(feature = "alloc")]
377impl IntegerResize for BoxedUint {
378 type Output = Self;
379
380 fn resize_unchecked(self, at_least_bits_precision: u32) -> Self::Output {
381 CryptoResize::resize_unchecked(self, at_least_bits_precision)
382 }
383
384 fn try_resize(self, at_least_bits_precision: u32) -> Option<Self::Output> {
385 CryptoResize::try_resize(self, at_least_bits_precision)
386 }
387}
388
389#[cfg(feature = "alloc")]
390impl UnsignedModularInt for BoxedUint {
391 type Bytes = alloc::boxed::Box<[u8]>;
392
393 fn leading_zeros(&self) -> u32 {
394 self.leading_zeros()
395 }
396
397 fn to_be_bytes(&self) -> Self::Bytes {
398 self.to_be_bytes()
399 }
400 #[cfg(feature = "alloc")]
401 fn to_be_bytes_trimmed_vartime(&self) -> Box<[u8]> {
402 self.to_be_bytes_trimmed_vartime()
403 }
404 fn as_nz_ref(&self) -> NonZero<Self> {
405 NonZero::new(self.clone()).expect("Value is non-zero")
406 }
407 fn bits(&self) -> u32 {
408 self.bits()
409 }
410 fn bits_precision(&self) -> u32 {
411 self.bits_precision()
412 }
413}
414
415#[cfg(feature = "alloc")]
416impl TryFromBeBytes for BoxedUint {
417 fn try_from_be_bytes_vartime(bytes: &[u8]) -> Result<Self> {
418 Ok(BoxedUint::from_be_slice_vartime(bytes))
419 }
420}