kyber_rs/group/integer_field/
integer.rs

1use core::cmp::Ordering::{self, Equal, Greater};
2use core::fmt::{Display, Formatter, LowerHex, UpperHex};
3use lazy_static::lazy_static;
4use num_bigint_dig as num_bigint;
5use thiserror::Error;
6
7use num_bigint::algorithms::jacobi;
8use num_bigint::Sign::Plus;
9use num_bigint::{BigInt, ModInverse, ParseBigIntError, Sign};
10use num_traits::{Num, Signed};
11
12use crate::cipher::stream::Stream;
13use crate::encoding::{BinaryMarshaler, BinaryUnmarshaler, Marshaling, MarshallingError};
14use crate::group::internal::marshalling;
15use crate::group::Scalar;
16use crate::util::random::random_int;
17use serde::{Deserialize, Serialize};
18
19use crate::group::integer_field::integer::ByteOrder::{BigEndian, LittleEndian};
20
21lazy_static! {
22    pub static ref ONE: BigInt = BigInt::from(1_i64);
23    pub static ref TWO: BigInt = BigInt::from(2_i64);
24}
25
26const MARSHAL_INT_ID: [u8; 8] = [b'm', b'o', b'd', b'.', b'i', b'n', b't', b' '];
27
28/// [`ByteOrder`] denotes the endianness of the operation.
29#[derive(Clone, Copy, PartialEq, Eq, Debug, Serialize, Deserialize)]
30pub enum ByteOrder {
31    LittleEndian,
32    BigEndian,
33}
34
35impl From<ByteOrder> for bool {
36    fn from(val: ByteOrder) -> Self {
37        match val {
38            LittleEndian => true,
39            BigEndian => false,
40        }
41    }
42}
43
44impl From<bool> for ByteOrder {
45    fn from(b: bool) -> Self {
46        match b {
47            true => LittleEndian,
48            false => BigEndian,
49        }
50    }
51}
52
53/// [`Int`] is a generic implementation of finite field arithmetic
54/// on `integer finite fields` with a given constant `modulus`,
55/// built using [`num_bigint_dig`] crate.
56/// The [`Scalar`] trait is implemented for [`Int`],
57/// and hence serves as a basic implementation of [`Scalar`],
58/// e.g., representing discrete-log exponents of `Schnorr groups`
59/// or scalar multipliers for elliptic curves.
60///
61/// [`Int`] offers an API similar to and compatible with [`BigInt`],
62/// but "carries around"  the relevant modulus
63/// and automatically normalizes the value to that modulus
64/// after all arithmetic operations, simplifying modular arithmetic.
65/// Binary operations assume that the source(s)
66/// have the same modulus, but do not check this assumption.
67/// Unary and binary arithmetic operations may be performed on uninitialized
68/// target objects, and receive the modulus of the first operand.
69/// For efficiency the modulus field m is a pointer,
70/// whose target is assumed never to change.
71#[derive(Clone, Eq, Debug, Serialize, Deserialize)]
72pub struct Int {
73    /// integer value from `0` through `m-1`
74    pub(crate) v: BigInt,
75    /// modulus for finite field arithmetic
76    pub(crate) m: BigInt,
77    /// endianness which will be used on input and output
78    pub bo: ByteOrder,
79}
80
81impl Default for Int {
82    fn default() -> Self {
83        Int {
84            bo: LittleEndian,
85            v: BigInt::from(0),
86            m: BigInt::from(0),
87        }
88    }
89}
90
91impl Int {
92    /// [`init64()`] creates an [`Int`] with an [`i64`] value and [`BigInt`] modulus.
93    pub fn init64(mut self, v: i64, m: BigInt) -> Self {
94        self.m = m.clone();
95        self.bo = BigEndian;
96        self.v = BigInt::from(v);
97        // specify euclidean modulus for negative number
98        match self.v.sign() {
99            num_bigint::Sign::Minus => self.v = (self.v % m.clone()) + m.abs(),
100            _ => self.v %= m,
101        }
102        self
103    }
104
105    /// [`init()`] a [`Int`] with a given [`BigInt`] `value` and a given [`BigInt`] `modulus`.
106    fn init(mut self, v: BigInt, m: BigInt) -> Self {
107        self.m = m.clone();
108        self.bo = BigEndian;
109        self.v = v % m;
110        self
111    }
112
113    /// [`little_endian()`] encodes the value of this [`Int`] into a little-endian byte-slice
114    /// at least `min` bytes but no more than `max` bytes long.
115    /// Panics if max != 0 and the Int cannot be represented in max bytes.
116    pub fn little_endian(&self, min: usize, max: usize) -> Result<Vec<u8>, IntError> {
117        let mut act = self.marshal_size();
118        let (_, v_bytes) = self.v.to_bytes_be();
119        let v_size = v_bytes.len();
120        if v_size < act {
121            act = v_size;
122        }
123        let mut pad = act;
124        if pad < min {
125            pad = min
126        }
127        if max != 0 && pad > max {
128            return Err(IntError::NotRepresentable);
129        }
130
131        let buf = vec![0; pad];
132        let buf2 = &buf[0..act];
133        Ok(reverse(buf2, &v_bytes))
134    }
135
136    /// [`new_int()`] creates a new [`Int`] with a given [`BigInt`] and a [`BigInt`] `modulus`.
137    pub fn new_int(v: BigInt, m: BigInt) -> Int {
138        Int::default().init(v, m)
139    }
140
141    /// [`new_int64()`] creates a new [`Int`] with a given [`i64`] value and [`BigInt`] `modulus`.
142    pub fn new_int64(v: i64, m: BigInt) -> Int {
143        Int::default().init64(v, m)
144    }
145
146    /// [`new_int_bytes()`] creates a new [`Int`] with a given slice of bytes and a [`BigInt`]
147    /// `modulus`.
148    pub fn new_int_bytes(a: &[u8], m: &BigInt, byte_order: ByteOrder) -> Int {
149        Int::default().init_bytes(a, m, byte_order)
150    }
151
152    /// [`new_int_string()`] creates a new [`Int`] with a given [`String`] and a [`BigInt`] `modulus`.
153    /// The value is set to a rational fraction n/d in a given base.
154    pub fn new_int_string(n: String, d: String, base: i32, m: &BigInt) -> Int {
155        Int::default().init_string(n, d, base, m)
156    }
157
158    /// [`equal()`] returns [`true`] if the two [`ints`](Int) are equal
159    pub fn equal(&self, s2: &Self) -> bool {
160        self.v.cmp(&s2.v) == Equal
161    }
162
163    /// [`cmpr()`] compares two [`ints`](Int) for equality or inequality
164    pub fn cmpr(&self, s2: &Self) -> Ordering {
165        self.v.cmp(&s2.v)
166    }
167
168    /// [`init_bytes()`] init the [`Int`] to a number represented in a `big-endian` byte string.
169    pub fn init_bytes(self, a: &[u8], m: &BigInt, byte_order: ByteOrder) -> Self {
170        Int {
171            m: m.clone(),
172            bo: byte_order,
173            v: self.v,
174        }
175        .set_bytes(a)
176    }
177
178    /// [`init_string()`] inits the [`Int`] to a rational fraction n/d
179    /// specified with a pair of [`strings`](String) in a given base.
180    fn init_string(mut self, n: String, d: String, base: i32, m: &BigInt) -> Int {
181        self.m = m.clone();
182        self.bo = BigEndian;
183        self.set_string(n, d, base)
184            .expect("init_string: invalid fraction representation")
185    }
186
187    /// [`set_string()`] sets the [`Int`] to a rational fraction n/d represented by a pair of [`strings`](String).
188    /// If `d == ""`, then the denominator is taken to be `1`.
189    /// Returns the [`Int`] on success or an [`Error`](BigIntError)
190    /// if the string failed to parse
191    pub fn set_string(mut self, n: String, d: String, base: i32) -> Result<Self, IntError> {
192        self.v = BigInt::from_str_radix(n.as_str(), base as u32)?;
193        if !d.is_empty() {
194            let mut di = Int {
195                m: self.m.clone(),
196                ..Default::default()
197            };
198            di = di.set_string(d, "".to_string(), base)?;
199            return Ok(self.clone().div(&self, &di));
200        }
201        Ok(self)
202    }
203}
204
205impl Display for Int {
206    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
207        write!(f, "{self:#x}")
208    }
209}
210
211impl PartialEq for Int {
212    fn eq(&self, other: &Self) -> bool {
213        self.equal(other)
214    }
215}
216
217impl Ord for Int {
218    fn cmp(&self, other: &Self) -> Ordering {
219        self.cmpr(other)
220    }
221}
222
223impl PartialOrd for Int {
224    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
225        Some(self.cmpr(other))
226    }
227}
228
229impl BinaryMarshaler for Int {
230    /// [`marshal_binary()`] encodes the value of this [`Int`] into a byte-slice exactly [`self.marshal_size()`] bytes long.
231    /// It uses `i`'s [`ByteOrder`] to determine which byte order to output.
232    fn marshal_binary(&self) -> Result<Vec<u8>, MarshallingError> {
233        let l = self.marshal_size();
234        // may be shorter than l
235        let (_, mut b) = self.v.to_bytes_be();
236        let offset = l - b.len();
237
238        if self.bo == LittleEndian {
239            return self
240                .little_endian(l, l)
241                .map_err(|e| MarshallingError::InvalidInput(e.to_string()));
242        }
243
244        if offset != 0 {
245            let mut nb = vec![0; l];
246            nb.splice((offset).., b);
247            b = nb;
248        }
249        Ok(b)
250    }
251}
252
253impl BinaryUnmarshaler for Int {
254    /// [`unmarshal_binary()`] tries to decode a [`Int`] from a byte-slice buffer.
255    /// Returns an [`Error`](MarshallingError) if the buffer is not exactly [`self.marshal_size()`] bytes long
256    /// or if the contents of the buffer represents an out-of-range integer.
257    fn unmarshal_binary(&mut self, data: &[u8]) -> Result<(), MarshallingError> {
258        let mut buf: Vec<u8> = data.to_vec();
259        if buf.len() != self.marshal_size() {
260            return Err(MarshallingError::InvalidInput(
261                "unmarshal_binary: wrong size buffer".to_owned(),
262            ));
263        }
264        // Still needed here because of the comparison with the modulo
265        if self.bo == LittleEndian {
266            buf = reverse(&vec![0_u8; buf.len()], &buf.to_vec());
267        }
268        self.v = BigInt::from_bytes_be(Plus, buf.as_slice());
269        if matches!(self.v.cmp(&self.m), Greater | Equal) {
270            return Err(MarshallingError::InvalidInput(
271                "unmarshal_binary: value out of range".to_owned(),
272            ));
273        }
274        Ok(())
275    }
276}
277
278impl Marshaling for Int {
279    fn marshal_to(&self, w: &mut impl std::io::Write) -> Result<(), MarshallingError> {
280        marshalling::scalar_marshal_to(self, w)
281    }
282
283    /// [`marshal_size()`] returns the length in bytes of encoded integers with modulus `m`.
284    /// The length of encoded [`ints`](Int) depends only on the size of the modulus,
285    /// and not on the the value of the encoded integer,
286    /// making the encoding is fixed-length for simplicity and security.
287    fn marshal_size(&self) -> usize {
288        ((self.m.bits()) + 7) / 8
289    }
290
291    fn unmarshal_from(&mut self, r: &mut impl std::io::Read) -> Result<(), MarshallingError> {
292        marshalling::scalar_unmarshal_from(self, r)
293    }
294
295    fn unmarshal_from_random(&mut self, r: &mut (impl std::io::Read + Stream)) {
296        marshalling::scalar_unmarshal_from_random(self, r);
297    }
298
299    fn marshal_id(&self) -> [u8; 8] {
300        MARSHAL_INT_ID
301    }
302}
303
304impl LowerHex for Int {
305    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
306        let prefix = if f.alternate() { "0x" } else { "" };
307        let encoded = hex::encode(self.v.to_bytes_be().1);
308        write!(f, "{prefix}{encoded}")
309    }
310}
311
312impl UpperHex for Int {
313    fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result {
314        let prefix = if f.alternate() { "0X" } else { "" };
315        let encoded = hex::encode_upper(self.v.to_bytes_be().1);
316        write!(f, "{prefix}{encoded}")
317    }
318}
319
320use core::ops::{self, Sub};
321impl_op_ex!(*|a: &Int, b: &Int| -> Int {
322    let m = a.m.clone();
323    let v = (a.v.clone() * b.v.clone()) % m.clone();
324    let bo = a.bo;
325    Int { v, m, bo }
326});
327
328impl_op_ex!(+|a: &Int, b: &Int| -> Int {
329        let m = a.m.clone();
330        let v = (a.v.clone() + b.v.clone()) % m.clone();
331        let bo = a.bo;
332        Int{v, m, bo}
333});
334
335impl Scalar for Int {
336    /// [`set()`] sets both `value` and `modulus` to be equal to another [`Int`].
337    /// Since this method copies the modulus as well,
338    fn set(self, a: &Self) -> Self {
339        let mut ai = self;
340        ai.v = a.v.clone();
341        ai.m = a.m.clone();
342        ai
343    }
344
345    /// [`set_int64()`] sets the [`Int`] to an arbitrary 64-bit "small integer" value.
346    /// The `modulus` must already be initialized.
347    fn set_int64(self, v: i64) -> Self {
348        let mut i = self;
349        i.v = BigInt::from(v);
350        // specify euclidean modulus for negative number
351        match i.v.sign() {
352            num_bigint::Sign::Minus => i.v = (i.v % i.m.clone()) + i.m.abs(),
353            _ => i.v %= i.m.clone(),
354        }
355        i
356    }
357
358    /// [`zero()`] set the [`Int`] to the value `0`. The modulus must already be initialized.
359    fn zero(self) -> Self {
360        let mut i = self;
361        i.v = BigInt::from(0_i64);
362        i
363    }
364
365    /// [`sub()`] sets the target to `a - b mod m`.
366    /// Target receives `a`'s modulus.
367    fn sub(mut self, a: &Self, b: &Self) -> Self {
368        self.m = a.m.clone();
369        let sub = &a.v - &b.v;
370        self.v = ((sub % &self.m) + &self.m) % &self.m;
371        // i.V.Sub(&ai.V, &bi.V).Mod(&i.V, i.M)
372        self
373    }
374
375    /// [`pick()`] a pseudo-random integer modulo `m`
376    /// using bits from the given stream cipher.
377    fn pick(self, rand: &mut impl Stream) -> Self {
378        let mut s = self.clone();
379        s.v.clone_from(&random_int(&self.m, rand));
380        s
381    }
382
383    /// [`set_bytes()`] set the value value to a number represented
384    /// by a byte string.
385    /// `Endianness` depends on the endianess set in `i`.
386    fn set_bytes(self, a: &[u8]) -> Self {
387        let mut buff = a.to_vec();
388        if self.bo == LittleEndian {
389            buff = reverse(vec![0; buff.len()].as_ref(), a);
390        }
391        Int {
392            m: self.m.clone(),
393            v: BigInt::from_bytes_be(Plus, buff.as_ref()) % &self.m,
394            bo: self.bo,
395        }
396    }
397
398    /// [`one()`] sets the [`Int`] to the value `1`.  The `modulus` must already be initialized.
399    fn one(self) -> Self {
400        let mut i = self;
401        i.v = BigInt::from(1_i64);
402        i
403    }
404
405    /// [`div()`] sets the target to `a * b^-1 mod m`, where `b^-1` is the modular inverse of `b`.
406    fn div(mut self, a: &Self, b: &Self) -> Self {
407        let _t = BigInt::default();
408        self.v = a.v.clone() * b.v.clone();
409        self.v = self.v.clone() % self.m.clone();
410        self
411    }
412
413    /// [`inv()`] sets the target to the modular inverse of a with respect to modulus `m`.
414    fn inv(self, a: &Self) -> Self {
415        let mut i = self;
416        i.v = a.clone().v.mod_inverse(&a.m.clone()).unwrap();
417        i.m = a.m.clone();
418        i
419    }
420
421    /// [`neg()`] sets the target to `-a mod m`.
422    fn neg(self, a: &Self) -> Self {
423        let mut i = self;
424        i.m = a.m.clone();
425        i.v = match a.v.sign() {
426            Plus => a.m.clone().sub(&a.v),
427            _ => BigInt::from(0_u64),
428        };
429        i
430    }
431}
432
433impl Int {
434    /// [`nonzero()`] returns `true` if the integer value is `nonzero`.
435    pub fn nonzero(&self) -> bool {
436        self.v.sign() != Sign::NoSign
437    }
438
439    /// [`int64()`] returns the [`i64`] representation of the value.
440    /// If the value is not representable in an [`i64`] the result is undefined.
441    pub fn int64(&self) -> i64 {
442        self.uint64() as i64
443    }
444
445    /// [`set_uint64()`] sets the Int to an arbitrary [`u64`] value.
446    /// The modulus must already be initialized.
447    pub fn set_uint64(&self, v: u64) -> Self {
448        let mut i = self.clone();
449        i.v = BigInt::from(v) % i.m.clone();
450        i
451    }
452
453    /// [`uint64()`] returns the [`u64`] representation of the value.
454    /// If the value is not representable in an [`u64`] the result is undefined.
455    pub fn uint64(&self) -> u64 {
456        let mut b = self.v.to_bytes_le().1;
457        b.resize(8, 0_u8);
458        let mut a = [0_u8; 8];
459        for (i, _) in b.iter().enumerate() {
460            a[i] = b[i];
461        }
462        let u = u64::from_le_bytes(a);
463        match self.v.sign() {
464            Sign::Minus => core::u64::MAX - u + 1,
465            _ => u,
466        }
467    }
468
469    /// [`exp()`] sets the target to `a^e mod m`,
470    /// where `e` is an arbitrary [`BigInt`] exponent (not necessarily `0 <= e < m`).
471    pub fn exp(mut self, a: &Self, e: &BigInt) -> Self {
472        self.m = a.m.clone();
473        // to protect against golang/go#22830
474        self.v = self.v.modpow(e, &self.m);
475        self
476    }
477
478    /// [`jacobi()`] computes the `Jacobi` symbol of `(a/m)`, which indicates whether a is
479    /// `zero` (`0`), a positive square in `m` (`1`), or a non-square in `m` (`-1`).
480    pub fn jacobi(&self, a_s: &Self) -> Self {
481        let mut i = self.clone();
482        i.m = a_s.m.clone();
483        i.v = BigInt::from(jacobi(&a_s.v, &i.m) as i64);
484        i
485    }
486
487    /// [`sqrt()`] computes some square root of a `mod m` of ONE exists.
488    /// Assumes the modulus `m` is an `odd prime`.
489    /// Returns `true` on success, `false` if input a is not a square.
490    pub fn sqrt(&mut self, a_s: &Self) -> Result<(), IntError> {
491        if a_s.v.sign() == Sign::Minus {
492            return Err(IntError::ImaginaryRoot);
493        }
494        self.v = a_s.v.sqrt() % a_s.m.clone();
495        self.m = a_s.m.clone();
496        Ok(())
497    }
498
499    /// [`big_endian()`] encodes the value of this [`Int`] into a big-endian byte-slice
500    /// at least `min` bytes but no more than `max` bytes long.
501    /// Returns an [`Error`](IntError) if `max != 0` and the [`Int`] cannot be represented in `max` bytes.
502    pub fn big_endian(&self, min: usize, max: usize) -> Result<Vec<u8>, IntError> {
503        let act = self.marshal_size();
504        let (mut pad, mut ofs) = (act, 0);
505        if pad < min {
506            (pad, ofs) = (min, min - act)
507        }
508        if max != 0 && pad > max {
509            return Err(IntError::NotRepresentable);
510        }
511        let mut buf = vec![0_u8; pad];
512        let b = self.v.to_bytes_be().1;
513        buf[ofs..].copy_from_slice(&b);
514        Ok(buf)
515    }
516}
517
518/// [`reverse()`] copies `src` into `dst` in byte-reversed order and returns `dst`,
519/// such that `src[0]` goes into `dst[len-1]` and vice versa.
520/// `dst` and `src` may be the same slice but otherwise must not overlap.
521fn reverse(dst: &[u8], src: &[u8]) -> Vec<u8> {
522    let mut dst = dst.to_vec();
523    let l = dst.len();
524    for i in 0..(l + 1) / 2 {
525        let j = l - 1 - i;
526        (dst[i], dst[j]) = (src[j], src[i]);
527    }
528    dst.to_vec()
529}
530
531#[derive(Debug, Error)]
532pub enum IntError {
533    #[error("marshalling error")]
534    MarshallingError(#[from] MarshallingError),
535    #[error("parse big int error")]
536    ParseBigIntError(#[from] ParseBigIntError),
537    #[error("Int not representable in max bytes")]
538    NotRepresentable,
539    #[error("input is a negative number, square root is imaginary")]
540    ImaginaryRoot,
541}