Skip to main content

pumpkin_checking/
int_ext.rs

1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::iter::Sum;
4use std::ops::Add;
5use std::ops::AddAssign;
6use std::ops::Mul;
7use std::ops::Neg;
8use std::ops::Sub;
9
10/// An [`i32`] or positive/negative infinity.
11///
12/// # Notes on arithmetic operations:
13/// - The result of the operation `infty + -infty` is undetermined, and if evaluated will cause a
14///   panic.
15/// - Multiplying [`IntExt::PositiveInf`] or [`IntExt::NegativeInf`] with `IntExt::I32(0)` will
16///   yield `IntExt::I32(0)`.
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum IntExt<Int = i32> {
19    Int(Int),
20    NegativeInf,
21    PositiveInf,
22}
23
24impl<Int: Copy> IntExt<Int> {
25    pub fn as_int(&self) -> Option<Int> {
26        match self {
27            IntExt::Int(int) => Some(*int),
28            IntExt::NegativeInf | IntExt::PositiveInf => None,
29        }
30    }
31}
32
33impl IntExt<i32> {
34    pub fn div_ceil(&self, other: IntExt<i32>) -> Option<IntExt<i32>> {
35        let result = self.div(other).ceil();
36
37        Self::int_ext_from_int_f64(result)
38    }
39
40    pub fn div_floor(&self, other: IntExt<i32>) -> Option<IntExt<i32>> {
41        let result = self.div(other).floor();
42
43        Self::int_ext_from_int_f64(result)
44    }
45
46    fn int_ext_from_int_f64(value: f64) -> Option<IntExt<i32>> {
47        if value.is_nan() {
48            return None;
49        }
50
51        if value.is_infinite() {
52            if value.is_sign_positive() {
53                return Some(IntExt::PositiveInf);
54            } else {
55                return Some(IntExt::NegativeInf);
56            }
57        }
58
59        assert!(value.fract().abs() < 1e-10);
60
61        Some(IntExt::Int(value as i32))
62    }
63}
64
65impl<Int: Into<f64>> IntExt<Int> {
66    fn div(self, rhs: Self) -> f64 {
67        let value: f64 = self.into();
68        let rhs_value: f64 = rhs.into();
69
70        value / rhs_value
71    }
72}
73
74impl<Int: Into<f64>> From<IntExt<Int>> for f64 {
75    fn from(value: IntExt<Int>) -> Self {
76        match value {
77            IntExt::Int(inner) => inner.into(),
78            IntExt::NegativeInf => -f64::INFINITY,
79            IntExt::PositiveInf => f64::INFINITY,
80        }
81    }
82}
83
84impl From<i32> for IntExt {
85    fn from(value: i32) -> Self {
86        IntExt::Int(value)
87    }
88}
89
90impl From<IntExt<i32>> for IntExt<i64> {
91    fn from(value: IntExt<i32>) -> Self {
92        match value {
93            IntExt::Int(int) => IntExt::Int(int.into()),
94            IntExt::NegativeInf => IntExt::NegativeInf,
95            IntExt::PositiveInf => IntExt::PositiveInf,
96        }
97    }
98}
99
100// TODO: This is not a great pattern, but for now I do not want to touch this.
101impl TryInto<i32> for IntExt {
102    type Error = ();
103
104    fn try_into(self) -> Result<i32, Self::Error> {
105        match self {
106            IntExt::Int(inner) => Ok(inner),
107            IntExt::NegativeInf | IntExt::PositiveInf => Err(()),
108        }
109    }
110}
111
112impl<Int: PartialEq> PartialEq<Int> for IntExt<Int> {
113    fn eq(&self, other: &Int) -> bool {
114        match self {
115            IntExt::Int(v1) => v1 == other,
116            IntExt::NegativeInf | IntExt::PositiveInf => false,
117        }
118    }
119}
120
121impl PartialEq<IntExt> for i32 {
122    fn eq(&self, other: &IntExt) -> bool {
123        other.eq(self)
124    }
125}
126
127impl PartialOrd<IntExt> for i32 {
128    fn partial_cmp(&self, other: &IntExt) -> Option<Ordering> {
129        other.neg().partial_cmp(&self.neg())
130    }
131}
132
133impl<Int: Ord> PartialOrd for IntExt<Int> {
134    fn partial_cmp(&self, other: &IntExt<Int>) -> Option<Ordering> {
135        Some(self.cmp(other))
136    }
137}
138
139impl<Int: Ord> Ord for IntExt<Int> {
140    fn cmp(&self, other: &Self) -> Ordering {
141        match self {
142            IntExt::Int(v1) => match other {
143                IntExt::Int(v2) => v1.cmp(v2),
144                IntExt::NegativeInf => Ordering::Greater,
145                IntExt::PositiveInf => Ordering::Less,
146            },
147            IntExt::NegativeInf => match other {
148                IntExt::Int(_) => Ordering::Less,
149                IntExt::PositiveInf => Ordering::Less,
150                IntExt::NegativeInf => Ordering::Equal,
151            },
152            IntExt::PositiveInf => match other {
153                IntExt::Int(_) => Ordering::Greater,
154                IntExt::NegativeInf => Ordering::Greater,
155                IntExt::PositiveInf => Ordering::Greater,
156            },
157        }
158    }
159}
160
161impl PartialOrd<i32> for IntExt {
162    fn partial_cmp(&self, other: &i32) -> Option<Ordering> {
163        match self {
164            IntExt::Int(v1) => v1.partial_cmp(other),
165            IntExt::NegativeInf => Some(Ordering::Less),
166            IntExt::PositiveInf => Some(Ordering::Greater),
167        }
168    }
169}
170
171impl PartialOrd<i64> for IntExt<i64> {
172    fn partial_cmp(&self, other: &i64) -> Option<Ordering> {
173        match self {
174            IntExt::Int(v1) => v1.partial_cmp(other),
175            IntExt::NegativeInf => Some(Ordering::Less),
176            IntExt::PositiveInf => Some(Ordering::Greater),
177        }
178    }
179}
180
181impl Add<i32> for IntExt {
182    type Output = IntExt;
183
184    fn add(self, rhs: i32) -> Self::Output {
185        self + IntExt::Int(rhs)
186    }
187}
188
189impl<Int: Add<Output = Int> + Debug> Add for IntExt<Int> {
190    type Output = IntExt<Int>;
191
192    fn add(self, rhs: IntExt<Int>) -> Self::Output {
193        match (self, rhs) {
194            (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs + rhs),
195
196            (IntExt::Int(_), Self::NegativeInf) => Self::NegativeInf,
197            (IntExt::Int(_), Self::PositiveInf) => Self::PositiveInf,
198            (Self::NegativeInf, IntExt::Int(_)) => Self::NegativeInf,
199            (Self::PositiveInf, IntExt::Int(_)) => Self::PositiveInf,
200
201            (IntExt::NegativeInf, IntExt::NegativeInf) => IntExt::NegativeInf,
202            (IntExt::PositiveInf, IntExt::PositiveInf) => IntExt::PositiveInf,
203
204            (lhs @ IntExt::NegativeInf, rhs @ IntExt::PositiveInf)
205            | (lhs @ IntExt::PositiveInf, rhs @ IntExt::NegativeInf) => {
206                panic!("the result of {lhs:?} + {rhs:?} is indeterminate")
207            }
208        }
209    }
210}
211
212impl Sub<IntExt<i64>> for i64 {
213    type Output = IntExt<i64>;
214
215    fn sub(self, rhs: IntExt<i64>) -> Self::Output {
216        IntExt::Int(self) - rhs
217    }
218}
219
220impl<Int: Sub<Output = Int> + Debug> Sub for IntExt<Int> {
221    type Output = IntExt<Int>;
222
223    fn sub(self, rhs: IntExt<Int>) -> Self::Output {
224        match (self, rhs) {
225            (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs - rhs),
226
227            (IntExt::Int(_), Self::NegativeInf) => Self::PositiveInf,
228            (IntExt::Int(_), Self::PositiveInf) => Self::NegativeInf,
229            (Self::NegativeInf, IntExt::Int(_)) => Self::NegativeInf,
230            (Self::PositiveInf, IntExt::Int(_)) => Self::PositiveInf,
231
232            (lhs @ IntExt::NegativeInf, rhs @ IntExt::NegativeInf)
233            | (lhs @ IntExt::PositiveInf, rhs @ IntExt::PositiveInf)
234            | (lhs @ IntExt::NegativeInf, rhs @ IntExt::PositiveInf)
235            | (lhs @ IntExt::PositiveInf, rhs @ IntExt::NegativeInf) => {
236                panic!("the result of {lhs:?} - {rhs:?} is indeterminate")
237            }
238        }
239    }
240}
241
242impl<Int> AddAssign<Int> for IntExt<Int>
243where
244    Int: AddAssign<Int>,
245{
246    fn add_assign(&mut self, rhs: Int) {
247        match self {
248            IntExt::Int(value) => {
249                value.add_assign(rhs);
250            }
251
252            IntExt::NegativeInf | IntExt::PositiveInf => {}
253        }
254    }
255}
256
257impl Mul<i32> for IntExt {
258    type Output = IntExt;
259
260    fn mul(self, rhs: i32) -> Self::Output {
261        self * IntExt::Int(rhs)
262    }
263}
264
265impl Mul for IntExt {
266    type Output = Self;
267
268    fn mul(self, rhs: Self) -> Self::Output {
269        match (self, rhs) {
270            (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs * rhs),
271
272            // Multiplication with 0 will always yield 0.
273            (IntExt::Int(0), Self::NegativeInf)
274            | (IntExt::Int(0), Self::PositiveInf)
275            | (Self::NegativeInf, IntExt::Int(0))
276            | (Self::PositiveInf, IntExt::Int(0)) => IntExt::Int(0),
277
278            (IntExt::Int(value), IntExt::NegativeInf)
279            | (IntExt::NegativeInf, IntExt::Int(value)) => {
280                if value >= 0 {
281                    IntExt::NegativeInf
282                } else {
283                    IntExt::PositiveInf
284                }
285            }
286
287            (IntExt::Int(value), IntExt::PositiveInf)
288            | (IntExt::PositiveInf, IntExt::Int(value)) => {
289                if value >= 0 {
290                    IntExt::PositiveInf
291                } else {
292                    IntExt::NegativeInf
293                }
294            }
295
296            (IntExt::NegativeInf, IntExt::NegativeInf)
297            | (IntExt::PositiveInf, IntExt::PositiveInf) => IntExt::PositiveInf,
298
299            (IntExt::NegativeInf, IntExt::PositiveInf)
300            | (IntExt::PositiveInf, IntExt::NegativeInf) => IntExt::NegativeInf,
301        }
302    }
303}
304
305impl Neg for IntExt {
306    type Output = Self;
307
308    fn neg(self) -> Self::Output {
309        match self {
310            IntExt::Int(value) => IntExt::Int(-value),
311            IntExt::NegativeInf => IntExt::PositiveInf,
312            IntExt::PositiveInf => Self::NegativeInf,
313        }
314    }
315}
316
317impl Sum for IntExt {
318    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
319        iter.fold(IntExt::Int(0), |acc, value| acc + value)
320    }
321}
322
323impl Sum for IntExt<i64> {
324    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
325        iter.fold(IntExt::Int(0), |acc, value| acc + value)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use IntExt::*;
332
333    use super::*;
334
335    #[test]
336    fn ordering_of_i32_with_i32_ext() {
337        assert!(Int(2) < 3);
338        assert!(Int(-1) < 3);
339        assert!(Int(-10) < -1);
340    }
341
342    #[test]
343    fn ordering_of_i32_ext_with_i32() {
344        assert!(1 < Int(2));
345        assert!(-10 < Int(-1));
346        assert!(-11 < Int(-10));
347    }
348
349    #[test]
350    fn test_adding_i32s() {
351        assert_eq!(Int(3) + Int(4), Int(7));
352    }
353
354    #[test]
355    fn test_adding_negative_inf() {
356        assert_eq!(Int(3) + NegativeInf, NegativeInf);
357    }
358
359    #[test]
360    fn test_adding_positive_inf() {
361        assert_eq!(Int(3) + PositiveInf, PositiveInf);
362    }
363}