1use super::{CandidType, Serializer, Type, TypeInner};
4use crate::{utils::pp_num_str, Error};
5use num_bigint::{BigInt, BigUint};
6use serde::{
7 de::{self, Deserialize, SeqAccess, Visitor},
8 Serialize,
9};
10use std::convert::From;
11use std::{fmt, io};
12
13#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
14pub struct Int(pub BigInt);
15#[derive(Serialize, Ord, PartialOrd, Eq, PartialEq, Debug, Clone, Hash, Default)]
16pub struct Nat(pub BigUint);
17
18impl From<BigInt> for Int {
19 fn from(i: BigInt) -> Self {
20 Self(i)
21 }
22}
23
24impl From<BigUint> for Nat {
25 fn from(i: BigUint) -> Self {
26 Self(i)
27 }
28}
29
30impl From<Nat> for Int {
31 fn from(n: Nat) -> Self {
32 let i: BigInt = n.0.into();
33 i.into()
34 }
35}
36
37impl From<Int> for BigInt {
38 fn from(i: Int) -> Self {
39 i.0
40 }
41}
42
43impl From<Nat> for BigUint {
44 fn from(i: Nat) -> Self {
45 i.0
46 }
47}
48
49impl From<Nat> for BigInt {
50 fn from(i: Nat) -> Self {
51 i.0.into()
52 }
53}
54
55impl Int {
56 #[inline]
57 pub fn parse(v: &[u8]) -> crate::Result<Self> {
58 let res = BigInt::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigInt"))?;
59 Ok(Int(res))
60 }
61}
62
63impl Nat {
64 #[inline]
65 pub fn parse(v: &[u8]) -> crate::Result<Self> {
66 let res = BigUint::parse_bytes(v, 10).ok_or_else(|| Error::msg("Cannot parse BigUint"))?;
67 Ok(Nat(res))
68 }
69}
70
71impl std::str::FromStr for Int {
72 type Err = crate::Error;
73 fn from_str(str: &str) -> Result<Self, Self::Err> {
74 Self::parse(str.as_bytes())
75 }
76}
77
78impl std::str::FromStr for Nat {
79 type Err = crate::Error;
80 fn from_str(str: &str) -> Result<Self, Self::Err> {
81 Self::parse(str.as_bytes())
82 }
83}
84
85impl fmt::Display for Int {
86 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87 let s = self.0.to_str_radix(10);
88 f.write_str(&pp_num_str(&s))
89 }
90}
91
92impl fmt::Display for Nat {
93 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
94 let s = self.0.to_str_radix(10);
95 f.write_str(&pp_num_str(&s))
96 }
97}
98
99impl CandidType for Int {
100 fn _ty() -> Type {
101 TypeInner::Int.into()
102 }
103 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
104 where
105 S: Serializer,
106 {
107 serializer.serialize_int(self)
108 }
109}
110
111impl CandidType for Nat {
112 fn _ty() -> Type {
113 TypeInner::Nat.into()
114 }
115 fn idl_serialize<S>(&self, serializer: S) -> Result<(), S::Error>
116 where
117 S: Serializer,
118 {
119 serializer.serialize_nat(self)
120 }
121}
122
123impl<'de> Deserialize<'de> for Int {
124 fn deserialize<D>(deserializer: D) -> Result<Int, D::Error>
125 where
126 D: serde::Deserializer<'de>,
127 {
128 struct IntVisitor;
129 impl Visitor<'_> for IntVisitor {
130 type Value = Int;
131 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
132 formatter.write_str("Int value")
133 }
134 fn visit_i64<E>(self, v: i64) -> Result<Int, E> {
135 Ok(Int::from(v))
136 }
137 fn visit_u64<E>(self, v: u64) -> Result<Int, E> {
138 Ok(Int::from(v))
139 }
140 fn visit_str<E: de::Error>(self, v: &str) -> Result<Int, E> {
141 v.parse::<Int>()
142 .map_err(|_| de::Error::custom(format!("{v:?} is not int")))
143 }
144 fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Int, E> {
145 Ok(Int(match v.first() {
146 Some(0) => BigInt::from_signed_bytes_le(&v[1..]),
147 Some(1) => BigInt::from_biguint(
148 num_bigint::Sign::Plus,
149 BigUint::from_bytes_le(&v[1..]),
150 ),
151 _ => return Err(de::Error::custom("not int nor nat")),
152 }))
153 }
154 }
155 deserializer.deserialize_any(IntVisitor)
156 }
157}
158
159impl<'de> Deserialize<'de> for Nat {
160 fn deserialize<D>(deserializer: D) -> Result<Nat, D::Error>
161 where
162 D: serde::Deserializer<'de>,
163 {
164 struct NatVisitor;
165 impl<'de> Visitor<'de> for NatVisitor {
166 type Value = Nat;
167 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
168 formatter.write_str("Nat value")
169 }
170 fn visit_i64<E: de::Error>(self, v: i64) -> Result<Nat, E> {
171 use num_bigint::ToBigUint;
172 v.to_biguint()
173 .ok_or_else(|| de::Error::custom("i64 cannot be converted to nat"))
174 .map(Nat)
175 }
176 fn visit_u64<E>(self, v: u64) -> Result<Nat, E> {
177 Ok(Nat::from(v))
178 }
179 fn visit_str<E: de::Error>(self, v: &str) -> Result<Nat, E> {
180 v.parse::<Nat>()
181 .map_err(|_| de::Error::custom(format!("{v:?} is not nat")))
182 }
183 fn visit_byte_buf<E: de::Error>(self, v: Vec<u8>) -> Result<Nat, E> {
184 if v[0] == 1 {
185 Ok(Nat(BigUint::from_bytes_le(&v[1..])))
186 } else {
187 Err(de::Error::custom("not nat"))
188 }
189 }
190
191 fn visit_seq<S>(self, mut seq: S) -> Result<Nat, S::Error>
192 where
193 S: SeqAccess<'de>,
194 {
195 let len = seq.size_hint().unwrap_or(0);
196 let mut data = Vec::with_capacity(len);
197
198 while let Some(value) = seq.next_element::<u32>()? {
199 data.push(value);
200 }
201
202 Ok(Nat(BigUint::new(data)))
203 }
204 }
205 deserializer.deserialize_any(NatVisitor)
206 }
207}
208
209impl Nat {
212 pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
213 where
214 W: ?Sized + io::Write,
215 {
216 use num_traits::cast::ToPrimitive;
217 if let Some(value) = self.0.to_u64() {
218 leb128::write::unsigned(w, value)?;
219 return Ok(());
220 }
221 let zero = BigUint::from(0u8);
222 let mut value = self.0.clone();
223 loop {
224 let big_byte = &value & BigUint::from(0x7fu8);
225 let mut byte = big_byte.to_u8().unwrap();
226 value >>= 7;
227 if value != zero {
228 byte |= 0x80u8;
229 }
230 let buf = [byte];
231 w.write_all(&buf)?;
232 if value == zero {
233 return Ok(());
234 }
235 }
236 }
237 pub fn decode<R>(r: &mut R) -> crate::Result<Self>
238 where
239 R: io::Read,
240 {
241 let mut small = 0u64;
242 let mut shift = 0u32;
243 loop {
244 let mut buf = [0];
245 r.read_exact(&mut buf)?;
246 let byte = buf[0];
247 let low_bits = u64::from(byte & 0x7f);
248 if shift == 0 || (shift < 64 && low_bits < (1u64 << (64 - shift))) {
249 small |= low_bits << shift;
250 if byte & 0x80u8 == 0 {
251 return Ok(Nat(BigUint::from(small)));
252 }
253 shift += 7;
254 continue;
255 }
256
257 let mut result = BigUint::from(small);
258 result |= BigUint::from(low_bits) << shift;
259 if byte & 0x80u8 == 0 {
260 return Ok(Nat(result));
261 }
262 shift += 7;
263 loop {
264 let mut buf = [0];
265 r.read_exact(&mut buf)?;
266 let byte = buf[0];
267 let low_bits = BigUint::from(byte & 0x7fu8);
268 result |= low_bits << shift;
269 if byte & 0x80u8 == 0 {
270 return Ok(Nat(result));
271 }
272 shift += 7;
273 }
274 }
275 }
276}
277
278impl Int {
279 pub fn encode<W>(&self, w: &mut W) -> crate::Result<()>
280 where
281 W: ?Sized + io::Write,
282 {
283 use num_traits::cast::ToPrimitive;
284 if let Some(value) = self.0.to_i64() {
285 leb128::write::signed(w, value)?;
286 return Ok(());
287 }
288 let zero = BigInt::from(0);
289 let mut value = self.0.clone();
290 loop {
291 let big_byte = &value & BigInt::from(0xff);
292 let mut byte = big_byte.to_u8().unwrap();
293 value >>= 6;
294 let done = value == zero || value == BigInt::from(-1);
295 if done {
296 byte &= 0x7f;
297 } else {
298 value >>= 1;
299 byte |= 0x80;
300 }
301 let buf = [byte];
302 w.write_all(&buf)?;
303 if done {
304 return Ok(());
305 }
306 }
307 }
308 pub fn decode<R>(r: &mut R) -> crate::Result<Self>
309 where
310 R: io::Read,
311 {
312 let mut small = 0i64;
313 let mut shift = 0u32;
314 loop {
315 let mut buf = [0];
316 r.read_exact(&mut buf)?;
317 let byte = buf[0];
318 let low_bits = i64::from(byte & 0x7f);
319
320 let fits_i64 = if shift < 57 {
321 true
322 } else if shift < 64 && byte & 0x80 == 0 {
323 let remaining_bits = 64 - shift;
328 if (byte & 0x40) != 0 {
329 (low_bits | !0x7f) >> (remaining_bits - 1) == -1
330 } else {
331 low_bits >> (remaining_bits - 1) == 0
332 }
333 } else {
334 false
335 };
336
337 if fits_i64 {
338 small |= low_bits << shift;
339 shift += 7;
340 if byte & 0x80 == 0 {
341 if shift < 64 && (byte & 0x40) != 0 {
342 small |= !0i64 << shift;
343 }
344 return Ok(Int(BigInt::from(small)));
345 }
346 continue;
347 }
348
349 let mut result = BigInt::from(small);
350 let big_low_bits = BigInt::from(byte & 0x7fu8);
351 result |= big_low_bits << shift;
352 shift += 7;
353 if byte & 0x80 == 0 {
354 if (byte & 0x40) != 0 {
355 result |= BigInt::from(-1) << shift;
356 }
357 return Ok(Int(result));
358 }
359 loop {
360 let mut buf = [0];
361 r.read_exact(&mut buf)?;
362 let byte = buf[0];
363 let big_low_bits = BigInt::from(byte & 0x7fu8);
364 result |= big_low_bits << shift;
365 shift += 7;
366 if byte & 0x80 == 0 {
367 if (byte & 0x40) != 0 {
368 result |= BigInt::from(-1) << shift;
369 }
370 return Ok(Int(result));
371 }
372 }
373 }
374 }
375}
376
377use std::cmp::{Ord, Ordering, PartialEq, PartialOrd};
379use std::ops::*;
380
381macro_rules! define_from {
382 ($f: ty, $($t: ty)*) => ($(
383 impl From<$t> for $f {
384 #[inline]
385 fn from(v: $t) -> Self { Self(v.into()) }
386 }
387 )*)
388}
389
390macro_rules! define_eq {
391 ($f: ty, $($t: ty)*) => ($(
392 impl PartialEq<$t> for $f {
393 #[inline]
394 fn eq(&self, v: &$t) -> bool { self.0.eq(&(*v).into()) }
395 }
396 impl PartialEq<$f> for $t {
397 #[inline]
398 fn eq(&self, v: &$f) -> bool { v.0.eq(&(*self).into()) }
399 }
400 )*)
401}
402
403macro_rules! define_op {
404 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
405 impl $imp<$scalar> for $res {
407 type Output = $res;
408
409 #[inline]
410 fn $method(self, other: $scalar) -> $res {
411 $imp::$method(self.0, &other).into()
412 }
413 }
414
415 impl $imp<$res> for $scalar {
417 type Output = $res;
418
419 #[inline]
420 fn $method(self, other: $res) -> $res {
421 $imp::$method(&self, other.0).into()
422 }
423 }
424 };
425}
426
427macro_rules! define_ord {
428 ($scalar: ty, $res: ty) => {
429 impl PartialOrd<$scalar> for $res {
431 #[inline]
432 fn partial_cmp(&self, other: &$scalar) -> Option<Ordering> {
433 PartialOrd::partial_cmp(self, &<$res>::from(*other))
434 }
435 }
436 impl PartialOrd<$res> for $scalar {
438 #[inline]
439 fn partial_cmp(&self, other: &$res) -> Option<Ordering> {
440 PartialOrd::partial_cmp(&<$res>::from(*self), other)
441 }
442 }
443 };
444}
445
446macro_rules! define_op_assign {
447 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
448 impl $imp<$scalar> for $res {
450 #[inline]
451 fn $method(&mut self, other: $scalar) {
452 $imp::$method(&mut self.0, other)
453 }
454 }
455 };
456}
457
458macro_rules! define_ops {
459 ($f: ty, $($t: ty)*) => ($(
460 define_op!(impl Add<$t> for $f, add);
461 define_op!(impl Sub<$t> for $f, sub);
462 define_op!(impl Mul<$t> for $f, mul);
463 define_op!(impl Div<$t> for $f, div);
464 define_op!(impl Rem<$t> for $f, rem);
465
466 define_ord!($t, $f);
467
468 define_op_assign!(impl AddAssign<$t> for $f, add_assign);
469 define_op_assign!(impl SubAssign<$t> for $f, sub_assign);
470 define_op_assign!(impl MulAssign<$t> for $f, mul_assign);
471 define_op_assign!(impl DivAssign<$t> for $f, div_assign);
472 define_op_assign!(impl RemAssign<$t> for $f, rem_assign);
473 )*)
474}
475
476define_from!( Nat, usize u8 u16 u32 u64 u128 );
477define_from!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
478
479define_eq!( Nat, usize u8 u16 u32 u64 u128 );
480define_eq!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
481
482define_ops!( Nat, usize u8 u16 u32 u64 u128 );
483define_ops!( Int, usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 );
484
485macro_rules! define_op_0 {
487 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
488 impl $imp<$scalar> for $res {
489 type Output = $res;
490
491 #[inline]
492 fn $method(self, other: $scalar) -> $res {
493 $imp::$method(self.0, &other.0).into()
494 }
495 }
496 };
497}
498
499macro_rules! define_op_0_assign {
500 (impl $imp: ident < $scalar: ty > for $res: ty, $method: ident) => {
501 impl $imp<$scalar> for $res {
503 #[inline]
504 fn $method(&mut self, other: $scalar) {
505 $imp::$method(&mut self.0, other.0)
506 }
507 }
508 };
509}
510
511define_op_0!(impl Add<Nat> for Nat, add);
512define_op_0!(impl Sub<Nat> for Nat, sub);
513define_op_0!(impl Mul<Nat> for Nat, mul);
514define_op_0!(impl Div<Nat> for Nat, div);
515define_op_0!(impl Rem<Nat> for Nat, rem);
516
517define_op_0_assign!(impl AddAssign<Nat> for Nat, add_assign);
518define_op_0_assign!(impl SubAssign<Nat> for Nat, sub_assign);
519define_op_0_assign!(impl MulAssign<Nat> for Nat, mul_assign);
520define_op_0_assign!(impl DivAssign<Nat> for Nat, div_assign);
521define_op_0_assign!(impl RemAssign<Nat> for Nat, rem_assign);
522
523define_op_0!(impl Add<Int> for Int, add);
524define_op_0!(impl Sub<Int> for Int, sub);
525define_op_0!(impl Mul<Int> for Int, mul);
526define_op_0!(impl Div<Int> for Int, div);
527define_op_0!(impl Rem<Int> for Int, rem);
528
529define_op_0_assign!(impl AddAssign<Int> for Int, add_assign);
530define_op_0_assign!(impl SubAssign<Int> for Int, sub_assign);
531define_op_0_assign!(impl MulAssign<Int> for Int, mul_assign);
532define_op_0_assign!(impl DivAssign<Int> for Int, div_assign);
533define_op_0_assign!(impl RemAssign<Int> for Int, rem_assign);
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use serde::Deserialize;
539
540 #[derive(Default, Debug, Clone, Deserialize, Serialize, PartialEq, Eq)]
541 pub struct TestStruct {
542 inner: Nat,
543 }
544
545 #[ignore]
546 #[test]
547 fn test_serde_with_bincode() {
548 let test_struct = TestStruct {
550 inner: Nat::from(1000u64),
551 };
552 let serialized = bincode::serialize(&test_struct).unwrap();
553 let deserialized = bincode::deserialize(&serialized).unwrap();
555 assert_eq!(test_struct, deserialized);
556 }
557
558 #[test]
559 fn test_serde_with_json() {
560 let test_struct = TestStruct {
561 inner: Nat::from(1000u64),
562 };
563 let serialized = serde_json::to_string(&test_struct).unwrap();
564 let deserialized = serde_json::from_str(&serialized).unwrap();
565 assert_eq!(test_struct, deserialized);
566
567 let test_struct = TestStruct {
571 inner: Nat::parse(b"60000000000000000").unwrap(),
572 };
573 let serialized = serde_json::to_string(&test_struct).unwrap();
574 assert_eq!(serialized, "{\"inner\":[2659581952,13969838]}");
575 let deserialized = serde_json::from_str(&serialized).unwrap();
576 assert_eq!(test_struct, deserialized);
577 }
578
579 #[test]
580 fn test_serde_with_cbor() {
581 let test_struct = TestStruct {
582 inner: Nat::from(1000u64),
583 };
584 let serialized = serde_cbor::to_vec(&test_struct).unwrap();
585 let deserialized = serde_cbor::from_slice(&serialized).unwrap();
586 assert_eq!(test_struct, deserialized);
587
588 let test_struct = TestStruct {
589 inner: Nat::parse(b"60000000000000000").unwrap(),
590 };
591 let serialized = serde_cbor::to_vec(&test_struct).unwrap();
592 let deserialized = serde_cbor::from_slice(&serialized).unwrap();
593 assert_eq!(test_struct, deserialized);
594 }
595}