1use alloc::fmt;
2use core::{
3 cmp::Ordering,
4 fmt::Display,
5 mem,
6 ops::{Add, AddAssign, Mul, Sub, SubAssign},
7};
8
9#[cfg(not(feature = "std"))]
10use num_traits::float::FloatCore;
11use num_traits::ops::overflowing::{OverflowingAdd, OverflowingMul};
12
13pub trait UInt:
14 num_traits::Zero
15 + num_traits::One
16 + num_traits::Unsigned
17 + OverflowingAdd
18 + num_traits::Bounded
19 + Sub<Output = Self>
20 + PartialOrd
21 + Copy
22 + Sized
23 + OverflowingMul
24 + Display
25 + fmt::Debug
26{
27}
28
29impl UInt for u128 {}
32impl UInt for u8 {}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub enum UIntPlusOne<T>
39where
40 T: UInt,
41{
42 UInt(T),
44 MaxPlusOne,
46}
47
48impl<T> UIntPlusOne<T>
49where
50 T: UInt,
51{
52 #[allow(clippy::missing_panics_doc)]
54 #[must_use]
55 pub fn max_plus_one_as_f64() -> f64 {
56 let bits = i32::try_from(mem::size_of::<T>() * 8)
57 .expect("Real assert: bit width of T fits in i32 (u8 to u128) and gets optimized away");
58 2.0f64.powi(bits)
59 }
60}
61
62impl<T> Display for UIntPlusOne<T>
63where
64 T: UInt + Display,
65{
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 match self {
68 Self::UInt(v) => write!(f, "{v}"),
69 Self::MaxPlusOne => write!(f, "(u128::MAX + 1"),
70 }
71 }
72}
73
74impl<T> num_traits::Zero for UIntPlusOne<T>
75where
76 T: UInt,
77{
78 fn zero() -> Self {
79 Self::UInt(T::zero())
80 }
81
82 fn is_zero(&self) -> bool {
83 matches!(self, Self::UInt(v) if v.is_zero())
84 }
85}
86
87impl<T> Add for UIntPlusOne<T>
88where
89 T: UInt,
90{
91 type Output = Self;
92
93 fn add(self, rhs: Self) -> Self {
95 let zero = T::zero();
96 let one: T = T::one();
97 let max: T = T::max_value();
98
99 match (self, rhs) {
100 (Self::UInt(z), b) | (b, Self::UInt(z)) if z == zero => b,
101 (Self::UInt(a), Self::UInt(b)) => {
102 let (wrapped_less1, overflow) = a.overflowing_add(&(b - one));
103 assert!(!overflow, "arithmetic operation overflowed: {self} + {rhs}");
104 if wrapped_less1 == max {
105 Self::MaxPlusOne
106 } else {
107 Self::UInt(wrapped_less1 + T::one())
108 }
109 }
110 (Self::MaxPlusOne, _) | (_, Self::MaxPlusOne) => {
111 panic!("arithmetic operation overflowed: {self} + {rhs}");
112 }
113 }
114 }
115}
116
117impl<T> SubAssign for UIntPlusOne<T>
118where
119 T: UInt,
120{
121 fn sub_assign(&mut self, rhs: Self) {
122 let zero = T::zero();
123 let one: T = T::one();
124 let max: T = T::max_value();
125
126 *self = match (*self, rhs) {
127 (Self::UInt(a), Self::UInt(b)) => Self::UInt(a - b),
128 (Self::MaxPlusOne, Self::UInt(z)) if z == zero => Self::MaxPlusOne,
129 (Self::MaxPlusOne, Self::UInt(v)) => Self::UInt(max - (v - one)),
130 (Self::MaxPlusOne, Self::MaxPlusOne) => Self::UInt(zero),
131 (Self::UInt(_), Self::MaxPlusOne) => {
132 panic!("underflow: UIntPlusOne::UInt - UIntPlusOne::Max")
133 }
134 }
135 }
136}
137
138impl<T> AddAssign for UIntPlusOne<T>
139where
140 T: UInt,
141{
142 fn add_assign(&mut self, rhs: Self) {
143 *self = self.add(rhs);
144 }
145}
146
147impl<T> num_traits::One for UIntPlusOne<T>
148where
149 T: UInt,
150{
151 fn one() -> Self {
152 Self::UInt(T::one())
153 }
154}
155
156impl<T> Mul for UIntPlusOne<T>
157where
158 T: UInt,
159{
160 type Output = Self;
161
162 fn mul(self, rhs: Self) -> Self {
164 let zero = T::zero();
165 let one: T = T::one();
166
167 match (self, rhs) {
168 (Self::UInt(o1), b) | (b, Self::UInt(o1)) if o1 == one => b,
169 (Self::UInt(z), _) | (_, Self::UInt(z)) if z == zero => Self::UInt(zero),
170 (Self::UInt(a), Self::UInt(b)) => {
171 let (a_times_b_less1, overflow) = a.overflowing_mul(&(b - one));
172 assert!(!overflow, "arithmetic operation overflowed: {self} * {rhs}");
173 Self::UInt(a_times_b_less1) + self
174 }
175 (Self::MaxPlusOne, _) | (_, Self::MaxPlusOne) => {
176 panic!("arithmetic operation overflowed: {self} * {rhs}");
177 }
178 }
179 }
180}
181
182impl<T> PartialOrd for UIntPlusOne<T>
183where
184 T: UInt,
185{
186 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
187 match (self, other) {
188 (Self::MaxPlusOne, Self::MaxPlusOne) => Some(Ordering::Equal),
189 (Self::MaxPlusOne, _) => Some(Ordering::Greater),
190 (_, Self::MaxPlusOne) => Some(Ordering::Less),
191 (Self::UInt(a), Self::UInt(b)) => a.partial_cmp(b),
192 }
193 }
194}
195
196#[cfg(test)]
197mod tests {
198 use super::*;
199 use core::prelude::v1::*;
200 #[cfg(not(target_arch = "wasm32"))] use std::panic;
202 #[cfg(not(target_arch = "wasm32"))] use std::panic::AssertUnwindSafe;
204
205 use wasm_bindgen_test::*;
206 wasm_bindgen_test_configure!(run_in_browser);
207
208 #[cfg(not(target_arch = "wasm32"))] fn u16_to_p1(v: u16) -> UIntPlusOne<u8> {
210 if v == 256 {
211 UIntPlusOne::MaxPlusOne
212 } else {
213 UIntPlusOne::UInt(u8::try_from(v).expect("value must be <= 255 or == 256"))
214 }
215 }
216
217 #[cfg(not(target_arch = "wasm32"))] fn add_em(a: u16, b: u16) -> bool {
219 let a_p1 = u16_to_p1(a);
220 let b_p1 = u16_to_p1(b);
221
222 let c = panic::catch_unwind(AssertUnwindSafe(|| {
223 let c = a + b;
224 assert!(c <= 256, "overflow");
225 c
226 }));
227 let c_actual = panic::catch_unwind(AssertUnwindSafe(|| a_p1 + b_p1));
228
229 match (c, c_actual) {
230 (Ok(c), Ok(c_p1)) => u16_to_p1(c) == c_p1,
231 (Err(_), Err(_)) => true,
232 _ => false, }
234 }
235
236 #[cfg(not(target_arch = "wasm32"))]
237 #[allow(dead_code)]
238 fn mul_em(a: u16, b: u16) -> bool {
239 let a_p1 = u16_to_p1(a);
240 let b_p1 = u16_to_p1(b);
241
242 let c = panic::catch_unwind(AssertUnwindSafe(|| {
243 let c = a * b;
244 assert!(c <= 256, "overflow");
245 c
246 }));
247 let c_actual = panic::catch_unwind(AssertUnwindSafe(|| a_p1 * b_p1));
248
249 match (c, c_actual) {
250 (Ok(c), Ok(c_p1)) => u16_to_p1(c) == c_p1,
251 (Err(_), Err(_)) => true,
252 _ => false, }
254 }
255
256 #[cfg(not(target_arch = "wasm32"))]
257 #[allow(dead_code)]
258 fn sub_em(a: u16, b: u16) -> bool {
259 let a_p1 = u16_to_p1(a);
260 let b_p1 = u16_to_p1(b);
261
262 let c = panic::catch_unwind(AssertUnwindSafe(|| {
263 let mut c = a;
264 c -= b;
265 assert!(c <= 256, "overflow");
266 c
267 }));
268 let c_actual = panic::catch_unwind(AssertUnwindSafe(|| {
269 let mut c_actual = a_p1;
270 c_actual -= b_p1;
271 c_actual
272 }));
273
274 match (c, c_actual) {
275 (Ok(c), Ok(c_p1)) => u16_to_p1(c) == c_p1,
276 (Err(_), Err(_)) => true,
277 _ => false, }
279 }
280
281 #[cfg(not(target_arch = "wasm32"))] fn compare_em(a: u16, b: u16) -> bool {
283 let a_p1 = u16_to_p1(a);
284 let b_p1 = u16_to_p1(b);
285
286 let c = panic::catch_unwind(AssertUnwindSafe(|| a.partial_cmp(&b)));
287 let c_actual = panic::catch_unwind(AssertUnwindSafe(|| a_p1.partial_cmp(&b_p1)));
288
289 match (c, c_actual) {
290 (Ok(Some(c)), Ok(Some(c_p1))) => c == c_p1,
291 _ => panic!("never happens"), }
293 }
294
295 #[cfg(not(target_arch = "wasm32"))] #[test]
297 fn test_add_equivalence() {
298 for a in 0..=256 {
299 for b in 0..=256 {
300 assert!(add_em(a, b), "a: {a}, b: {b}");
301 }
302 }
303 }
304
305 #[cfg(debug_assertions)]
306 #[cfg(not(target_arch = "wasm32"))] #[test]
308 fn test_mul_equivalence() {
309 for a in 0..=256 {
310 for b in 0..=256 {
311 assert!(mul_em(a, b), "a: {a}, b: {b}");
312 }
313 }
314 }
315
316 #[cfg(debug_assertions)]
317 #[cfg(not(target_arch = "wasm32"))] #[test]
319 fn test_sub_equivalence() {
320 for a in 0..=256 {
321 for b in 0..=256 {
322 assert!(sub_em(a, b), "a: {a}, b: {b}");
323 }
324 }
325 }
326
327 #[cfg(not(target_arch = "wasm32"))] #[test]
329 fn test_compare_equivalence() {
330 for a in 0..=256 {
331 for b in 0..=256 {
332 assert!(compare_em(a, b), "a: {a}, b: {b}");
333 }
334 }
335 }
336
337 #[test]
338 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
339 fn test_add_assign() {
340 let mut a = UIntPlusOne::<u128>::UInt(1);
341 a += UIntPlusOne::UInt(1);
342 assert_eq!(a, UIntPlusOne::UInt(2));
343 }
344
345 #[test]
346 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
347 fn test_is_zero() {
348 use num_traits::Zero;
349
350 assert!(UIntPlusOne::<u128>::zero().is_zero());
351 assert!(!UIntPlusOne::<u128>::UInt(1).is_zero());
352 assert!(!UIntPlusOne::<u128>::MaxPlusOne.is_zero());
353 }
354
355 #[test]
356 #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
357 #[should_panic(expected = "underflow: UIntPlusOne::UInt - UIntPlusOne::Max")]
358 fn test_sub_assign_max_plus_one_underflow() {
359 let mut value = UIntPlusOne::UInt(1u128);
360 value -= UIntPlusOne::MaxPlusOne;
362 }
363}