1use super::{CandidType, Serializer, Type, TypeInner};
4use crate::{utils::pp_num_str, Error};
5use num_bigint::{BigInt, BigUint};
6use serde::{
7 de::{self, Deserialize, SeqAccess, Visitor},
8 Serialize,
9};
10use std::convert::From;
11use std::{fmt, io};
12
13#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
14pub struct Int(pub BigInt);
15#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
16pub struct Nat(pub BigUint);
17
18impl From<BigInt> for Int {
19 fn from(i: BigInt) -> Self {
20 Self(i)
21 }
22}
23
24impl From<BigUint> for Nat {
25 fn from(i: BigUint) -> Self {
26 Self(i)
27 }
28}
29
30impl From<Nat> for Int {
31 fn from(n: Nat) -> Self {
32 let i: BigInt = n.0.into();
33 i.into()
34 }
35}
36
37impl From<Int> for BigInt {
38 fn from(i: Int) -> Self {
39 i.0
40 }
41}
42
43impl From<Nat> for BigUint {
44 fn from(i: Nat) -> Self {
45 i.0
46 }
47}
48
49impl From<Nat> for BigInt {
50 fn from(i: Nat) -> Self {
51 i.0.into()
52 }
53}
54
55impl Int {
56 #[inline]
57 pub fn parse(v: &[u8]) -> crate::Result<Self> {
58 let res = BigInt::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigInt"))?;
59 Ok(Int(res))
60 }
61}
62
63impl Nat {
64 #[inline]
65 pub fn parse(v: &[u8]) -> crate::Result<Self> {
66 let res = BigUint::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigUint"))?;
67 Ok(Nat(res))
68 }
69}
70
71impl std::str::FromStr for Int {
72 type Err = crate::Error;
73 fn from_str(str: &str) -> Result<Self, Self::Err> {
74 Self::parse(str.as_bytes())
75 }
76}
77
78impl std::str::FromStr for Nat {
79 type Err = crate::Error;
80 fn from_str(str: &str) -> Result<Self, Self::Err> {
81 Self::parse(str.as_bytes())
82 }
83}
84
85impl fmt::Display for Int {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 let s = self.0.to_str_radix(10);
88 f.write_str(&pp_num_str(&s))
89 }
90}
91
92impl fmt::Display for Nat {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 let s = self.0.to_str_radix(10);
95 f.write_str(&pp_num_str(&s))
96 }
97}
98
99impl CandidType for Int {
100 fn _ty() -> Type {
101 TypeInner::Int.into()
102 }
103 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
104 where
105 S: Serializer,
106 {
107 serializer.serialize_int(self)
108 }
109}
110
111impl CandidType for Nat {
112 fn _ty() -> Type {
113 TypeInner::Nat.into()
114 }
115 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
116 where
117 S: Serializer,
118 {
119 serializer.serialize_nat(self)
120 }
121}
122
123impl<'de> Deserialize<'de> for Int {
124 fn deserialize<D>(deserializer: D) -> Result<Int, D::Error>
125 where
126 D: serde::Deserializer<'de>,
127 {
128 struct IntVisitor;
129 impl Visitor<'_> for IntVisitor {
130 type Value = Int;
131 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132 formatter.write_str("Int value")
133 }
134 fn visit_i64<E>(self, v: i64) -> Result<Int, E> {
135 Ok(Int::from(v))
136 }
137 fn visit_u64<E>(self, v: u64) -> Result<Int, E> {
138 Ok(Int::from(v))
139 }
140 fn visit_str<E: de::Error>(self, v: &str) -> Result<Int, E> {
141 v.parse::<Int>()
142 .map_err(|_| de::Error::custom(format!("{v:?} is not int")))
143 }
144 fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Int, E> {
145 Ok(Int(match v.first() {
146 Some(0) => BigInt::from_signed_bytes_le(&v[1..]),
147 Some(1) => BigInt::from_biguint(
148 num_bigint::Sign::Plus,
149 BigUint::from_bytes_le(&v[1..]),
150 ),
151 _ => return Err(de::Error::custom("not int nor nat")),
152 }))
153 }
154 }
155 deserializer.deserialize_any(IntVisitor)
156 }
157}
158
159impl<'de> Deserialize<'de> for Nat {
160 fn deserialize<D>(deserializer: D) -> Result<Nat, D::Error>
161 where
162 D: serde::Deserializer<'de>,
163 {
164 struct NatVisitor;
165 impl<'de> Visitor<'de> for NatVisitor {
166 type Value = Nat;
167 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
168 formatter.write_str("Nat value")
169 }
170 fn visit_i64<E: de::Error>(self, v: i64) -> Result<Nat, E> {
171 use num_bigint::ToBigUint;
172 v.to_biguint()
173 .ok_or_else(|| de::Error::custom("i64 cannot be converted to nat"))
174 .map(Nat)
175 }
176 fn visit_u64<E>(self, v: u64) -> Result<Nat, E> {
177 Ok(Nat::from(v))
178 }
179 fn visit_str<E: de::Error>(self, v: &str) -> Result<Nat, E> {
180 v.parse::<Nat>()
181 .map_err(|_| de::Error::custom(format!("{v:?} is not nat")))
182 }
183 fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Nat, E> {
184 if v[0] == 1 {
185 Ok(Nat(BigUint::from_bytes_le(&v[1..])))
186 } else {
187 Err(de::Error::custom("not nat"))
188 }
189 }
190
191 fn visit_seq<S>(self, mut seq: S) -> Result<Nat, S::Error>
192 where
193 S: SeqAccess<'de>,
194 {
195 let len = seq.size_hint().unwrap_or(0);
196 let mut data = Vec::with_capacity(len);
197
198 while let Some(value) = seq.next_element::<u32>()? {
199 data.push(value);
200 }
201
202 Ok(Nat(BigUint::new(data)))
203 }
204 }
205 deserializer.deserialize_any(NatVisitor)
206 }
207}
208
209impl Nat {
212 pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
213 where
214 W: ?Sized + io::Write,
215 {
216 use num_traits::cast::ToPrimitive;
217 let zero = BigUint::from(0u8);
218 let mut value = self.0.clone();
219 loop {
220 let big_byte = &value & BigUint::from(0x7fu8);
221 let mut byte = big_byte.to_u8().unwrap();
222 value >>= 7;
223 if value != zero {
224 byte |= 0x80u8;
225 }
226 let buf = [byte];
227 w.write_all(&buf)?;
228 if value == zero {
229 return Ok(());
230 }
231 }
232 }
233 pub fn decode<R>(r: &mut R) -> crate::Result<Self>
234 where
235 R: io::Read,
236 {
237 let mut result = BigUint::from(0u8);
238 let mut shift = 0;
239 loop {
240 let mut buf = [0];
241 r.read_exact(&mut buf)?;
242 let low_bits = BigUint::from(buf[0] & 0x7fu8);
243 result |= low_bits << shift;
244 if buf[0] & 0x80u8 == 0 {
245 return Ok(Nat(result));
246 }
247 shift += 7;
248 }
249 }
250}
251
252impl Int {
253 pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
254 where
255 W: ?Sized + io::Write,
256 {
257 use num_traits::cast::ToPrimitive;
258 let zero = BigInt::from(0);
259 let mut value = self.0.clone();
260 loop {
261 let big_byte = &value & BigInt::from(0xff);
262 let mut byte = big_byte.to_u8().unwrap();
263 value >>= 6;
264 let done = value == zero || value == BigInt::from(-1);
265 if done {
266 byte &= 0x7f;
267 } else {
268 value >>= 1;
269 byte |= 0x80;
270 }
271 let buf = [byte];
272 w.write_all(&buf)?;
273 if done {
274 return Ok(());
275 }
276 }
277 }
278 pub fn decode<R>(r: &mut R) -> crate::Result<Self>
279 where
280 R: io::Read,
281 {
282 let mut result = BigInt::from(0);
283 let mut shift = 0;
284 let mut byte;
285 loop {
286 let mut buf = [0];
287 r.read_exact(&mut buf)?;
288 byte = buf[0];
289 let low_bits = BigInt::from(byte & 0x7fu8);
290 result |= low_bits << shift;
291 shift += 7;
292 if byte & 0x80u8 == 0 {
293 break;
294 }
295 }
296 if (0x40u8 & byte) == 0x40u8 {
297 result |= BigInt::from(-1) << shift;
298 }
299 Ok(Int(result))
300 }
301}
302
303use std::cmp::{Ord, Ordering, PartialEq, PartialOrd};
305use std::ops::*;
306
307macro_rules! define_from {
308 ($f: ty, $($t: ty)*) => ($(
309 impl From<$t> for $f {
310 #[inline]
311 fn from(v: $t) -> Self { Self(v.into()) }
312 }
313 )*)
314}
315
316macro_rules! define_eq {
317 ($f: ty, $($t: ty)*) => ($(
318 impl PartialEq<$t> for $f {
319 #[inline]
320 fn eq(&self, v: &$t) -> bool { self.0.eq(&(*v).into()) }
321 }
322 impl PartialEq<$f> for $t {
323 #[inline]
324 fn eq(&self, v: &$f) -> bool { v.0.eq(&(*self).into()) }
325 }
326 )*)
327}
328
329macro_rules! define_op {
330 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
331 impl $imp<$scalar> for $res {
333 type Output = $res;
334
335 #[inline]
336 fn $method(self, other: $scalar) -> $res {
337 $imp::$method(self.0, &other).into()
338 }
339 }
340
341 impl $imp<$res> for $scalar {
343 type Output = $res;
344
345 #[inline]
346 fn $method(self, other: $res) -> $res {
347 $imp::$method(&self, other.0).into()
348 }
349 }
350 };
351}
352
353macro_rules! define_ord {
354 ($scalar: ty, $res: ty) => {
355 impl PartialOrd<$scalar> for $res {
357 #[inline]
358 fn partial_cmp(&self, other: &$scalar) -> Option<Ordering> {
359 PartialOrd::partial_cmp(self, &<$res>::from(*other))
360 }
361 }
362 impl PartialOrd<$res> for $scalar {
364 #[inline]
365 fn partial_cmp(&self, other: &$res) -> Option<Ordering> {
366 PartialOrd::partial_cmp(&<$res>::from(*self), other)
367 }
368 }
369 };
370}
371
372macro_rules! define_op_assign {
373 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
374 impl $imp<$scalar> for $res {
376 #[inline]
377 fn $method(&mut self, other: $scalar) {
378 $imp::$method(&mut self.0, other)
379 }
380 }
381 };
382}
383
384macro_rules! define_ops {
385 ($f: ty, $($t: ty)*) => ($(
386 define_op!(impl Add<$t> for $f, add);
387 define_op!(impl Sub<$t> for $f, sub);
388 define_op!(impl Mul<$t> for $f, mul);
389 define_op!(impl Div<$t> for $f, div);
390 define_op!(impl Rem<$t> for $f, rem);
391
392 define_ord!($t, $f);
393
394 define_op_assign!(impl AddAssign<$t> for $f, add_assign);
395 define_op_assign!(impl SubAssign<$t> for $f, sub_assign);
396 define_op_assign!(impl MulAssign<$t> for $f, mul_assign);
397 define_op_assign!(impl DivAssign<$t> for $f, div_assign);
398 define_op_assign!(impl RemAssign<$t> for $f, rem_assign);
399 )*)
400}
401
402define_from!( Nat, usize u8 u16 u32 u64 u128 );
403define_from!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
404
405define_eq!( Nat, usize u8 u16 u32 u64 u128 );
406define_eq!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
407
408define_ops!( Nat, usize u8 u16 u32 u64 u128 );
409define_ops!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
410
411macro_rules! define_op_0 {
413 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
414 impl $imp<$scalar> for $res {
415 type Output = $res;
416
417 #[inline]
418 fn $method(self, other: $scalar) -> $res {
419 $imp::$method(self.0, &other.0).into()
420 }
421 }
422 };
423}
424
425macro_rules! define_op_0_assign {
426 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
427 impl $imp<$scalar> for $res {
429 #[inline]
430 fn $method(&mut self, other: $scalar) {
431 $imp::$method(&mut self.0, other.0)
432 }
433 }
434 };
435}
436
437define_op_0!(impl Add<Nat> for Nat, add);
438define_op_0!(impl Sub<Nat> for Nat, sub);
439define_op_0!(impl Mul<Nat> for Nat, mul);
440define_op_0!(impl Div<Nat> for Nat, div);
441define_op_0!(impl Rem<Nat> for Nat, rem);
442
443define_op_0_assign!(impl AddAssign<Nat> for Nat, add_assign);
444define_op_0_assign!(impl SubAssign<Nat> for Nat, sub_assign);
445define_op_0_assign!(impl MulAssign<Nat> for Nat, mul_assign);
446define_op_0_assign!(impl DivAssign<Nat> for Nat, div_assign);
447define_op_0_assign!(impl RemAssign<Nat> for Nat, rem_assign);
448
449define_op_0!(impl Add<Int> for Int, add);
450define_op_0!(impl Sub<Int> for Int, sub);
451define_op_0!(impl Mul<Int> for Int, mul);
452define_op_0!(impl Div<Int> for Int, div);
453define_op_0!(impl Rem<Int> for Int, rem);
454
455define_op_0_assign!(impl AddAssign<Int> for Int, add_assign);
456define_op_0_assign!(impl SubAssign<Int> for Int, sub_assign);
457define_op_0_assign!(impl MulAssign<Int> for Int, mul_assign);
458define_op_0_assign!(impl DivAssign<Int> for Int, div_assign);
459define_op_0_assign!(impl RemAssign<Int> for Int, rem_assign);
460
461#[cfg(test)]
462mod tests {
463 use super::*;
464 use serde::Deserialize;
465
466 #[derive(Default, Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
467 pub struct TestStruct {
468 inner: Nat,
469 }
470
471 #[ignore]
472 #[test]
473 fn test_serde_with_bincode() {
474 let test_struct = TestStruct {
476 inner: Nat::from(1000u64),
477 };
478 let serialized = bincode::serialize(&test_struct).unwrap();
479 let deserialized = bincode::deserialize(&serialized).unwrap();
481 assert_eq!(test_struct, deserialized);
482 }
483
484 #[test]
485 fn test_serde_with_json() {
486 let test_struct = TestStruct {
487 inner: Nat::from(1000u64),
488 };
489 let serialized = serde_json::to_string(&test_struct).unwrap();
490 let deserialized = serde_json::from_str(&serialized).unwrap();
491 assert_eq!(test_struct, deserialized);
492
493 let test_struct = TestStruct {
497 inner: Nat::parse(b"60000000000000000").unwrap(),
498 };
499 let serialized = serde_json::to_string(&test_struct).unwrap();
500 assert_eq!(serialized, "{\"inner\":[2659581952,13969838]}");
501 let deserialized = serde_json::from_str(&serialized).unwrap();
502 assert_eq!(test_struct, deserialized);
503 }
504
505 #[test]
506 fn test_serde_with_cbor() {
507 let test_struct = TestStruct {
508 inner: Nat::from(1000u64),
509 };
510 let serialized = serde_cbor::to_vec(&test_struct).unwrap();
511 let deserialized = serde_cbor::from_slice(&serialized).unwrap();
512 assert_eq!(test_struct, deserialized);
513
514 let test_struct = TestStruct {
515 inner: Nat::parse(b"60000000000000000").unwrap(),
516 };
517 let serialized = serde_cbor::to_vec(&test_struct).unwrap();
518 let deserialized = serde_cbor::from_slice(&serialized).unwrap();
519 assert_eq!(test_struct, deserialized);
520 }
521}