1use serde::{de, Deserialize, Deserializer, Serialize, Serializer};
2
3use std::cmp::Ordering;
4use std::fmt;
5use std::hash::{Hash, Hasher};
6use std::mem::discriminant;
7use std::num::FpCategory;
8use std::ops::{Add, Div, Mul, Rem, Sub};
9
10#[derive(Copy, Clone, Debug, Deserialize, Serialize)]
11pub enum Numeric {
12 Integer(i64),
13
14 #[serde(
15 serialize_with = "serialize_float",
16 deserialize_with = "deserialize_float"
17 )]
18 Float(f64),
19}
20
21fn serialize_float<S>(f: &f64, s: S) -> Result<S::Ok, S::Error>
24where
25 S: Serializer,
26{
27 match f.classify() {
28 FpCategory::Nan => s.serialize_str("NaN"),
29 FpCategory::Infinite => s.serialize_str(if *f == f64::INFINITY {
30 "Infinity"
31 } else {
32 "-Infinity"
33 }),
34 FpCategory::Zero | FpCategory::Subnormal | FpCategory::Normal => s.serialize_f64(*f),
35 }
36}
37
38fn deserialize_float<'de, D>(deserializer: D) -> Result<f64, D::Error>
40where
41 D: Deserializer<'de>,
42{
43 struct FloatVisitor;
44
45 impl<'de> de::Visitor<'de> for FloatVisitor {
46 type Value = f64;
47
48 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
49 formatter.write_str("An integer (42), a float (1.2), or a string (\"NaN\")")
50 }
51
52 fn visit_u64<E>(self, v: u64) -> Result<Self::Value, E>
53 where
54 E: de::Error,
55 {
56 Ok(v as f64)
57 }
58
59 fn visit_i64<E>(self, v: i64) -> Result<Self::Value, E>
60 where
61 E: de::Error,
62 {
63 Ok(v as f64)
64 }
65
66 fn visit_f64<E>(self, v: f64) -> Result<Self::Value, E>
67 where
68 E: de::Error,
69 {
70 Ok(v)
71 }
72
73 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
74 where
75 E: de::Error,
76 {
77 match v {
78 "Infinity" => Ok(f64::INFINITY),
79 "-Infinity" => Ok(f64::NEG_INFINITY),
80 "NaN" => Ok(f64::NAN),
81 _ => Err(de::Error::custom("invalid float")),
82 }
83 }
84 }
85
86 deserializer.deserialize_any(FloatVisitor)
87}
88
89impl Add for Numeric {
90 type Output = Option<Self>;
91
92 fn add(self, other: Self) -> Option<Self> {
93 match (self, other) {
94 (Numeric::Integer(a), Numeric::Integer(b)) => a.checked_add(b).map(Numeric::Integer),
95 (Numeric::Integer(a), Numeric::Float(b)) => Some(Numeric::Float(a as f64 + b)),
96 (Numeric::Float(a), Numeric::Integer(b)) => Some(Numeric::Float(a + b as f64)),
97 (Numeric::Float(a), Numeric::Float(b)) => Some(Numeric::Float(a + b)),
98 }
99 }
100}
101
102impl Sub for Numeric {
103 type Output = Option<Self>;
104
105 fn sub(self, other: Self) -> Option<Self> {
106 match (self, other) {
107 (Numeric::Integer(a), Numeric::Integer(b)) => a.checked_sub(b).map(Numeric::Integer),
108 (Numeric::Integer(a), Numeric::Float(b)) => Some(Numeric::Float(a as f64 - b)),
109 (Numeric::Float(a), Numeric::Integer(b)) => Some(Numeric::Float(a - b as f64)),
110 (Numeric::Float(a), Numeric::Float(b)) => Some(Numeric::Float(a - b)),
111 }
112 }
113}
114
115impl Numeric {
116 pub fn modulo(self, modulus: Self) -> Option<Self> {
117 fn modulo(a: f64, b: f64) -> f64 {
118 ((a % b) + b) % b
119 }
120
121 match (self, modulus) {
122 (Numeric::Integer(a), Numeric::Integer(b)) => {
123 a.checked_rem(b).map(|c| (c + b) % b).map(Numeric::Integer)
124 }
125 (Numeric::Integer(a), Numeric::Float(b)) => Some(Numeric::Float(modulo(a as f64, b))),
126 (Numeric::Float(a), Numeric::Integer(b)) => Some(Numeric::Float(modulo(a, b as f64))),
127 (Numeric::Float(a), Numeric::Float(b)) => Some(Numeric::Float(modulo(a, b))),
128 }
129 }
130}
131
132impl Rem for Numeric {
133 type Output = Option<Self>;
134
135 fn rem(self, other: Self) -> Option<Self> {
136 match (self, other) {
137 (Numeric::Integer(a), Numeric::Integer(b)) => a.checked_rem(b).map(Numeric::Integer),
138 (Numeric::Integer(a), Numeric::Float(b)) => Some(Numeric::Float((a as f64) % b)),
139 (Numeric::Float(a), Numeric::Integer(b)) => Some(Numeric::Float(a % (b as f64))),
140 (Numeric::Float(a), Numeric::Float(b)) => Some(Numeric::Float(a % b)),
141 }
142 }
143}
144
145impl Mul for Numeric {
146 type Output = Option<Self>;
147
148 fn mul(self, other: Self) -> Option<Self> {
149 match (self, other) {
150 (Numeric::Integer(a), Numeric::Integer(b)) => a.checked_mul(b).map(Numeric::Integer),
151 (Numeric::Integer(a), Numeric::Float(b)) => Some(Numeric::Float(a as f64 * b)),
152 (Numeric::Float(a), Numeric::Integer(b)) => Some(Numeric::Float(a * b as f64)),
153 (Numeric::Float(a), Numeric::Float(b)) => Some(Numeric::Float(a * b)),
154 }
155 }
156}
157
158impl Div for Numeric {
159 type Output = Option<Self>;
160
161 fn div(self, other: Self) -> Option<Self> {
162 match (self, other) {
163 (Numeric::Integer(a), Numeric::Integer(b)) => Some(Numeric::Float(a as f64 / b as f64)),
164 (Numeric::Integer(a), Numeric::Float(b)) => Some(Numeric::Float(a as f64 / b)),
165 (Numeric::Float(a), Numeric::Integer(b)) => Some(Numeric::Float(a / b as f64)),
166 (Numeric::Float(a), Numeric::Float(b)) => Some(Numeric::Float(a / b)),
167 }
168 }
169}
170
171impl PartialEq for Numeric {
172 fn eq(&self, other: &Self) -> bool {
173 matches!(self.partial_cmp(other), Some(Ordering::Equal))
174 }
175}
176
177impl Eq for Numeric {}
178
179pub const MOST_POSITIVE_EXACT_FLOAT: i64 = 1 << 53;
181
182const MOST_POSITIVE_I64_FLOAT: f64 = -(i64::MIN as f64);
186const MOST_NEGATIVE_I64_FLOAT: f64 = i64::MIN as f64;
187
188impl Hash for Numeric {
189 fn hash<H>(&self, state: &mut H)
190 where
191 H: Hasher,
192 {
193 match self {
194 Numeric::Integer(i) => {
195 discriminant(self).hash(state);
196 *i as u64
197 }
198 Numeric::Float(f) => match f.classify() {
199 FpCategory::Zero => {
200 discriminant(&Numeric::Integer(0)).hash(state);
202 0u64
203 }
204 FpCategory::Nan | FpCategory::Infinite | FpCategory::Subnormal => {
205 discriminant(self).hash(state);
206 f.to_bits()
207 }
208 FpCategory::Normal => {
209 if f.fract() == 0.0 {
211 if MOST_NEGATIVE_I64_FLOAT <= *f && *f < MOST_POSITIVE_I64_FLOAT {
212 discriminant(&Numeric::Integer(0)).hash(state);
214 (*f as i64) as u64
215 } else {
216 discriminant(self).hash(state);
218 f.to_bits()
219 }
220 } else {
221 discriminant(self).hash(state);
223 f.to_bits()
224 }
225 }
226 },
227 }
228 .hash(state)
229 }
230}
231
232impl PartialOrd for Numeric {
233 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
234 let partial_cmp = |i: i64, f: f64| {
237 if f.is_nan() {
238 None
239 } else if -MOST_POSITIVE_EXACT_FLOAT < i && i < MOST_POSITIVE_EXACT_FLOAT {
240 (i as f64).partial_cmp(&f)
242 } else if f >= MOST_POSITIVE_I64_FLOAT {
243 Some(Ordering::Less)
245 } else if f < MOST_NEGATIVE_I64_FLOAT {
246 Some(Ordering::Greater)
248 } else {
249 i.partial_cmp(&(f as i64))
252 }
253 };
254 match (*self, *other) {
255 (Self::Integer(left), Self::Integer(right)) => left.partial_cmp(&right),
256 (Self::Integer(i), Self::Float(f)) => partial_cmp(i, f),
257 (Self::Float(f), Self::Integer(i)) => partial_cmp(i, f).map(Ordering::reverse),
258 (Self::Float(left), Self::Float(right)) => left.partial_cmp(&right),
259 }
260 }
261}
262
263impl From<i64> for Numeric {
264 fn from(other: i64) -> Self {
265 Self::Integer(other)
266 }
267}
268impl From<f64> for Numeric {
269 fn from(other: f64) -> Self {
270 Self::Float(other)
271 }
272}
273#[cfg(test)]
274mod tests {
275 use super::*;
276
277 use serde_json::{from_str as from_json, to_string as to_json};
278
279 use std::collections::hash_map::DefaultHasher;
280
281 fn hash<T: Hash>(t: &T) -> u64 {
282 let mut s = DefaultHasher::new();
283 t.hash(&mut s);
284 s.finish()
285 }
286
287 #[test]
288 #[allow(clippy::neg_cmp_op_on_partial_ord)]
289 fn mixed_comparison() {
291 assert!(Numeric::Integer(1) != Numeric::Float(f64::NAN));
293 assert!(Numeric::Integer(-1) != Numeric::Float(f64::NAN));
294 assert!(!(Numeric::Integer(1) < Numeric::Float(f64::NAN)));
295 assert!(!(Numeric::Integer(1) > Numeric::Float(f64::NAN)));
296 assert!(!(Numeric::Integer(-1) > Numeric::Float(f64::NAN)));
297 assert!(Numeric::Float(f64::NAN) != Numeric::Float(f64::NAN));
298
299 assert!(Numeric::Integer(0) == Numeric::Float(0.0));
301 assert!(Numeric::Integer(0) == Numeric::Float(-0.0));
302 assert!(Numeric::Float(0.0) == Numeric::Float(-0.0));
303
304 assert!(Numeric::Integer(1) < Numeric::Float(f64::INFINITY));
306 assert!(Numeric::Integer(i64::MAX) < Numeric::Float(f64::INFINITY));
307 assert!(Numeric::Integer(i64::MIN) < Numeric::Float(f64::INFINITY));
308 assert!(Numeric::Integer(i64::MIN) > Numeric::Float(f64::NEG_INFINITY));
309 assert!(Numeric::Integer(0) > Numeric::Float(f64::NEG_INFINITY));
310 assert!(Numeric::Integer(i64::MAX) > Numeric::Float(f64::NEG_INFINITY));
311
312 assert!(Numeric::Integer(1) == Numeric::Float(1.0));
314 assert!(Numeric::Integer(-1) != Numeric::Float(1.0));
315 assert!(Numeric::Integer(2) > Numeric::Float(1.0));
316 assert!(Numeric::Integer(-2) < Numeric::Float(1.0));
317 assert!(Numeric::Integer(1 << 52) == Numeric::Float((2.0_f64).powi(52)));
318 assert!(Numeric::Integer(1 << 53) == Numeric::Float((2.0_f64).powi(53)));
319 assert!(Numeric::Integer((1 << 52) + 1) == Numeric::Float((2.0_f64).powi(52) + 1.0));
320 assert!(Numeric::Integer(1 << 52) < Numeric::Float((2.0_f64).powi(52) + 1.0));
321 assert!(Numeric::Integer((1 << 52) + 1) > Numeric::Float((2.0_f64).powi(52)));
322 assert!(Numeric::Integer(-(1 << 52) - 1) < Numeric::Float(-(2.0_f64).powi(52)));
323
324 assert!(Numeric::Integer((1 << 53) + 1) > Numeric::Float((2.0_f64).powi(53)));
326 assert!(Numeric::Integer((1 << 53) - 1) == Numeric::Float((2.0_f64).powi(53) - 1.0));
327 assert!(Numeric::Integer(-(1 << 53) - 1) < Numeric::Float(-(2.0_f64).powi(53)));
328 assert!(Numeric::Integer(-(1 << 54)) < Numeric::Float(-(2.0_f64).powi(53)));
329 assert!(Numeric::Integer(1 << 54) > Numeric::Float((2.0_f64).powi(53)));
330 assert!(Numeric::Integer(1 << 56) > Numeric::Float((2.0_f64).powi(54)));
331
332 assert!(Numeric::Integer(1 << 56) < Numeric::Float((2.0_f64).powi(70)));
334
335 assert!(Numeric::Integer(1 << 56) > Numeric::Float(-(2.0_f64).powi(70)));
337 assert!(Numeric::Integer(-(1 << 56)) > Numeric::Float(-(2.0_f64).powi(70)));
338 assert!(Numeric::Integer(i64::MIN) > Numeric::Float(-(2.0_f64).powi(70)));
339 assert!(Numeric::Integer(i64::MAX) < Numeric::Float((2.0_f64).powi(65) + 3.1));
340
341 assert!(Numeric::Integer(i64::MAX) < Numeric::Float((2.0_f64).powi(63)));
343 assert!(Numeric::Integer(i64::MAX) > Numeric::Float((2.0_f64).powi(63) - 1024.0));
345 assert!(Numeric::Integer(i64::MAX) < Numeric::Float((2.0_f64).powi(63) + 2048.0));
347
348 assert!(Numeric::Integer(i64::MIN) == Numeric::Float(-(2.0_f64).powi(63)));
350 assert!(Numeric::Integer(i64::MIN) > Numeric::Float(-(2.0_f64).powi(63) - 2048.0));
352 assert!(Numeric::Integer(i64::MIN) < Numeric::Float(-(2.0_f64).powi(63) + 1024.0));
354
355 assert!(Numeric::Integer(i64::MIN) < Numeric::Float(-(2.0_f64).powi(62)));
356 assert!(Numeric::Integer(i64::MIN) > Numeric::Float(-(2.0_f64).powi(65)));
357
358 assert!(Numeric::Integer(2) == Numeric::Float(2.0));
360 assert!(Numeric::Integer(2) < Numeric::Float(2.1));
361 assert!(Numeric::Integer(2) < Numeric::Float(2.0 + 2.0 * f64::EPSILON));
363 assert!(Numeric::Integer(2) > Numeric::Float(2.0 - 2.0 * f64::EPSILON));
364 assert!(Numeric::Integer(1) < Numeric::Float(1.0 + f64::EPSILON));
365 assert!(Numeric::Integer(1) > Numeric::Float(1.0 - f64::EPSILON));
366 assert!(Numeric::Integer(2) < Numeric::Float(3.0));
367 }
368
369 #[test]
370 fn numeric_hash() {
371 let nan1 = f64::NAN;
372 let nan2 = f64::from_bits(f64::NAN.to_bits() | 0xDEADBEEF); assert!(nan1.is_nan() && nan2.is_nan());
374
375 assert_eq!(hash(&Numeric::Float(nan1)), hash(&Numeric::Float(nan1)));
376 assert_ne!(hash(&Numeric::Float(nan1)), hash(&Numeric::Float(nan2)));
377 assert_eq!(hash(&Numeric::Float(nan2)), hash(&Numeric::Float(nan2)));
378
379 let inf = f64::INFINITY;
380 let ninf = f64::NEG_INFINITY;
381 assert!(inf.is_infinite() && ninf.is_infinite());
382 assert_eq!(hash(&Numeric::Float(inf)), hash(&Numeric::Float(inf)));
383 assert_ne!(hash(&Numeric::Float(inf)), hash(&Numeric::Float(ninf)));
384 assert_eq!(hash(&Numeric::Float(ninf)), hash(&Numeric::Float(ninf)));
385
386 assert_eq!(hash(&Numeric::Float(0.0)), hash(&Numeric::Float(0.0)));
388 assert_eq!(hash(&Numeric::Float(0.0)), hash(&Numeric::Float(-0.0)));
389 assert_eq!(hash(&Numeric::Float(1.0)), hash(&Numeric::Float(1.0)));
390 assert_ne!(hash(&Numeric::Float(1.0)), hash(&Numeric::Float(-1.0)));
391 assert_eq!(hash(&Numeric::Float(1e100)), hash(&Numeric::Float(1e100)));
392 assert_ne!(hash(&Numeric::Float(1e100)), hash(&Numeric::Float(2e100)));
393
394 let eps = f64::EPSILON;
396 assert!(eps.is_normal() && eps > 0.0);
397 assert_eq!(hash(&Numeric::Float(1.1)), hash(&Numeric::Float(1.1)));
398 assert_ne!(hash(&Numeric::Float(1.1)), hash(&Numeric::Float(1.1 + eps)));
399 assert_ne!(hash(&Numeric::Float(1.1)), hash(&Numeric::Float(1.1 - eps)));
400
401 let min = i64::MIN;
403 let max = i64::MAX;
404 let mid = 1_i64 << 53;
405 let fmin = MOST_NEGATIVE_I64_FLOAT;
406 let fmax = MOST_POSITIVE_I64_FLOAT;
407 let fmid = 2_f64.powi(53);
408 assert_eq!(hash(&Numeric::Integer(0)), hash(&Numeric::Float(-0.0)));
409 assert_eq!(hash(&Numeric::Integer(0)), hash(&Numeric::Float(0.0)));
410 assert_eq!(hash(&Numeric::Integer(1)), hash(&Numeric::Float(1.0)));
411 assert_ne!(hash(&Numeric::Integer(-1)), hash(&Numeric::Float(1.0)));
412 assert_eq!(hash(&Numeric::Integer(-1)), hash(&Numeric::Float(-1.0)));
413 assert_eq!(hash(&Numeric::Integer(min)), hash(&Numeric::Float(fmin)));
414
415 assert_ne!(
416 hash(&Numeric::Integer(mid)),
417 hash(&Numeric::Float(fmid - 1.0)) );
419 assert_eq!(
420 hash(&Numeric::Integer(mid - 1)),
421 hash(&Numeric::Float(fmid - 1.0))
422 );
423 assert_eq!(hash(&Numeric::Integer(mid)), hash(&Numeric::Float(fmid)));
424 assert_ne!(hash(&Numeric::Integer(max)), hash(&Numeric::Float(fmax)));
425
426 assert_ne!(
427 hash(&Numeric::Integer(max)),
428 hash(&Numeric::Float(fmax + 2048.0)) );
430 assert_ne!(
431 hash(&Numeric::Integer(max)),
432 hash(&Numeric::Float(fmax - 1024.0)) );
434 assert_ne!(
435 hash(&Numeric::Integer(min)),
436 hash(&Numeric::Float(fmin + 2048.0)) );
438 assert_ne!(
439 hash(&Numeric::Integer(min)),
440 hash(&Numeric::Float(fmin - 2048.0)) );
442 }
443
444 #[test]
445 fn json_serialization() {
446 assert_eq!(to_json(&Numeric::Integer(0)).unwrap(), r#"{"Integer":0}"#);
447 assert_eq!(to_json(&Numeric::Integer(1)).unwrap(), r#"{"Integer":1}"#);
448 assert_eq!(to_json(&Numeric::Integer(-1)).unwrap(), r#"{"Integer":-1}"#);
449 assert_eq!(
450 to_json(&Numeric::Integer(i64::MAX)).unwrap(),
451 r#"{"Integer":9223372036854775807}"#
452 );
453
454 assert_eq!(
455 to_json(&Numeric::Float(MOST_POSITIVE_EXACT_FLOAT as f64)).unwrap(),
456 r#"{"Float":9007199254740992.0}"#
457 );
458 assert_eq!(to_json(&Numeric::Float(1.0)).unwrap(), r#"{"Float":1.0}"#);
459 assert_eq!(
460 to_json(&Numeric::Float(f64::EPSILON)).unwrap(),
461 r#"{"Float":2.220446049250313e-16}"#
462 );
463 assert_eq!(to_json(&Numeric::Float(0.0)).unwrap(), r#"{"Float":0.0}"#);
464 assert_eq!(to_json(&Numeric::Float(-0.0)).unwrap(), r#"{"Float":-0.0}"#);
465 assert_eq!(to_json(&Numeric::Float(-1.0)).unwrap(), r#"{"Float":-1.0}"#);
466 assert_eq!(
467 to_json(&Numeric::Float(f64::NEG_INFINITY)).unwrap(),
468 r#"{"Float":"-Infinity"}"#
469 );
470 assert_eq!(
471 to_json(&Numeric::Float(f64::INFINITY)).unwrap(),
472 r#"{"Float":"Infinity"}"#
473 );
474 assert_eq!(
475 to_json(&Numeric::Float(f64::NAN)).unwrap(),
476 r#"{"Float":"NaN"}"#
477 );
478 }
479
480 #[test]
481 fn json_deserialization() {
482 assert_eq!(
484 from_json::<Numeric>(r#"{"Integer":0}"#).unwrap(),
485 Numeric::Integer(0)
486 );
487 assert_eq!(
488 from_json::<Numeric>(r#"{"Integer":1}"#).unwrap(),
489 Numeric::Integer(1)
490 );
491 assert_eq!(
492 from_json::<Numeric>(r#"{"Integer":-1}"#).unwrap(),
493 Numeric::Integer(-1)
494 );
495 assert_eq!(
496 from_json::<Numeric>(r#"{"Integer":9223372036854775807}"#).unwrap(),
497 Numeric::Integer(i64::MAX)
498 );
499
500 assert_eq!(
502 from_json::<Numeric>(r#"{"Float":9007199254740992.0}"#).unwrap(),
503 Numeric::Float(MOST_POSITIVE_EXACT_FLOAT as f64)
504 );
505 assert_eq!(
506 from_json::<Numeric>(r#"{"Float":1.0}"#).unwrap(),
507 Numeric::Float(1.0)
508 );
509 assert_eq!(
510 from_json::<Numeric>(r#"{"Float":2.220446049250313e-16}"#).unwrap(),
511 Numeric::Float(f64::EPSILON)
512 );
513 assert_eq!(
514 from_json::<Numeric>(r#"{"Float":0.0}"#).unwrap(),
515 Numeric::Float(0.0)
516 );
517 assert_eq!(
518 from_json::<Numeric>(r#"{"Float":-0.0}"#).unwrap(),
519 Numeric::Float(-0.0)
520 );
521 assert_eq!(
522 from_json::<Numeric>(r#"{"Float":-1.0}"#).unwrap(),
523 Numeric::Float(-1.0)
524 );
525 assert_eq!(
526 from_json::<Numeric>(r#"{"Float":"-Infinity"}"#).unwrap(),
527 Numeric::Float(f64::NEG_INFINITY)
528 );
529 assert_eq!(
530 from_json::<Numeric>(r#"{"Float":"Infinity"}"#).unwrap(),
531 Numeric::Float(f64::INFINITY)
532 );
533 assert!(match from_json::<Numeric>(r#"{"Float":"NaN"}"#).unwrap() {
534 Numeric::Float(f) => f.is_nan(),
535 _ => panic!("expected a float"),
536 });
537 }
538}