Skip to main content

candid/types/
number.rs

1//! Data structure for Candid type Int, Nat, supporting big integer with LEB128 encoding.
2
3use super::{CandidType, Serializer, Type, TypeInner};
4use crate::{utils::pp_num_str, Error};
5use num_bigint::{BigInt, BigUint};
6use serde::{
7    de::{self, Deserialize, SeqAccess, Visitor},
8    Serialize,
9};
10use std::convert::From;
11use std::{fmt, io};
12
13#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
14pub struct Int(pub BigInt);
15#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
16pub struct Nat(pub BigUint);
17
18impl From<BigInt> for Int {
19    fn from(i: BigInt) -> Self {
20        Self(i)
21    }
22}
23
24impl From<BigUint> for Nat {
25    fn from(i: BigUint) -> Self {
26        Self(i)
27    }
28}
29
30impl From<Nat> for Int {
31    fn from(n: Nat) -> Self {
32        let i: BigInt = n.0.into();
33        i.into()
34    }
35}
36
37impl From<Int> for BigInt {
38    fn from(i: Int) -> Self {
39        i.0
40    }
41}
42
43impl From<Nat> for BigUint {
44    fn from(i: Nat) -> Self {
45        i.0
46    }
47}
48
49impl From<Nat> for BigInt {
50    fn from(i: Nat) -> Self {
51        i.0.into()
52    }
53}
54
55impl Int {
56    #[inline]
57    pub fn parse(v: &[u8]) -> crate::Result<Self> {
58        let res = BigInt::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigInt"))?;
59        Ok(Int(res))
60    }
61}
62
63impl Nat {
64    #[inline]
65    pub fn parse(v: &[u8]) -> crate::Result<Self> {
66        let res = BigUint::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigUint"))?;
67        Ok(Nat(res))
68    }
69}
70
71impl std::str::FromStr for Int {
72    type Err = crate::Error;
73    fn from_str(str: &str) -> Result<Self, Self::Err> {
74        Self::parse(str.as_bytes())
75    }
76}
77
78impl std::str::FromStr for Nat {
79    type Err = crate::Error;
80    fn from_str(str: &str) -> Result<Self, Self::Err> {
81        Self::parse(str.as_bytes())
82    }
83}
84
85impl fmt::Display for Int {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        let s = self.0.to_str_radix(10);
88        f.write_str(&pp_num_str(&s))
89    }
90}
91
92impl fmt::Display for Nat {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let s = self.0.to_str_radix(10);
95        f.write_str(&pp_num_str(&s))
96    }
97}
98
99impl CandidType for Int {
100    fn _ty() -> Type {
101        TypeInner::Int.into()
102    }
103    fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
104    where
105        S: Serializer,
106    {
107        serializer.serialize_int(self)
108    }
109}
110
111impl CandidType for Nat {
112    fn _ty() -> Type {
113        TypeInner::Nat.into()
114    }
115    fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
116    where
117        S: Serializer,
118    {
119        serializer.serialize_nat(self)
120    }
121}
122
123impl<'de> Deserialize<'de> for Int {
124    fn deserialize<D>(deserializer: D) -> Result<Int, D::Error>
125    where
126        D: serde::Deserializer<'de>,
127    {
128        struct IntVisitor;
129        impl Visitor<'_> for IntVisitor {
130            type Value = Int;
131            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132                formatter.write_str("Int value")
133            }
134            fn visit_i64<E>(self, v: i64) -> Result<Int, E> {
135                Ok(Int::from(v))
136            }
137            fn visit_u64<E>(self, v: u64) -> Result<Int, E> {
138                Ok(Int::from(v))
139            }
140            fn visit_str<E: de::Error>(self, v: &str) -> Result<Int, E> {
141                v.parse::<Int>()
142                    .map_err(|_| de::Error::custom(format!("{v:?} is not int")))
143            }
144            fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Int, E> {
145                Ok(Int(match v.first() {
146                    Some(0) => BigInt::from_signed_bytes_le(&v[1..]),
147                    Some(1) => BigInt::from_biguint(
148                        num_bigint::Sign::Plus,
149                        BigUint::from_bytes_le(&v[1..]),
150                    ),
151                    _ => return Err(de::Error::custom("not int nor nat")),
152                }))
153            }
154        }
155        deserializer.deserialize_any(IntVisitor)
156    }
157}
158
159impl<'de> Deserialize<'de> for Nat {
160    fn deserialize<D>(deserializer: D) -> Result<Nat, D::Error>
161    where
162        D: serde::Deserializer<'de>,
163    {
164        struct NatVisitor;
165        impl<'de> Visitor<'de> for NatVisitor {
166            type Value = Nat;
167            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
168                formatter.write_str("Nat value")
169            }
170            fn visit_i64<E: de::Error>(self, v: i64) -> Result<Nat, E> {
171                use num_bigint::ToBigUint;
172                v.to_biguint()
173                    .ok_or_else(|| de::Error::custom("i64 cannot be converted to nat"))
174                    .map(Nat)
175            }
176            fn visit_u64<E>(self, v: u64) -> Result<Nat, E> {
177                Ok(Nat::from(v))
178            }
179            fn visit_str<E: de::Error>(self, v: &str) -> Result<Nat, E> {
180                v.parse::<Nat>()
181                    .map_err(|_| de::Error::custom(format!("{v:?} is not nat")))
182            }
183            fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Nat, E> {
184                if v[0] == 1 {
185                    Ok(Nat(BigUint::from_bytes_le(&v[1..])))
186                } else {
187                    Err(de::Error::custom("not nat"))
188                }
189            }
190
191            fn visit_seq<S>(self, mut seq: S) -> Result<Nat, S::Error>
192            where
193                S: SeqAccess<'de>,
194            {
195                let len = seq.size_hint().unwrap_or(0);
196                let mut data = Vec::with_capacity(len);
197
198                while let Some(value) = seq.next_element::<u32>()? {
199                    data.push(value);
200                }
201
202                Ok(Nat(BigUint::new(data)))
203            }
204        }
205        deserializer.deserialize_any(NatVisitor)
206    }
207}
208
209// LEB128 encoding for bignum.
210
211impl Nat {
212    pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
213    where
214        W: ?Sized + io::Write,
215    {
216        use num_traits::cast::ToPrimitive;
217        if let Some(value) = self.0.to_u64() {
218            leb128::write::unsigned(w, value)?;
219            return Ok(());
220        }
221        let zero = BigUint::from(0u8);
222        let mut value = self.0.clone();
223        loop {
224            let big_byte = &value & BigUint::from(0x7fu8);
225            let mut byte = big_byte.to_u8().unwrap();
226            value >>= 7;
227            if value != zero {
228                byte |= 0x80u8;
229            }
230            let buf = [byte];
231            w.write_all(&buf)?;
232            if value == zero {
233                return Ok(());
234            }
235        }
236    }
237    pub fn decode<R>(r: &mut R) -> crate::Result<Self>
238    where
239        R: io::Read,
240    {
241        let mut small = 0u64;
242        let mut shift = 0u32;
243        loop {
244            let mut buf = [0];
245            r.read_exact(&mut buf)?;
246            let byte = buf[0];
247            let low_bits = u64::from(byte & 0x7f);
248            if shift == 0 || (shift < 64 && low_bits < (1u64 << (64 - shift))) {
249                small |= low_bits << shift;
250                if byte & 0x80u8 == 0 {
251                    return Ok(Nat(BigUint::from(small)));
252                }
253                shift += 7;
254                continue;
255            }
256
257            let mut result = BigUint::from(small);
258            result |= BigUint::from(low_bits) << shift;
259            if byte & 0x80u8 == 0 {
260                return Ok(Nat(result));
261            }
262            shift += 7;
263            loop {
264                let mut buf = [0];
265                r.read_exact(&mut buf)?;
266                let byte = buf[0];
267                let low_bits = BigUint::from(byte & 0x7fu8);
268                result |= low_bits << shift;
269                if byte & 0x80u8 == 0 {
270                    return Ok(Nat(result));
271                }
272                shift += 7;
273            }
274        }
275    }
276}
277
278impl Int {
279    pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
280    where
281        W: ?Sized + io::Write,
282    {
283        use num_traits::cast::ToPrimitive;
284        if let Some(value) = self.0.to_i64() {
285            leb128::write::signed(w, value)?;
286            return Ok(());
287        }
288        let zero = BigInt::from(0);
289        let mut value = self.0.clone();
290        loop {
291            let big_byte = &value & BigInt::from(0xff);
292            let mut byte = big_byte.to_u8().unwrap();
293            value >>= 6;
294            let done = value == zero || value == BigInt::from(-1);
295            if done {
296                byte &= 0x7f;
297            } else {
298                value >>= 1;
299                byte |= 0x80;
300            }
301            let buf = [byte];
302            w.write_all(&buf)?;
303            if done {
304                return Ok(());
305            }
306        }
307    }
308    pub fn decode<R>(r: &mut R) -> crate::Result<Self>
309    where
310        R: io::Read,
311    {
312        let mut small = 0i64;
313        let mut shift = 0u32;
314        loop {
315            let mut buf = [0];
316            r.read_exact(&mut buf)?;
317            let byte = buf[0];
318            let low_bits = i64::from(byte & 0x7f);
319
320            let fits_i64 = if shift < 57 {
321                true
322            } else if shift < 64 && byte & 0x80 == 0 {
323                // Only the terminal byte can confirm the value fits in i64:
324                // if continuation is set, more bytes follow at shifts >= 64
325                // and the value's high bits won't fit. Without this guard,
326                // `low_bits << shift` would silently truncate bits 1..6.
327                let remaining_bits = 64 - shift;
328                if (byte & 0x40) != 0 {
329                    (low_bits | !0x7f) >> (remaining_bits - 1) == -1
330                } else {
331                    low_bits >> (remaining_bits - 1) == 0
332                }
333            } else {
334                false
335            };
336
337            if fits_i64 {
338                small |= low_bits << shift;
339                shift += 7;
340                if byte & 0x80 == 0 {
341                    if shift < 64 && (byte & 0x40) != 0 {
342                        small |= !0i64 << shift;
343                    }
344                    return Ok(Int(BigInt::from(small)));
345                }
346                continue;
347            }
348
349            let mut result = BigInt::from(small);
350            let big_low_bits = BigInt::from(byte & 0x7fu8);
351            result |= big_low_bits << shift;
352            shift += 7;
353            if byte & 0x80 == 0 {
354                if (byte & 0x40) != 0 {
355                    result |= BigInt::from(-1) << shift;
356                }
357                return Ok(Int(result));
358            }
359            loop {
360                let mut buf = [0];
361                r.read_exact(&mut buf)?;
362                let byte = buf[0];
363                let big_low_bits = BigInt::from(byte & 0x7fu8);
364                result |= big_low_bits << shift;
365                shift += 7;
366                if byte & 0x80 == 0 {
367                    if (byte & 0x40) != 0 {
368                        result |= BigInt::from(-1) << shift;
369                    }
370                    return Ok(Int(result));
371                }
372            }
373        }
374    }
375}
376
377// Define all operators and traits relevant for Nat and Int.
378use std::cmp::{Ord, Ordering, PartialEq, PartialOrd};
379use std::ops::*;
380
381macro_rules! define_from {
382    ($f: ty, $($t: ty)*) => ($(
383        impl From<$t> for $f {
384            #[inline]
385            fn from(v: $t) -> Self { Self(v.into()) }
386        }
387    )*)
388}
389
390macro_rules! define_eq {
391    ($f: ty, $($t: ty)*) => ($(
392        impl PartialEq<$t> for $f {
393            #[inline]
394            fn eq(&self, v: &$t) -> bool { self.0.eq(&(*v).into()) }
395        }
396        impl PartialEq<$f> for $t {
397            #[inline]
398            fn eq(&self, v: &$f) -> bool { v.0.eq(&(*self).into()) }
399        }
400    )*)
401}
402
403macro_rules! define_op {
404    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
405        // Implement A * B
406        impl $imp<$scalar> for $res {
407            type Output = $res;
408
409            #[inline]
410            fn $method(self, other: $scalar) -> $res {
411                $imp::$method(self.0, &other).into()
412            }
413        }
414
415        // Implement B * A
416        impl $imp<$res> for $scalar {
417            type Output = $res;
418
419            #[inline]
420            fn $method(self, other: $res) -> $res {
421                $imp::$method(&self, other.0).into()
422            }
423        }
424    };
425}
426
427macro_rules! define_ord {
428    ($scalar: ty, $res: ty) => {
429        // A < B
430        impl PartialOrd<$scalar> for $res {
431            #[inline]
432            fn partial_cmp(&self, other: &$scalar) -> Option<Ordering> {
433                PartialOrd::partial_cmp(self, &<$res>::from(*other))
434            }
435        }
436        // B < A
437        impl PartialOrd<$res> for $scalar {
438            #[inline]
439            fn partial_cmp(&self, other: &$res) -> Option<Ordering> {
440                PartialOrd::partial_cmp(&<$res>::from(*self), other)
441            }
442        }
443    };
444}
445
446macro_rules! define_op_assign {
447    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
448        // Implement A * B
449        impl $imp<$scalar> for $res {
450            #[inline]
451            fn $method(&mut self, other: $scalar) {
452                $imp::$method(&mut self.0, other)
453            }
454        }
455    };
456}
457
458macro_rules! define_ops {
459    ($f: ty, $($t: ty)*) => ($(
460        define_op!(impl Add<$t> for $f, add);
461        define_op!(impl Sub<$t> for $f, sub);
462        define_op!(impl Mul<$t> for $f, mul);
463        define_op!(impl Div<$t> for $f, div);
464        define_op!(impl Rem<$t> for $f, rem);
465
466        define_ord!($t, $f);
467
468        define_op_assign!(impl AddAssign<$t> for $f, add_assign);
469        define_op_assign!(impl SubAssign<$t> for $f, sub_assign);
470        define_op_assign!(impl MulAssign<$t> for $f, mul_assign);
471        define_op_assign!(impl DivAssign<$t> for $f, div_assign);
472        define_op_assign!(impl RemAssign<$t> for $f, rem_assign);
473    )*)
474}
475
476define_from!( Nat, usize u8 u16 u32 u64 u128 );
477define_from!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
478
479define_eq!( Nat, usize u8 u16 u32 u64 u128 );
480define_eq!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
481
482define_ops!( Nat, usize u8 u16 u32 u64 u128 );
483define_ops!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
484
485// Need a separate macro to extract the Big[U]Int from the Nat/Int struct.
486macro_rules! define_op_0 {
487    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
488        impl $imp<$scalar> for $res {
489            type Output = $res;
490
491            #[inline]
492            fn $method(self, other: $scalar) -> $res {
493                $imp::$method(self.0, &other.0).into()
494            }
495        }
496    };
497}
498
499macro_rules! define_op_0_assign {
500    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
501        // Implement A * B
502        impl $imp<$scalar> for $res {
503            #[inline]
504            fn $method(&mut self, other: $scalar) {
505                $imp::$method(&mut self.0, other.0)
506            }
507        }
508    };
509}
510
511define_op_0!(impl Add<Nat> for Nat, add);
512define_op_0!(impl Sub<Nat> for Nat, sub);
513define_op_0!(impl Mul<Nat> for Nat, mul);
514define_op_0!(impl Div<Nat> for Nat, div);
515define_op_0!(impl Rem<Nat> for Nat, rem);
516
517define_op_0_assign!(impl AddAssign<Nat> for Nat, add_assign);
518define_op_0_assign!(impl SubAssign<Nat> for Nat, sub_assign);
519define_op_0_assign!(impl MulAssign<Nat> for Nat, mul_assign);
520define_op_0_assign!(impl DivAssign<Nat> for Nat, div_assign);
521define_op_0_assign!(impl RemAssign<Nat> for Nat, rem_assign);
522
523define_op_0!(impl Add<Int> for Int, add);
524define_op_0!(impl Sub<Int> for Int, sub);
525define_op_0!(impl Mul<Int> for Int, mul);
526define_op_0!(impl Div<Int> for Int, div);
527define_op_0!(impl Rem<Int> for Int, rem);
528
529define_op_0_assign!(impl AddAssign<Int> for Int, add_assign);
530define_op_0_assign!(impl SubAssign<Int> for Int, sub_assign);
531define_op_0_assign!(impl MulAssign<Int> for Int, mul_assign);
532define_op_0_assign!(impl DivAssign<Int> for Int, div_assign);
533define_op_0_assign!(impl RemAssign<Int> for Int, rem_assign);
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use serde::Deserialize;
539
540    #[derive(Default, Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
541    pub struct TestStruct {
542        inner: Nat,
543    }
544
545    #[ignore]
546    #[test]
547    fn test_serde_with_bincode() {
548        // This ignored/failed test shows that bincode isn't supported.
549        let test_struct = TestStruct {
550            inner: Nat::from(1000u64),
551        };
552        let serialized = bincode::serialize(&test_struct).unwrap();
553        // panicked at 'called `Result::unwrap()` on an `Err` value: DeserializeAnyNotSupported'
554        let deserialized = bincode::deserialize(&serialized).unwrap();
555        assert_eq!(test_struct, deserialized);
556    }
557
558    #[test]
559    fn test_serde_with_json() {
560        let test_struct = TestStruct {
561            inner: Nat::from(1000u64),
562        };
563        let serialized = serde_json::to_string(&test_struct).unwrap();
564        let deserialized = serde_json::from_str(&serialized).unwrap();
565        assert_eq!(test_struct, deserialized);
566
567        // Nats serialize as arrays in JSON. The following tests the breakdown
568        // of a big number into an array.
569        // 13969838 * 2^32 + 2659581952 == 60000000000000000
570        let test_struct = TestStruct {
571            inner: Nat::parse(b"60000000000000000").unwrap(),
572        };
573        let serialized = serde_json::to_string(&test_struct).unwrap();
574        assert_eq!(serialized, "{\"inner\":[2659581952,13969838]}");
575        let deserialized = serde_json::from_str(&serialized).unwrap();
576        assert_eq!(test_struct, deserialized);
577    }
578
579    #[test]
580    fn test_serde_with_cbor() {
581        let test_struct = TestStruct {
582            inner: Nat::from(1000u64),
583        };
584        let serialized = serde_cbor::to_vec(&test_struct).unwrap();
585        let deserialized = serde_cbor::from_slice(&serialized).unwrap();
586        assert_eq!(test_struct, deserialized);
587
588        let test_struct = TestStruct {
589            inner: Nat::parse(b"60000000000000000").unwrap(),
590        };
591        let serialized = serde_cbor::to_vec(&test_struct).unwrap();
592        let deserialized = serde_cbor::from_slice(&serialized).unwrap();
593        assert_eq!(test_struct, deserialized);
594    }
595}