Skip to main content

candid/types/
number.rs

1//! Data structure for Candid type Int, Nat, supporting big integer with LEB128 encoding.
2
3use super::{CandidType, Serializer, Type, TypeInner};
4use crate::{utils::pp_num_str, Error};
5use num_bigint::{BigInt, BigUint};
6use serde::{
7    de::{self, Deserialize, SeqAccess, Visitor},
8    Serialize,
9};
10use std::convert::From;
11use std::{fmt, io};
12
13#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
14pub struct Int(pub BigInt);
15#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
16pub struct Nat(pub BigUint);
17
18impl From<BigInt> for Int {
19    fn from(i: BigInt) -> Self {
20        Self(i)
21    }
22}
23
24impl From<BigUint> for Nat {
25    fn from(i: BigUint) -> Self {
26        Self(i)
27    }
28}
29
30impl From<Nat> for Int {
31    fn from(n: Nat) -> Self {
32        let i: BigInt = n.0.into();
33        i.into()
34    }
35}
36
37impl From<Int> for BigInt {
38    fn from(i: Int) -> Self {
39        i.0
40    }
41}
42
43impl From<Nat> for BigUint {
44    fn from(i: Nat) -> Self {
45        i.0
46    }
47}
48
49impl From<Nat> for BigInt {
50    fn from(i: Nat) -> Self {
51        i.0.into()
52    }
53}
54
55impl Int {
56    #[inline]
57    pub fn parse(v: &[u8]) -> crate::Result<Self> {
58        let res = BigInt::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigInt"))?;
59        Ok(Int(res))
60    }
61}
62
63impl Nat {
64    #[inline]
65    pub fn parse(v: &[u8]) -> crate::Result<Self> {
66        let res = BigUint::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigUint"))?;
67        Ok(Nat(res))
68    }
69}
70
71impl std::str::FromStr for Int {
72    type Err = crate::Error;
73    fn from_str(str: &str) -> Result<Self, Self::Err> {
74        Self::parse(str.as_bytes())
75    }
76}
77
78impl std::str::FromStr for Nat {
79    type Err = crate::Error;
80    fn from_str(str: &str) -> Result<Self, Self::Err> {
81        Self::parse(str.as_bytes())
82    }
83}
84
85impl fmt::Display for Int {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        let s = self.0.to_str_radix(10);
88        f.write_str(&pp_num_str(&s))
89    }
90}
91
92impl fmt::Display for Nat {
93    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94        let s = self.0.to_str_radix(10);
95        f.write_str(&pp_num_str(&s))
96    }
97}
98
99impl CandidType for Int {
100    fn _ty() -> Type {
101        TypeInner::Int.into()
102    }
103    fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
104    where
105        S: Serializer,
106    {
107        serializer.serialize_int(self)
108    }
109}
110
111impl CandidType for Nat {
112    fn _ty() -> Type {
113        TypeInner::Nat.into()
114    }
115    fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
116    where
117        S: Serializer,
118    {
119        serializer.serialize_nat(self)
120    }
121}
122
123impl<'de> Deserialize<'de> for Int {
124    fn deserialize<D>(deserializer: D) -> Result<Int, D::Error>
125    where
126        D: serde::Deserializer<'de>,
127    {
128        struct IntVisitor;
129        impl Visitor<'_> for IntVisitor {
130            type Value = Int;
131            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132                formatter.write_str("Int value")
133            }
134            fn visit_i64<E>(self, v: i64) -> Result<Int, E> {
135                Ok(Int::from(v))
136            }
137            fn visit_u64<E>(self, v: u64) -> Result<Int, E> {
138                Ok(Int::from(v))
139            }
140            fn visit_str<E: de::Error>(self, v: &str) -> Result<Int, E> {
141                v.parse::<Int>()
142                    .map_err(|_| de::Error::custom(format!("{v:?} is not int")))
143            }
144            fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Int, E> {
145                Ok(Int(match v.first() {
146                    Some(0) => BigInt::from_signed_bytes_le(&v[1..]),
147                    Some(1) => BigInt::from_biguint(
148                        num_bigint::Sign::Plus,
149                        BigUint::from_bytes_le(&v[1..]),
150                    ),
151                    _ => return Err(de::Error::custom("not int nor nat")),
152                }))
153            }
154        }
155        deserializer.deserialize_any(IntVisitor)
156    }
157}
158
159impl<'de> Deserialize<'de> for Nat {
160    fn deserialize<D>(deserializer: D) -> Result<Nat, D::Error>
161    where
162        D: serde::Deserializer<'de>,
163    {
164        struct NatVisitor;
165        impl<'de> Visitor<'de> for NatVisitor {
166            type Value = Nat;
167            fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
168                formatter.write_str("Nat value")
169            }
170            fn visit_i64<E: de::Error>(self, v: i64) -> Result<Nat, E> {
171                use num_bigint::ToBigUint;
172                v.to_biguint()
173                    .ok_or_else(|| de::Error::custom("i64 cannot be converted to nat"))
174                    .map(Nat)
175            }
176            fn visit_u64<E>(self, v: u64) -> Result<Nat, E> {
177                Ok(Nat::from(v))
178            }
179            fn visit_str<E: de::Error>(self, v: &str) -> Result<Nat, E> {
180                v.parse::<Nat>()
181                    .map_err(|_| de::Error::custom(format!("{v:?} is not nat")))
182            }
183            fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Nat, E> {
184                if v[0] == 1 {
185                    Ok(Nat(BigUint::from_bytes_le(&v[1..])))
186                } else {
187                    Err(de::Error::custom("not nat"))
188                }
189            }
190
191            fn visit_seq<S>(self, mut seq: S) -> Result<Nat, S::Error>
192            where
193                S: SeqAccess<'de>,
194            {
195                let len = seq.size_hint().unwrap_or(0);
196                let mut data = Vec::with_capacity(len);
197
198                while let Some(value) = seq.next_element::<u32>()? {
199                    data.push(value);
200                }
201
202                Ok(Nat(BigUint::new(data)))
203            }
204        }
205        deserializer.deserialize_any(NatVisitor)
206    }
207}
208
209// LEB128 encoding for bignum.
210
211impl Nat {
212    pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
213    where
214        W: ?Sized + io::Write,
215    {
216        use num_traits::cast::ToPrimitive;
217        let zero = BigUint::from(0u8);
218        let mut value = self.0.clone();
219        loop {
220            let big_byte = &value & BigUint::from(0x7fu8);
221            let mut byte = big_byte.to_u8().unwrap();
222            value >>= 7;
223            if value != zero {
224                byte |= 0x80u8;
225            }
226            let buf = [byte];
227            w.write_all(&buf)?;
228            if value == zero {
229                return Ok(());
230            }
231        }
232    }
233    pub fn decode<R>(r: &mut R) -> crate::Result<Self>
234    where
235        R: io::Read,
236    {
237        let mut result = BigUint::from(0u8);
238        let mut shift = 0;
239        loop {
240            let mut buf = [0];
241            r.read_exact(&mut buf)?;
242            let low_bits = BigUint::from(buf[0] & 0x7fu8);
243            result |= low_bits << shift;
244            if buf[0] & 0x80u8 == 0 {
245                return Ok(Nat(result));
246            }
247            shift += 7;
248        }
249    }
250}
251
252impl Int {
253    pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
254    where
255        W: ?Sized + io::Write,
256    {
257        use num_traits::cast::ToPrimitive;
258        let zero = BigInt::from(0);
259        let mut value = self.0.clone();
260        loop {
261            let big_byte = &value & BigInt::from(0xff);
262            let mut byte = big_byte.to_u8().unwrap();
263            value >>= 6;
264            let done = value == zero || value == BigInt::from(-1);
265            if done {
266                byte &= 0x7f;
267            } else {
268                value >>= 1;
269                byte |= 0x80;
270            }
271            let buf = [byte];
272            w.write_all(&buf)?;
273            if done {
274                return Ok(());
275            }
276        }
277    }
278    pub fn decode<R>(r: &mut R) -> crate::Result<Self>
279    where
280        R: io::Read,
281    {
282        let mut result = BigInt::from(0);
283        let mut shift = 0;
284        let mut byte;
285        loop {
286            let mut buf = [0];
287            r.read_exact(&mut buf)?;
288            byte = buf[0];
289            let low_bits = BigInt::from(byte & 0x7fu8);
290            result |= low_bits << shift;
291            shift += 7;
292            if byte & 0x80u8 == 0 {
293                break;
294            }
295        }
296        if (0x40u8 & byte) == 0x40u8 {
297            result |= BigInt::from(-1) << shift;
298        }
299        Ok(Int(result))
300    }
301}
302
303// Define all operators and traits relevant for Nat and Int.
304use std::cmp::{Ord, Ordering, PartialEq, PartialOrd};
305use std::ops::*;
306
307macro_rules! define_from {
308    ($f: ty, $($t: ty)*) => ($(
309        impl From<$t> for $f {
310            #[inline]
311            fn from(v: $t) -> Self { Self(v.into()) }
312        }
313    )*)
314}
315
316macro_rules! define_eq {
317    ($f: ty, $($t: ty)*) => ($(
318        impl PartialEq<$t> for $f {
319            #[inline]
320            fn eq(&self, v: &$t) -> bool { self.0.eq(&(*v).into()) }
321        }
322        impl PartialEq<$f> for $t {
323            #[inline]
324            fn eq(&self, v: &$f) -> bool { v.0.eq(&(*self).into()) }
325        }
326    )*)
327}
328
329macro_rules! define_op {
330    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
331        // Implement A * B
332        impl $imp<$scalar> for $res {
333            type Output = $res;
334
335            #[inline]
336            fn $method(self, other: $scalar) -> $res {
337                $imp::$method(self.0, &other).into()
338            }
339        }
340
341        // Implement B * A
342        impl $imp<$res> for $scalar {
343            type Output = $res;
344
345            #[inline]
346            fn $method(self, other: $res) -> $res {
347                $imp::$method(&self, other.0).into()
348            }
349        }
350    };
351}
352
353macro_rules! define_ord {
354    ($scalar: ty, $res: ty) => {
355        // A < B
356        impl PartialOrd<$scalar> for $res {
357            #[inline]
358            fn partial_cmp(&self, other: &$scalar) -> Option<Ordering> {
359                PartialOrd::partial_cmp(self, &<$res>::from(*other))
360            }
361        }
362        // B < A
363        impl PartialOrd<$res> for $scalar {
364            #[inline]
365            fn partial_cmp(&self, other: &$res) -> Option<Ordering> {
366                PartialOrd::partial_cmp(&<$res>::from(*self), other)
367            }
368        }
369    };
370}
371
372macro_rules! define_op_assign {
373    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
374        // Implement A * B
375        impl $imp<$scalar> for $res {
376            #[inline]
377            fn $method(&mut self, other: $scalar) {
378                $imp::$method(&mut self.0, other)
379            }
380        }
381    };
382}
383
384macro_rules! define_ops {
385    ($f: ty, $($t: ty)*) => ($(
386        define_op!(impl Add<$t> for $f, add);
387        define_op!(impl Sub<$t> for $f, sub);
388        define_op!(impl Mul<$t> for $f, mul);
389        define_op!(impl Div<$t> for $f, div);
390        define_op!(impl Rem<$t> for $f, rem);
391
392        define_ord!($t, $f);
393
394        define_op_assign!(impl AddAssign<$t> for $f, add_assign);
395        define_op_assign!(impl SubAssign<$t> for $f, sub_assign);
396        define_op_assign!(impl MulAssign<$t> for $f, mul_assign);
397        define_op_assign!(impl DivAssign<$t> for $f, div_assign);
398        define_op_assign!(impl RemAssign<$t> for $f, rem_assign);
399    )*)
400}
401
402define_from!( Nat, usize u8 u16 u32 u64 u128 );
403define_from!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
404
405define_eq!( Nat, usize u8 u16 u32 u64 u128 );
406define_eq!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
407
408define_ops!( Nat, usize u8 u16 u32 u64 u128 );
409define_ops!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
410
411// Need a separate macro to extract the Big[U]Int from the Nat/Int struct.
412macro_rules! define_op_0 {
413    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
414        impl $imp<$scalar> for $res {
415            type Output = $res;
416
417            #[inline]
418            fn $method(self, other: $scalar) -> $res {
419                $imp::$method(self.0, &other.0).into()
420            }
421        }
422    };
423}
424
425macro_rules! define_op_0_assign {
426    (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
427        // Implement A * B
428        impl $imp<$scalar> for $res {
429            #[inline]
430            fn $method(&mut self, other: $scalar) {
431                $imp::$method(&mut self.0, other.0)
432            }
433        }
434    };
435}
436
437define_op_0!(impl Add<Nat> for Nat, add);
438define_op_0!(impl Sub<Nat> for Nat, sub);
439define_op_0!(impl Mul<Nat> for Nat, mul);
440define_op_0!(impl Div<Nat> for Nat, div);
441define_op_0!(impl Rem<Nat> for Nat, rem);
442
443define_op_0_assign!(impl AddAssign<Nat> for Nat, add_assign);
444define_op_0_assign!(impl SubAssign<Nat> for Nat, sub_assign);
445define_op_0_assign!(impl MulAssign<Nat> for Nat, mul_assign);
446define_op_0_assign!(impl DivAssign<Nat> for Nat, div_assign);
447define_op_0_assign!(impl RemAssign<Nat> for Nat, rem_assign);
448
449define_op_0!(impl Add<Int> for Int, add);
450define_op_0!(impl Sub<Int> for Int, sub);
451define_op_0!(impl Mul<Int> for Int, mul);
452define_op_0!(impl Div<Int> for Int, div);
453define_op_0!(impl Rem<Int> for Int, rem);
454
455define_op_0_assign!(impl AddAssign<Int> for Int, add_assign);
456define_op_0_assign!(impl SubAssign<Int> for Int, sub_assign);
457define_op_0_assign!(impl MulAssign<Int> for Int, mul_assign);
458define_op_0_assign!(impl DivAssign<Int> for Int, div_assign);
459define_op_0_assign!(impl RemAssign<Int> for Int, rem_assign);
460
461#[cfg(test)]
462mod tests {
463    use super::*;
464    use serde::Deserialize;
465
466    #[derive(Default, Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
467    pub struct TestStruct {
468        inner: Nat,
469    }
470
471    #[ignore]
472    #[test]
473    fn test_serde_with_bincode() {
474        // This ignored/failed test shows that bincode isn't supported.
475        let test_struct = TestStruct {
476            inner: Nat::from(1000u64),
477        };
478        let serialized = bincode::serialize(&test_struct).unwrap();
479        // panicked at 'called `Result::unwrap()` on an `Err` value: DeserializeAnyNotSupported'
480        let deserialized = bincode::deserialize(&serialized).unwrap();
481        assert_eq!(test_struct, deserialized);
482    }
483
484    #[test]
485    fn test_serde_with_json() {
486        let test_struct = TestStruct {
487            inner: Nat::from(1000u64),
488        };
489        let serialized = serde_json::to_string(&test_struct).unwrap();
490        let deserialized = serde_json::from_str(&serialized).unwrap();
491        assert_eq!(test_struct, deserialized);
492
493        // Nats serialize as arrays in JSON. The following tests the breakdown
494        // of a big number into an array.
495        // 13969838 * 2^32 + 2659581952 == 60000000000000000
496        let test_struct = TestStruct {
497            inner: Nat::parse(b"60000000000000000").unwrap(),
498        };
499        let serialized = serde_json::to_string(&test_struct).unwrap();
500        assert_eq!(serialized, "{\"inner\":[2659581952,13969838]}");
501        let deserialized = serde_json::from_str(&serialized).unwrap();
502        assert_eq!(test_struct, deserialized);
503    }
504
505    #[test]
506    fn test_serde_with_cbor() {
507        let test_struct = TestStruct {
508            inner: Nat::from(1000u64),
509        };
510        let serialized = serde_cbor::to_vec(&test_struct).unwrap();
511        let deserialized = serde_cbor::from_slice(&serialized).unwrap();
512        assert_eq!(test_struct, deserialized);
513
514        let test_struct = TestStruct {
515            inner: Nat::parse(b"60000000000000000").unwrap(),
516        };
517        let serialized = serde_cbor::to_vec(&test_struct).unwrap();
518        let deserialized = serde_cbor::from_slice(&serialized).unwrap();
519        assert_eq!(test_struct, deserialized);
520    }
521}