arcis_compiler/utils/
number.rs

1use num_bigint::{BigInt, ParseBigIntError, RandBigInt, Sign, ToBigInt};
2use num_traits::{FromPrimitive, Num, One, Signed, ToBytes, ToPrimitive, Zero};
3use rand::Rng;
4use serde::{Deserialize, Serialize};
5use std::{
6    cmp::Ordering,
7    fmt::{Display, Formatter},
8    iter::Sum,
9    ops::{Add, BitAnd, Div, Mul, Neg, Rem, Shr, Sub},
10    str::FromStr,
11};
12use Number::{BigNum, SmallNum};
13
14/// Values that can fit into an i64 are SmallNum, others are BigNum.
15/// We store them in i128 so that the multiplications can happen without overflow.
16#[derive(Debug, Clone, Hash, PartialEq, Eq, Serialize, Deserialize)]
17pub enum Number {
18    SmallNum(i128),
19    BigNum(BigInt),
20}
21
22impl From<Number> for f64 {
23    fn from(number: Number) -> Self {
24        match number {
25            SmallNum(i) => i.to_f64().unwrap(),
26            BigNum(n) => n.to_f64().unwrap(),
27        }
28    }
29}
30
31impl From<Number> for BigInt {
32    #[inline(always)]
33    fn from(value: Number) -> Self {
34        match value {
35            SmallNum(n) => n.into(),
36            BigNum(n) => n,
37        }
38    }
39}
40
41impl From<BigInt> for Number {
42    #[inline(always)]
43    fn from(value: BigInt) -> Self {
44        if let Some(val64) = value.to_i64() {
45            SmallNum(val64 as i128)
46        } else {
47            BigNum(value)
48        }
49    }
50}
51
52impl From<f64> for Number {
53    fn from(value: f64) -> Self {
54        BigInt::from_f64(value).expect("BigInt from f64.").into()
55    }
56}
57
58impl From<f32> for Number {
59    fn from(value: f32) -> Self {
60        BigInt::from_f32(value).expect("BigInt from f32.").into()
61    }
62}
63
64impl<const N: usize> From<[u8; N]> for Number {
65    fn from(value: [u8; N]) -> Self {
66        BigInt::from_bytes_le(Sign::Plus, &value).into()
67    }
68}
69
70fn cut_off<const N: usize>(val: &[u8]) -> [u8; N] {
71    (0..N)
72        .map(|i| if i < val.len() { val[i] } else { 0 })
73        .collect::<Vec<_>>()
74        .try_into()
75        .unwrap()
76}
77
78impl<const N: usize> From<Number> for [u8; N] {
79    fn from(value: Number) -> Self {
80        match value {
81            SmallNum(i) => cut_off(&i.to_le_bytes()),
82            BigNum(i) => cut_off(&i.to_le_bytes()),
83        }
84    }
85}
86
87impl From<&str> for Number {
88    #[inline(always)]
89    fn from(value: &str) -> Self {
90        Number::from(BigInt::from_str(value).unwrap())
91    }
92}
93
94macro_rules! impl_from_low {
95    ($t: ty) => {
96        impl From<$t> for Number {
97            #[inline(always)]
98            fn from(value: $t) -> Self {
99                SmallNum(value as i128)
100            }
101        }
102        impl_from_ref!($t);
103    };
104}
105
106macro_rules! impl_from_high {
107    ($t: ty) => {
108        impl From<$t> for Number {
109            #[inline(always)]
110            fn from(value: $t) -> Self {
111                if let Some(val64) = value.to_i64() {
112                    SmallNum(val64 as i128)
113                } else {
114                    BigNum(value.into())
115                }
116            }
117        }
118        impl_from_ref!($t);
119    };
120}
121
122macro_rules! impl_from_ref {
123    ($t: ty) => {
124        impl<'a> From<&'a $t> for Number {
125            #[inline(always)]
126            fn from(value: &'a $t) -> Self {
127                Number::from(*value)
128            }
129        }
130    };
131}
132
133impl_from_low!(bool);
134impl_from_low!(u8);
135impl_from_low!(u16);
136impl_from_low!(u32);
137impl_from_low!(i8);
138impl_from_low!(i16);
139impl_from_low!(i32);
140impl_from_high!(usize);
141impl_from_high!(u64);
142impl_from_high!(u128);
143impl_from_high!(isize);
144impl_from_high!(i64);
145impl_from_high!(i128);
146
147macro_rules! match_binary_op {
148    ($f: ident, $s: ident, $r: ident) => {
149        match ($s, $r) {
150            (SmallNum(a), SmallNum(b)) => a.$f(b).into(),
151            (SmallNum(a), BigNum(b)) => a.$f(b).into(),
152            (BigNum(a), SmallNum(b)) => a.$f(b).into(),
153            (BigNum(a), BigNum(b)) => a.$f(b).into(),
154        }
155    };
156}
157
158macro_rules! match_single_op {
159    ($f: ident, $n: ident, $i: expr) => {
160        match ($n) {
161            SmallNum(a) => a.$f($i as i128).into(),
162            BigNum(a) => a.$f($i).into(),
163        }
164    };
165}
166
167macro_rules! match_single_op_reverse {
168    ($f: ident, $n: ident, $i: expr) => {
169        match ($n) {
170            SmallNum(a) => ($i as i128).$f(a).into(),
171            BigNum(a) => ($i as i128).$f(a).into(),
172        }
173    };
174}
175
176macro_rules! binary_op {
177    ($t: ident, $f: ident) => {
178        impl $t<Number> for Number {
179            type Output = Number;
180            #[inline(always)]
181            fn $f(self, rhs: Number) -> Number {
182                match_binary_op!($f, self, rhs)
183            }
184        }
185        impl<'b> $t<&'b Number> for Number {
186            type Output = Number;
187            #[inline(always)]
188            fn $f(self, rhs: &'b Number) -> Number {
189                match_binary_op!($f, self, rhs)
190            }
191        }
192        impl $t<i32> for Number {
193            type Output = Number;
194            #[inline(always)]
195            fn $f(self, rhs: i32) -> Number {
196                match_single_op!($f, self, rhs)
197            }
198        }
199        impl $t<&i32> for Number {
200            type Output = Number;
201            #[inline(always)]
202            fn $f(self, rhs: &i32) -> Number {
203                match_single_op!($f, self, *rhs)
204            }
205        }
206        impl $t<Number> for i32 {
207            type Output = Number;
208            #[inline(always)]
209            fn $f(self, rhs: Number) -> Number {
210                match_single_op_reverse!($f, rhs, self)
211            }
212        }
213        impl $t<Number> for &i32 {
214            type Output = Number;
215            #[inline(always)]
216            fn $f(self, rhs: Number) -> Number {
217                match_single_op_reverse!($f, rhs, *self)
218            }
219        }
220        impl<'a> $t<Number> for &'a Number {
221            type Output = Number;
222            #[inline(always)]
223            fn $f(self, rhs: Number) -> Number {
224                match_binary_op!($f, self, rhs)
225            }
226        }
227        impl<'a, 'b> $t<&'b Number> for &'a Number {
228            type Output = Number;
229            #[inline(always)]
230            fn $f(self, rhs: &'b Number) -> Number {
231                match_binary_op!($f, self, rhs)
232            }
233        }
234        impl<'a> $t<i32> for &'a Number {
235            type Output = Number;
236            #[inline(always)]
237            fn $f(self, rhs: i32) -> Number {
238                match_single_op!($f, self, rhs)
239            }
240        }
241        impl<'a> $t<&i32> for &'a Number {
242            type Output = Number;
243            #[inline(always)]
244            fn $f(self, rhs: &i32) -> Number {
245                match_single_op!($f, self, *rhs)
246            }
247        }
248        impl<'a> $t<&'a Number> for i32 {
249            type Output = Number;
250            #[inline(always)]
251            fn $f(self, rhs: &'a Number) -> Number {
252                match_single_op_reverse!($f, rhs, self)
253            }
254        }
255        impl<'a> $t<&'a Number> for &i32 {
256            type Output = Number;
257            #[inline(always)]
258            fn $f(self, rhs: &'a Number) -> Number {
259                match_single_op_reverse!($f, rhs, *self)
260            }
261        }
262    };
263}
264
265binary_op!(Add, add);
266binary_op!(Sub, sub);
267binary_op!(Mul, mul);
268binary_op!(Div, div);
269binary_op!(Rem, rem);
270
271impl BitAnd for Number {
272    type Output = Number;
273
274    fn bitand(self, rhs: Self) -> Self::Output {
275        match (self, rhs) {
276            (SmallNum(a), SmallNum(b)) => SmallNum(a.bitand(b)),
277            (BigNum(a), SmallNum(b)) => a
278                .bitand(b.to_bigint().expect("i128 to BigInt always works"))
279                .into(),
280            (SmallNum(a), BigNum(b)) => b
281                .bitand(a.to_bigint().expect("i128 to BigInt always works."))
282                .into(),
283            (BigNum(a), BigNum(b)) => a.bitand(b).into(),
284        }
285    }
286}
287
288impl Shr<usize> for Number {
289    type Output = Number;
290
291    #[inline(always)]
292    fn shr(self, rhs: usize) -> Self::Output {
293        match self {
294            SmallNum(n) => (n >> rhs.min(127)).into(), // returning the sign bit in case rhs > 127
295            BigNum(n) => (n >> rhs).into(),
296        }
297    }
298}
299
300impl Shr<&usize> for Number {
301    type Output = Number;
302
303    #[inline(always)]
304    fn shr(self, rhs: &usize) -> Self::Output {
305        match self {
306            SmallNum(n) => (n >> rhs).into(),
307            BigNum(n) => (n >> rhs).into(),
308        }
309    }
310}
311
312impl PartialOrd<Self> for Number {
313    #[inline(always)]
314    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
315        Some(self.cmp(other))
316    }
317}
318
319impl PartialEq<i32> for Number {
320    fn eq(&self, other: &i32) -> bool {
321        self == &Number::from(*other)
322    }
323}
324
325impl PartialEq<Number> for i32 {
326    fn eq(&self, other: &Number) -> bool {
327        other == &Number::from(*self)
328    }
329}
330
331impl PartialOrd<i32> for Number {
332    #[inline(always)]
333    fn partial_cmp(&self, other: &i32) -> Option<Ordering> {
334        Some(self.cmp(&Number::from(*other)))
335    }
336}
337
338impl PartialOrd<Number> for i32 {
339    #[inline(always)]
340    fn partial_cmp(&self, other: &Number) -> Option<Ordering> {
341        Some(Number::from(*self).cmp(other))
342    }
343}
344
345impl Ord for Number {
346    #[inline(always)]
347    fn cmp(&self, other: &Self) -> Ordering {
348        match (self, other) {
349            (SmallNum(a), SmallNum(b)) => a.cmp(b),
350            (SmallNum(_), BigNum(b)) => match b.sign() {
351                Sign::Minus => Ordering::Greater,
352                Sign::NoSign => unreachable!("zero is a small num, not a big num"),
353                Sign::Plus => Ordering::Less,
354            },
355            (BigNum(a), SmallNum(_)) => match a.sign() {
356                Sign::Minus => Ordering::Less,
357                Sign::NoSign => unreachable!("zero is a small num, not a big num"),
358                Sign::Plus => Ordering::Greater,
359            },
360            (BigNum(a), BigNum(b)) => a.cmp(b),
361        }
362    }
363}
364
365impl Num for Number {
366    type FromStrRadixErr = ParseBigIntError;
367
368    fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
369        Ok(BigInt::from_str_radix(str, radix)?.into())
370    }
371}
372
373impl FromStr for Number {
374    type Err = ParseBigIntError;
375    fn from_str(str: &str) -> Result<Self, Self::Err> {
376        if let Some(stripped) = str.strip_prefix("0x") {
377            Self::from_str_radix(stripped, 16)
378        } else if let Some(stripped) = str.strip_prefix("0o") {
379            Self::from_str_radix(stripped, 8)
380        } else if let Some(stripped) = str.strip_prefix("0b") {
381            Self::from_str_radix(stripped, 2)
382        } else {
383            Self::from_str_radix(str, 10)
384        }
385    }
386}
387
388impl Zero for Number {
389    #[inline(always)]
390    fn zero() -> Self {
391        SmallNum(0)
392    }
393
394    #[inline(always)]
395    fn is_zero(&self) -> bool {
396        *self == SmallNum(0)
397    }
398}
399
400impl One for Number {
401    #[inline(always)]
402    fn one() -> Self {
403        SmallNum(1)
404    }
405}
406
407impl Neg for Number {
408    type Output = Number;
409
410    #[inline(always)]
411    fn neg(self) -> Self::Output {
412        match self {
413            SmallNum(n) => (-n).into(),
414            BigNum(n) => (-n).into(),
415        }
416    }
417}
418
419impl Signed for Number {
420    #[inline(always)]
421    fn abs(&self) -> Self {
422        match self {
423            SmallNum(n) => n.abs().into(),
424            BigNum(n) => n.abs().into(),
425        }
426    }
427
428    #[inline(always)]
429    fn abs_sub(&self, other: &Self) -> Self {
430        if self <= other {
431            0.into()
432        } else {
433            self.clone() - other.clone()
434        }
435    }
436
437    #[inline(always)]
438    fn signum(&self) -> Self {
439        match self {
440            SmallNum(n) => n.signum().into(),
441            BigNum(n) => n.signum().into(),
442        }
443    }
444
445    #[inline(always)]
446    fn is_positive(&self) -> bool {
447        match self {
448            SmallNum(n) => n.is_positive(),
449            BigNum(n) => n.is_positive(),
450        }
451    }
452
453    #[inline(always)]
454    fn is_negative(&self) -> bool {
455        match self {
456            SmallNum(n) => n.is_negative(),
457            BigNum(n) => n.is_negative(),
458        }
459    }
460}
461
462impl Number {
463    #[inline(always)]
464    pub fn bit(&self, idx: usize) -> bool {
465        match self {
466            SmallNum(n) => ((n >> idx.min(127)) & 1) == 1, /* returning the sign bit in case idx
467                                                             * > */
468            // 127
469            BigNum(n) => n.bit(idx as u64),
470        }
471    }
472
473    #[inline(always)]
474    pub fn power_of_two(idx: usize) -> Number {
475        if idx < 63 {
476            SmallNum(1 << idx)
477        } else {
478            BigNum(BigInt::from(1) << idx)
479        }
480    }
481
482    #[inline(always)]
483    pub fn negative_power_of_two(idx: usize) -> Number {
484        if idx < 64 {
485            SmallNum(-1 << idx)
486        } else {
487            BigNum(BigInt::from(-1) << idx)
488        }
489    }
490
491    #[inline(always)]
492    pub fn bits(&self) -> usize {
493        match self {
494            SmallNum(n) => {
495                if *n == 0 {
496                    1
497                } else {
498                    1 + n.abs().ilog2() as usize
499                }
500            }
501            BigNum(n) => n.bits() as usize,
502        }
503    }
504    pub fn gen_range<R: Rng + ?Sized>(rng: &mut R, lower: &Number, upper: &Number) -> Number {
505        match (lower, upper) {
506            (SmallNum(l), SmallNum(u)) => rng.gen_range((*l)..(*u)).into(),
507            (BigNum(l), BigNum(u)) => rng.gen_bigint_range(l, u).into(),
508            (SmallNum(l), BigNum(u)) => rng.gen_bigint_range(&(*l).into(), u).into(),
509            (BigNum(l), SmallNum(u)) => rng.gen_bigint_range(l, &(*u).into()).into(),
510        }
511    }
512}
513
514impl Sum for Number {
515    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
516        let mut sum: Number = SmallNum(0);
517        for n in iter {
518            sum = sum + n
519        }
520        sum
521    }
522}
523
524impl Display for Number {
525    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
526        match self {
527            SmallNum(n) => n.fmt(f),
528            BigNum(n) => n.fmt(f),
529        }
530    }
531}
532
533impl ToPrimitive for Number {
534    #[inline(always)]
535    fn to_i64(&self) -> Option<i64> {
536        match self {
537            SmallNum(n) => n.to_i64(),
538            BigNum(n) => n.to_i64(),
539        }
540    }
541
542    #[inline(always)]
543    fn to_i128(&self) -> Option<i128> {
544        match self {
545            SmallNum(n) => n.to_i128(),
546            BigNum(n) => n.to_i128(),
547        }
548    }
549
550    #[inline(always)]
551    fn to_u64(&self) -> Option<u64> {
552        match self {
553            SmallNum(n) => n.to_u64(),
554            BigNum(n) => n.to_u64(),
555        }
556    }
557
558    #[inline(always)]
559    fn to_u128(&self) -> Option<u128> {
560        match self {
561            SmallNum(n) => n.to_u128(),
562            BigNum(n) => n.to_u128(),
563        }
564    }
565}
566
567impl Number {
568    pub fn log2(&self) -> f64 {
569        f64::from(self.clone()).log2()
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use crate::utils::number::Number;
576
577    #[test]
578    fn conversions_from_u8_arr() {
579        assert_eq!(Number::from([]), Number::from(0));
580        assert_eq!(Number::from([123u8]), Number::from(123));
581        assert_eq!(Number::from([44u8, 1]), Number::from(300));
582    }
583    fn test_conversion_to_u8_arr<const N: usize>(num: Number, true_res: [u8; N]) {
584        let arr: [u8; N] = num.into();
585        assert_eq!(arr, true_res);
586    }
587    #[test]
588    fn conversions_to_u8_arr() {
589        test_conversion_to_u8_arr(Number::from(0), [0u8, 0u8, 0u8]);
590        test_conversion_to_u8_arr(Number::from(123), [123u8]);
591        test_conversion_to_u8_arr(Number::from(300), [44u8, 1u8]);
592    }
593}