baseunits_rs/util/
ratio.rs

1use std::ops::Div;
2use rust_decimal::{Decimal, RoundingStrategy};
3use rust_decimal::prelude::Zero;
4
5#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
6pub struct Ratio {
7  numerator: Decimal,
8  denominator: Decimal,
9}
10
11impl Ratio {
12  pub fn new_i64(numerator: i64, denominator: i64) -> Self {
13    Self::new(Decimal::from(numerator), Decimal::from(denominator))
14  }
15  pub fn new(numerator: Decimal, denominator: Decimal) -> Self {
16    if denominator.is_zero() {
17      panic!("denominator is zero");
18    }
19    Self {
20      numerator,
21      denominator,
22    }
23  }
24
25  pub fn decimal_value(&self, scale: u32, rounding_strategy: Option<RoundingStrategy>) -> Decimal {
26    let Ratio {
27      numerator,
28      denominator,
29    } = self.clone();
30    match rounding_strategy {
31      None => numerator.div(denominator),
32      Some(s) => numerator.div(denominator).round_dp_with_strategy(scale, s),
33    }
34  }
35
36  fn gcd(numerator: Decimal, denominator: Decimal) -> Decimal {
37    if denominator.is_zero() {
38      numerator
39    } else {
40      Self::gcd(denominator, numerator % denominator)
41    }
42  }
43
44  pub fn reduce(self) -> Self {
45    let gcd = Self::gcd(self.numerator, self.denominator);
46    Self::new(self.numerator / gcd, self.denominator / gcd)
47  }
48
49  pub fn times(self, multiplier: Self) -> Self {
50    Self::new(
51      self.numerator * multiplier.numerator,
52      self.denominator * multiplier.denominator,
53    )
54  }
55
56  pub fn times_by_big_decimal(self, multiplier: Decimal) -> Self {
57    Self::new(self.numerator * multiplier, self.denominator)
58  }
59}
60
61#[cfg(test)]
62mod tests {
63  use super::*;
64  use std::str::FromStr;
65  use rust_decimal::prelude::*;
66  use rust_decimal::Decimal;
67
68  #[test]
69  fn test_big_decimal_ratio() {
70    let r3over2 = Ratio::new(Decimal::from_i32(3).unwrap(), Decimal::from_i32(2).unwrap());
71    let result = r3over2.decimal_value(1, None);
72    assert_eq!(result, Decimal::from_str("1.5").unwrap());
73
74    let r10over3 = Ratio::new(
75      Decimal::from_i32(10).unwrap(),
76      Decimal::from_i32(3).unwrap(),
77    );
78    let result = r10over3.decimal_value(3, Some(RoundingStrategy::RoundDown));
79    assert_eq!(result, Decimal::from_str("3.333").unwrap());
80
81    let result = r10over3.decimal_value(3, Some(RoundingStrategy::RoundUp));
82    assert_eq!(result, Decimal::from_str("3.334").unwrap());
83
84    let r_many_digits = Ratio::new(
85      Decimal::from_str("9.001").unwrap(),
86      Decimal::from_i32(3).unwrap(),
87    );
88    let result = r_many_digits.decimal_value(6, Some(RoundingStrategy::RoundUp));
89    assert_eq!(result, Decimal::from_str("3.000334").unwrap());
90
91    let result = r_many_digits.decimal_value(7, Some(RoundingStrategy::RoundUp));
92    assert_eq!(result, Decimal::from_str("3.0003334").unwrap());
93
94    let result = r_many_digits.decimal_value(7, Some(RoundingStrategy::RoundHalfUp));
95    assert_eq!(result, Decimal::from_str("3.0003333").unwrap());
96  }
97}