1use std::cmp::{PartialEq, PartialOrd};
2use std::convert::From;
3use std::ops::{Add, Div, Mul, Rem, Sub};
4
5macro_rules! implement_trait {
6 ($type:ident, $inner:ty, $int:ty, $trait:ident, $method:ident) => {
7 impl $trait<$int> for $type {
8 type Output = $type;
9 fn $method(self, other: $int) -> Self::Output {
10 Self(self.0.$method(other as $inner))
11 }
12 }
13 impl $trait<$type> for $int {
14 type Output = $type;
15 fn $method(self, other: $type) -> Self::Output {
16 $type((self as $inner).$method(other.0))
17 }
18 }
19 impl $trait for $type {
20 type Output = $type;
21 fn $method(self, other: $type) -> Self::Output {
22 Self(self.0.$method(other.0))
23 }
24 }
25 };
26}
27
28macro_rules! implement_eq_trait {
29 ($type:ident, $inner:ty, $int:ty) => {
30 impl PartialEq<$int> for $type {
31 fn eq(&self, other: &$int) -> bool {
32 self.0 == *other as $inner
33 }
34 }
35 impl PartialEq<$type> for $int {
36 fn eq(&self, other: &$type) -> bool {
37 (*self as $inner) == other.0
38 }
39 }
40 };
41}
42
43macro_rules! implement_ord_trait {
44 ($type:ident, $inner:ty, $int:ty) => {
45 impl PartialOrd<$int> for $type {
46 fn partial_cmp(&self, other: &$int) -> Option<std::cmp::Ordering> {
47 self.0.partial_cmp(&(*other as $inner))
48 }
49 }
50 impl PartialOrd<$type> for $int {
51 fn partial_cmp(&self, other: &$type) -> Option<std::cmp::Ordering> {
52 (*self as $inner).partial_cmp(&other.0)
53 }
54 }
55 };
56}
57
58macro_rules! implement_from {
59 ($type:ident, $inner:ty, $($int:ty),*) => {
60 $(impl From<$int> for $type {
61 fn from(value: $int) -> Self {
62 Self(value as $inner)
63 }
64 })*
65 };
66}
67
68#[macro_export]
77macro_rules! implement_int {
78 ($type:ident, $inner:ty) => {
79 implement_from!($type, $inner, u8, u16, u32, u64, i8, i16, i32, i64);
80
81 impl From<$type> for $inner {
82 fn from(value: $type) -> Self {
83 value.0
84 }
85 }
86
87 implement_trait!($type, $inner, $inner, Add, add);
88 implement_trait!($type, $inner, $inner, Sub, sub);
89 implement_trait!($type, $inner, $inner, Mul, mul);
90 implement_trait!($type, $inner, $inner, Div, div);
91 implement_trait!($type, $inner, $inner, Rem, rem);
92 implement_eq_trait!($type, $inner, $inner);
93 implement_ord_trait!($type, $inner, $inner);
94 };
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[derive(Debug, Clone, Copy)]
102 struct TestIntU64(u64);
103 #[derive(Debug, Clone, Copy)]
104 struct TestIntU32(u32);
105
106 implement_int!(TestIntU64, u64);
107 implement_int!(TestIntU32, u32);
108
109 #[test]
110 fn test_basic_add() {
111 let a = TestIntU64(3);
112 let b = 4;
113 let c = a + b;
114 assert_eq!(c, 7);
115 }
116
117 #[test]
118 #[should_panic]
119 fn test_add_overflow() {
120 let a = TestIntU32(u32::max_value());
121 let b: u32 = 1;
122 let _ = a + b;
123 }
124
125 #[test]
126 fn test_conversion_from_different_type() {
127 let _: TestIntU64 = 10u8.into();
128 }
129
130 #[test]
131 fn all_side_add_works() {
132 let a = TestIntU32(13);
133 let b = 1u32;
134 let _ = a + b;
135 let _ = b + a;
136 let a = TestIntU32(3);
137 let b = TestIntU32(1);
138 let _ = a + b;
139 }
140
141 #[test]
142 fn all_side_cmp_works() {
143 let a = TestIntU32(13);
144 let b = 1u32;
145 let _ = a < b;
146 let _ = b < a;
147 let _ = a > b;
148 let _ = b > a;
149 let _ = b == a;
150 let _ = a == b;
151 let _ = a >= b;
152 let _ = b >= a;
153 let _ = b != a;
154 let _ = a != b;
155 }
156}