1use std::cmp::Ordering;
2use std::fmt::Debug;
3use std::iter::Sum;
4use std::ops::Add;
5use std::ops::AddAssign;
6use std::ops::Mul;
7use std::ops::Neg;
8use std::ops::Sub;
9
10#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum IntExt<Int = i32> {
19 Int(Int),
20 NegativeInf,
21 PositiveInf,
22}
23
24impl<Int: Copy> IntExt<Int> {
25 pub fn as_int(&self) -> Option<Int> {
26 match self {
27 IntExt::Int(int) => Some(*int),
28 IntExt::NegativeInf | IntExt::PositiveInf => None,
29 }
30 }
31}
32
33impl IntExt<i32> {
34 pub fn div_ceil(&self, other: IntExt<i32>) -> Option<IntExt<i32>> {
35 let result = self.div(other).ceil();
36
37 Self::int_ext_from_int_f64(result)
38 }
39
40 pub fn div_floor(&self, other: IntExt<i32>) -> Option<IntExt<i32>> {
41 let result = self.div(other).floor();
42
43 Self::int_ext_from_int_f64(result)
44 }
45
46 fn int_ext_from_int_f64(value: f64) -> Option<IntExt<i32>> {
47 if value.is_nan() {
48 return None;
49 }
50
51 if value.is_infinite() {
52 if value.is_sign_positive() {
53 return Some(IntExt::PositiveInf);
54 } else {
55 return Some(IntExt::NegativeInf);
56 }
57 }
58
59 assert!(value.fract().abs() < 1e-10);
60
61 Some(IntExt::Int(value as i32))
62 }
63}
64
65impl<Int: Into<f64>> IntExt<Int> {
66 fn div(self, rhs: Self) -> f64 {
67 let value: f64 = self.into();
68 let rhs_value: f64 = rhs.into();
69
70 value / rhs_value
71 }
72}
73
74impl<Int: Into<f64>> From<IntExt<Int>> for f64 {
75 fn from(value: IntExt<Int>) -> Self {
76 match value {
77 IntExt::Int(inner) => inner.into(),
78 IntExt::NegativeInf => -f64::INFINITY,
79 IntExt::PositiveInf => f64::INFINITY,
80 }
81 }
82}
83
84impl From<i32> for IntExt {
85 fn from(value: i32) -> Self {
86 IntExt::Int(value)
87 }
88}
89
90impl From<IntExt<i32>> for IntExt<i64> {
91 fn from(value: IntExt<i32>) -> Self {
92 match value {
93 IntExt::Int(int) => IntExt::Int(int.into()),
94 IntExt::NegativeInf => IntExt::NegativeInf,
95 IntExt::PositiveInf => IntExt::PositiveInf,
96 }
97 }
98}
99
100impl TryInto<i32> for IntExt {
102 type Error = ();
103
104 fn try_into(self) -> Result<i32, Self::Error> {
105 match self {
106 IntExt::Int(inner) => Ok(inner),
107 IntExt::NegativeInf | IntExt::PositiveInf => Err(()),
108 }
109 }
110}
111
112impl<Int: PartialEq> PartialEq<Int> for IntExt<Int> {
113 fn eq(&self, other: &Int) -> bool {
114 match self {
115 IntExt::Int(v1) => v1 == other,
116 IntExt::NegativeInf | IntExt::PositiveInf => false,
117 }
118 }
119}
120
121impl PartialEq<IntExt> for i32 {
122 fn eq(&self, other: &IntExt) -> bool {
123 other.eq(self)
124 }
125}
126
127impl PartialOrd<IntExt> for i32 {
128 fn partial_cmp(&self, other: &IntExt) -> Option<Ordering> {
129 other.neg().partial_cmp(&self.neg())
130 }
131}
132
133impl<Int: Ord> PartialOrd for IntExt<Int> {
134 fn partial_cmp(&self, other: &IntExt<Int>) -> Option<Ordering> {
135 Some(self.cmp(other))
136 }
137}
138
139impl<Int: Ord> Ord for IntExt<Int> {
140 fn cmp(&self, other: &Self) -> Ordering {
141 match self {
142 IntExt::Int(v1) => match other {
143 IntExt::Int(v2) => v1.cmp(v2),
144 IntExt::NegativeInf => Ordering::Greater,
145 IntExt::PositiveInf => Ordering::Less,
146 },
147 IntExt::NegativeInf => match other {
148 IntExt::Int(_) => Ordering::Less,
149 IntExt::PositiveInf => Ordering::Less,
150 IntExt::NegativeInf => Ordering::Equal,
151 },
152 IntExt::PositiveInf => match other {
153 IntExt::Int(_) => Ordering::Greater,
154 IntExt::NegativeInf => Ordering::Greater,
155 IntExt::PositiveInf => Ordering::Greater,
156 },
157 }
158 }
159}
160
161impl PartialOrd<i32> for IntExt {
162 fn partial_cmp(&self, other: &i32) -> Option<Ordering> {
163 match self {
164 IntExt::Int(v1) => v1.partial_cmp(other),
165 IntExt::NegativeInf => Some(Ordering::Less),
166 IntExt::PositiveInf => Some(Ordering::Greater),
167 }
168 }
169}
170
171impl PartialOrd<i64> for IntExt<i64> {
172 fn partial_cmp(&self, other: &i64) -> Option<Ordering> {
173 match self {
174 IntExt::Int(v1) => v1.partial_cmp(other),
175 IntExt::NegativeInf => Some(Ordering::Less),
176 IntExt::PositiveInf => Some(Ordering::Greater),
177 }
178 }
179}
180
181impl Add<i32> for IntExt {
182 type Output = IntExt;
183
184 fn add(self, rhs: i32) -> Self::Output {
185 self + IntExt::Int(rhs)
186 }
187}
188
189impl<Int: Add<Output = Int> + Debug> Add for IntExt<Int> {
190 type Output = IntExt<Int>;
191
192 fn add(self, rhs: IntExt<Int>) -> Self::Output {
193 match (self, rhs) {
194 (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs + rhs),
195
196 (IntExt::Int(_), Self::NegativeInf) => Self::NegativeInf,
197 (IntExt::Int(_), Self::PositiveInf) => Self::PositiveInf,
198 (Self::NegativeInf, IntExt::Int(_)) => Self::NegativeInf,
199 (Self::PositiveInf, IntExt::Int(_)) => Self::PositiveInf,
200
201 (IntExt::NegativeInf, IntExt::NegativeInf) => IntExt::NegativeInf,
202 (IntExt::PositiveInf, IntExt::PositiveInf) => IntExt::PositiveInf,
203
204 (lhs @ IntExt::NegativeInf, rhs @ IntExt::PositiveInf)
205 | (lhs @ IntExt::PositiveInf, rhs @ IntExt::NegativeInf) => {
206 panic!("the result of {lhs:?} + {rhs:?} is indeterminate")
207 }
208 }
209 }
210}
211
212impl Sub<IntExt<i64>> for i64 {
213 type Output = IntExt<i64>;
214
215 fn sub(self, rhs: IntExt<i64>) -> Self::Output {
216 IntExt::Int(self) - rhs
217 }
218}
219
220impl<Int: Sub<Output = Int> + Debug> Sub for IntExt<Int> {
221 type Output = IntExt<Int>;
222
223 fn sub(self, rhs: IntExt<Int>) -> Self::Output {
224 match (self, rhs) {
225 (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs - rhs),
226
227 (IntExt::Int(_), Self::NegativeInf) => Self::PositiveInf,
228 (IntExt::Int(_), Self::PositiveInf) => Self::NegativeInf,
229 (Self::NegativeInf, IntExt::Int(_)) => Self::NegativeInf,
230 (Self::PositiveInf, IntExt::Int(_)) => Self::PositiveInf,
231
232 (lhs @ IntExt::NegativeInf, rhs @ IntExt::NegativeInf)
233 | (lhs @ IntExt::PositiveInf, rhs @ IntExt::PositiveInf)
234 | (lhs @ IntExt::NegativeInf, rhs @ IntExt::PositiveInf)
235 | (lhs @ IntExt::PositiveInf, rhs @ IntExt::NegativeInf) => {
236 panic!("the result of {lhs:?} - {rhs:?} is indeterminate")
237 }
238 }
239 }
240}
241
242impl<Int> AddAssign<Int> for IntExt<Int>
243where
244 Int: AddAssign<Int>,
245{
246 fn add_assign(&mut self, rhs: Int) {
247 match self {
248 IntExt::Int(value) => {
249 value.add_assign(rhs);
250 }
251
252 IntExt::NegativeInf | IntExt::PositiveInf => {}
253 }
254 }
255}
256
257impl Mul<i32> for IntExt {
258 type Output = IntExt;
259
260 fn mul(self, rhs: i32) -> Self::Output {
261 self * IntExt::Int(rhs)
262 }
263}
264
265impl Mul for IntExt {
266 type Output = Self;
267
268 fn mul(self, rhs: Self) -> Self::Output {
269 match (self, rhs) {
270 (IntExt::Int(lhs), IntExt::Int(rhs)) => IntExt::Int(lhs * rhs),
271
272 (IntExt::Int(0), Self::NegativeInf)
274 | (IntExt::Int(0), Self::PositiveInf)
275 | (Self::NegativeInf, IntExt::Int(0))
276 | (Self::PositiveInf, IntExt::Int(0)) => IntExt::Int(0),
277
278 (IntExt::Int(value), IntExt::NegativeInf)
279 | (IntExt::NegativeInf, IntExt::Int(value)) => {
280 if value >= 0 {
281 IntExt::NegativeInf
282 } else {
283 IntExt::PositiveInf
284 }
285 }
286
287 (IntExt::Int(value), IntExt::PositiveInf)
288 | (IntExt::PositiveInf, IntExt::Int(value)) => {
289 if value >= 0 {
290 IntExt::PositiveInf
291 } else {
292 IntExt::NegativeInf
293 }
294 }
295
296 (IntExt::NegativeInf, IntExt::NegativeInf)
297 | (IntExt::PositiveInf, IntExt::PositiveInf) => IntExt::PositiveInf,
298
299 (IntExt::NegativeInf, IntExt::PositiveInf)
300 | (IntExt::PositiveInf, IntExt::NegativeInf) => IntExt::NegativeInf,
301 }
302 }
303}
304
305impl Neg for IntExt {
306 type Output = Self;
307
308 fn neg(self) -> Self::Output {
309 match self {
310 IntExt::Int(value) => IntExt::Int(-value),
311 IntExt::NegativeInf => IntExt::PositiveInf,
312 IntExt::PositiveInf => Self::NegativeInf,
313 }
314 }
315}
316
317impl Sum for IntExt {
318 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
319 iter.fold(IntExt::Int(0), |acc, value| acc + value)
320 }
321}
322
323impl Sum for IntExt<i64> {
324 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
325 iter.fold(IntExt::Int(0), |acc, value| acc + value)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use IntExt::*;
332
333 use super::*;
334
335 #[test]
336 fn ordering_of_i32_with_i32_ext() {
337 assert!(Int(2) < 3);
338 assert!(Int(-1) < 3);
339 assert!(Int(-10) < -1);
340 }
341
342 #[test]
343 fn ordering_of_i32_ext_with_i32() {
344 assert!(1 < Int(2));
345 assert!(-10 < Int(-1));
346 assert!(-11 < Int(-10));
347 }
348
349 #[test]
350 fn test_adding_i32s() {
351 assert_eq!(Int(3) + Int(4), Int(7));
352 }
353
354 #[test]
355 fn test_adding_negative_inf() {
356 assert_eq!(Int(3) + NegativeInf, NegativeInf);
357 }
358
359 #[test]
360 fn test_adding_positive_inf() {
361 assert_eq!(Int(3) + PositiveInf, PositiveInf);
362 }
363}