range_set_blaze/
uint_plus_one.rs

1use alloc::fmt;
2use core::{
3    cmp::Ordering,
4    fmt::Display,
5    mem,
6    ops::{Add, AddAssign, Mul, Sub, SubAssign},
7};
8
9#[cfg(not(feature = "std"))]
10use num_traits::float::FloatCore;
11use num_traits::ops::overflowing::{OverflowingAdd, OverflowingMul};
12
13pub trait UInt:
14    num_traits::Zero
15    + num_traits::One
16    + num_traits::Unsigned
17    + OverflowingAdd
18    + num_traits::Bounded
19    + Sub<Output = Self>
20    + PartialOrd
21    + Copy
22    + Sized
23    + OverflowingMul
24    + Display
25    + fmt::Debug
26{
27}
28
29// u128 and u8 are UInt.
30// We define u8 for testing purposes.
31impl UInt for u128 {}
32impl UInt for u8 {}
33
34/// Represents values from `0` to `u128::MAX + 1` (inclusive).
35///
36/// Needed to represent every possible length of a `RangeInclusive<i128>` and `RangeInclusive<u128>`.
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
38pub enum UIntPlusOne<T>
39where
40    T: UInt,
41{
42    /// A variant representing an unsigned integer of type `T`.
43    UInt(T),
44    /// A variant representing the value `u128::MAX + 1`.
45    MaxPlusOne,
46}
47
48impl<T> UIntPlusOne<T>
49where
50    T: UInt,
51{
52    /// Returns the maximum value of an unsigned integer type `T` plus one, as an `f64`.
53    #[allow(clippy::missing_panics_doc)]
54    #[must_use]
55    pub fn max_plus_one_as_f64() -> f64 {
56        let bits = i32::try_from(mem::size_of::<T>() * 8)
57            .expect("Real assert: bit width of T fits in i32 (u8 to u128) and gets optimized away");
58        2.0f64.powi(bits)
59    }
60}
61
62impl<T> Display for UIntPlusOne<T>
63where
64    T: UInt + Display,
65{
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        match self {
68            Self::UInt(v) => write!(f, "{v}"),
69            Self::MaxPlusOne => write!(f, "(u128::MAX + 1"),
70        }
71    }
72}
73
74impl<T> num_traits::Zero for UIntPlusOne<T>
75where
76    T: UInt,
77{
78    fn zero() -> Self {
79        Self::UInt(T::zero())
80    }
81
82    fn is_zero(&self) -> bool {
83        matches!(self, Self::UInt(v) if v.is_zero())
84    }
85}
86
87impl<T> Add for UIntPlusOne<T>
88where
89    T: UInt,
90{
91    type Output = Self;
92
93    /// Adds two `UIntPlusOne` values. Always panics on overflow.
94    fn add(self, rhs: Self) -> Self {
95        let zero = T::zero();
96        let one: T = T::one();
97        let max: T = T::max_value();
98
99        match (self, rhs) {
100            (Self::UInt(z), b) | (b, Self::UInt(z)) if z == zero => b,
101            (Self::UInt(a), Self::UInt(b)) => {
102                let (wrapped_less1, overflow) = a.overflowing_add(&(b - one));
103                assert!(!overflow, "arithmetic operation overflowed: {self} + {rhs}");
104                if wrapped_less1 == max {
105                    Self::MaxPlusOne
106                } else {
107                    Self::UInt(wrapped_less1 + T::one())
108                }
109            }
110            (Self::MaxPlusOne, _) | (_, Self::MaxPlusOne) => {
111                panic!("arithmetic operation overflowed: {self} + {rhs}");
112            }
113        }
114    }
115}
116
117impl<T> SubAssign for UIntPlusOne<T>
118where
119    T: UInt,
120{
121    fn sub_assign(&mut self, rhs: Self) {
122        let zero = T::zero();
123        let one: T = T::one();
124        let max: T = T::max_value();
125
126        *self = match (*self, rhs) {
127            (Self::UInt(a), Self::UInt(b)) => Self::UInt(a - b),
128            (Self::MaxPlusOne, Self::UInt(z)) if z == zero => Self::MaxPlusOne,
129            (Self::MaxPlusOne, Self::UInt(v)) => Self::UInt(max - (v - one)),
130            (Self::MaxPlusOne, Self::MaxPlusOne) => Self::UInt(zero),
131            (Self::UInt(_), Self::MaxPlusOne) => {
132                panic!("underflow: UIntPlusOne::UInt - UIntPlusOne::Max")
133            }
134        }
135    }
136}
137
138impl<T> AddAssign for UIntPlusOne<T>
139where
140    T: UInt,
141{
142    fn add_assign(&mut self, rhs: Self) {
143        *self = self.add(rhs);
144    }
145}
146
147impl<T> num_traits::One for UIntPlusOne<T>
148where
149    T: UInt,
150{
151    fn one() -> Self {
152        Self::UInt(T::one())
153    }
154}
155
156impl<T> Mul for UIntPlusOne<T>
157where
158    T: UInt,
159{
160    type Output = Self;
161
162    /// Multiplies two `UIntPlusOne` values. Always panics on overflow.
163    fn mul(self, rhs: Self) -> Self {
164        let zero = T::zero();
165        let one: T = T::one();
166
167        match (self, rhs) {
168            (Self::UInt(o1), b) | (b, Self::UInt(o1)) if o1 == one => b,
169            (Self::UInt(z), _) | (_, Self::UInt(z)) if z == zero => Self::UInt(zero),
170            (Self::UInt(a), Self::UInt(b)) => {
171                let (a_times_b_less1, overflow) = a.overflowing_mul(&(b - one));
172                assert!(!overflow, "arithmetic operation overflowed: {self} * {rhs}");
173                Self::UInt(a_times_b_less1) + self
174            }
175            (Self::MaxPlusOne, _) | (_, Self::MaxPlusOne) => {
176                panic!("arithmetic operation overflowed: {self} * {rhs}");
177            }
178        }
179    }
180}
181
182impl<T> PartialOrd for UIntPlusOne<T>
183where
184    T: UInt,
185{
186    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
187        match (self, other) {
188            (Self::MaxPlusOne, Self::MaxPlusOne) => Some(Ordering::Equal),
189            (Self::MaxPlusOne, _) => Some(Ordering::Greater),
190            (_, Self::MaxPlusOne) => Some(Ordering::Less),
191            (Self::UInt(a), Self::UInt(b)) => a.partial_cmp(b),
192        }
193    }
194}
195
196#[cfg(test)]
197mod tests {
198    use super::*;
199    use core::prelude::v1::*;
200    #[cfg(not(target_arch = "wasm32"))] // not used by wasm-wasip1
201    use std::panic;
202    #[cfg(not(target_arch = "wasm32"))] // not used by wasm-wasip1
203    use std::panic::AssertUnwindSafe;
204
205    use wasm_bindgen_test::*;
206    wasm_bindgen_test_configure!(run_in_browser);
207
208    #[cfg(not(target_arch = "wasm32"))] // not used by wasm-wasip1
209    fn u16_to_p1(v: u16) -> UIntPlusOne<u8> {
210        if v == 256 {
211            UIntPlusOne::MaxPlusOne
212        } else {
213            UIntPlusOne::UInt(u8::try_from(v).expect("value must be <= 255 or == 256"))
214        }
215    }
216
217    #[cfg(not(target_arch = "wasm32"))] // not used by wasm-wasip1
218    fn add_em(a: u16, b: u16) -> bool {
219        let a_p1 = u16_to_p1(a);
220        let b_p1 = u16_to_p1(b);
221
222        let c = panic::catch_unwind(AssertUnwindSafe(|| {
223            let c = a + b;
224            assert!(c <= 256, "overflow");
225            c
226        }));
227        let c_actual = panic::catch_unwind(AssertUnwindSafe(|| a_p1 + b_p1));
228
229        match (c, c_actual) {
230            (Ok(c), Ok(c_p1)) => u16_to_p1(c) == c_p1,
231            (Err(_), Err(_)) => true,
232            _ => false, // Don't need to cover this
233        }
234    }
235
236    #[cfg(not(target_arch = "wasm32"))]
237    #[allow(dead_code)]
238    fn mul_em(a: u16, b: u16) -> bool {
239        let a_p1 = u16_to_p1(a);
240        let b_p1 = u16_to_p1(b);
241
242        let c = panic::catch_unwind(AssertUnwindSafe(|| {
243            let c = a * b;
244            assert!(c <= 256, "overflow");
245            c
246        }));
247        let c_actual = panic::catch_unwind(AssertUnwindSafe(|| a_p1 * b_p1));
248
249        match (c, c_actual) {
250            (Ok(c), Ok(c_p1)) => u16_to_p1(c) == c_p1,
251            (Err(_), Err(_)) => true,
252            _ => false, // Don't need to cover this
253        }
254    }
255
256    #[cfg(not(target_arch = "wasm32"))]
257    #[allow(dead_code)]
258    fn sub_em(a: u16, b: u16) -> bool {
259        let a_p1 = u16_to_p1(a);
260        let b_p1 = u16_to_p1(b);
261
262        let c = panic::catch_unwind(AssertUnwindSafe(|| {
263            let mut c = a;
264            c -= b;
265            assert!(c <= 256, "overflow");
266            c
267        }));
268        let c_actual = panic::catch_unwind(AssertUnwindSafe(|| {
269            let mut c_actual = a_p1;
270            c_actual -= b_p1;
271            c_actual
272        }));
273
274        match (c, c_actual) {
275            (Ok(c), Ok(c_p1)) => u16_to_p1(c) == c_p1,
276            (Err(_), Err(_)) => true,
277            _ => false, // Don't need to cover this
278        }
279    }
280
281    #[cfg(not(target_arch = "wasm32"))] // not used by wasm-wasip1
282    fn compare_em(a: u16, b: u16) -> bool {
283        let a_p1 = u16_to_p1(a);
284        let b_p1 = u16_to_p1(b);
285
286        let c = panic::catch_unwind(AssertUnwindSafe(|| a.partial_cmp(&b)));
287        let c_actual = panic::catch_unwind(AssertUnwindSafe(|| a_p1.partial_cmp(&b_p1)));
288
289        match (c, c_actual) {
290            (Ok(Some(c)), Ok(Some(c_p1))) => c == c_p1,
291            _ => panic!("never happens"), // Don't need to cover this
292        }
293    }
294
295    #[cfg(not(target_arch = "wasm32"))] // can't test wasm-wasip1 because we need to catch panics
296    #[test]
297    fn test_add_equivalence() {
298        for a in 0..=256 {
299            for b in 0..=256 {
300                assert!(add_em(a, b), "a: {a}, b: {b}");
301            }
302        }
303    }
304
305    #[cfg(debug_assertions)]
306    #[cfg(not(target_arch = "wasm32"))] // can't test wasm-wasip1 because we need to catch panics
307    #[test]
308    fn test_mul_equivalence() {
309        for a in 0..=256 {
310            for b in 0..=256 {
311                assert!(mul_em(a, b), "a: {a}, b: {b}");
312            }
313        }
314    }
315
316    #[cfg(debug_assertions)]
317    #[cfg(not(target_arch = "wasm32"))] // can't test wasm-wasip1 because we need to catch panics
318    #[test]
319    fn test_sub_equivalence() {
320        for a in 0..=256 {
321            for b in 0..=256 {
322                assert!(sub_em(a, b), "a: {a}, b: {b}");
323            }
324        }
325    }
326
327    #[cfg(not(target_arch = "wasm32"))] // can't test wasm-wasip1 because we need to catch panics
328    #[test]
329    fn test_compare_equivalence() {
330        for a in 0..=256 {
331            for b in 0..=256 {
332                assert!(compare_em(a, b), "a: {a}, b: {b}");
333            }
334        }
335    }
336
337    #[test]
338    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
339    fn test_add_assign() {
340        let mut a = UIntPlusOne::<u128>::UInt(1);
341        a += UIntPlusOne::UInt(1);
342        assert_eq!(a, UIntPlusOne::UInt(2));
343    }
344
345    #[test]
346    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
347    fn test_is_zero() {
348        use num_traits::Zero;
349
350        assert!(UIntPlusOne::<u128>::zero().is_zero());
351        assert!(!UIntPlusOne::<u128>::UInt(1).is_zero());
352        assert!(!UIntPlusOne::<u128>::MaxPlusOne.is_zero());
353    }
354
355    #[test]
356    #[cfg_attr(target_arch = "wasm32", wasm_bindgen_test)]
357    #[should_panic(expected = "underflow: UIntPlusOne::UInt - UIntPlusOne::Max")]
358    fn test_sub_assign_max_plus_one_underflow() {
359        let mut value = UIntPlusOne::UInt(1u128);
360        // This should panic because subtracting MaxPlusOne from a UInt should not be allowed
361        value -= UIntPlusOne::MaxPlusOne;
362    }
363}