baseunits_rs/util/
ratio.rs1use std::ops::Div;
2use rust_decimal::{Decimal, RoundingStrategy};
3use rust_decimal::prelude::Zero;
4
5#[derive(Debug, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
6pub struct Ratio {
7 numerator: Decimal,
8 denominator: Decimal,
9}
10
11impl Ratio {
12 pub fn new_i64(numerator: i64, denominator: i64) -> Self {
13 Self::new(Decimal::from(numerator), Decimal::from(denominator))
14 }
15 pub fn new(numerator: Decimal, denominator: Decimal) -> Self {
16 if denominator.is_zero() {
17 panic!("denominator is zero");
18 }
19 Self {
20 numerator,
21 denominator,
22 }
23 }
24
25 pub fn decimal_value(&self, scale: u32, rounding_strategy: Option<RoundingStrategy>) -> Decimal {
26 let Ratio {
27 numerator,
28 denominator,
29 } = self.clone();
30 match rounding_strategy {
31 None => numerator.div(denominator),
32 Some(s) => numerator.div(denominator).round_dp_with_strategy(scale, s),
33 }
34 }
35
36 fn gcd(numerator: Decimal, denominator: Decimal) -> Decimal {
37 if denominator.is_zero() {
38 numerator
39 } else {
40 Self::gcd(denominator, numerator % denominator)
41 }
42 }
43
44 pub fn reduce(self) -> Self {
45 let gcd = Self::gcd(self.numerator, self.denominator);
46 Self::new(self.numerator / gcd, self.denominator / gcd)
47 }
48
49 pub fn times(self, multiplier: Self) -> Self {
50 Self::new(
51 self.numerator * multiplier.numerator,
52 self.denominator * multiplier.denominator,
53 )
54 }
55
56 pub fn times_by_big_decimal(self, multiplier: Decimal) -> Self {
57 Self::new(self.numerator * multiplier, self.denominator)
58 }
59}
60
61#[cfg(test)]
62mod tests {
63 use super::*;
64 use std::str::FromStr;
65 use rust_decimal::prelude::*;
66 use rust_decimal::Decimal;
67
68 #[test]
69 fn test_big_decimal_ratio() {
70 let r3over2 = Ratio::new(Decimal::from_i32(3).unwrap(), Decimal::from_i32(2).unwrap());
71 let result = r3over2.decimal_value(1, None);
72 assert_eq!(result, Decimal::from_str("1.5").unwrap());
73
74 let r10over3 = Ratio::new(
75 Decimal::from_i32(10).unwrap(),
76 Decimal::from_i32(3).unwrap(),
77 );
78 let result = r10over3.decimal_value(3, Some(RoundingStrategy::RoundDown));
79 assert_eq!(result, Decimal::from_str("3.333").unwrap());
80
81 let result = r10over3.decimal_value(3, Some(RoundingStrategy::RoundUp));
82 assert_eq!(result, Decimal::from_str("3.334").unwrap());
83
84 let r_many_digits = Ratio::new(
85 Decimal::from_str("9.001").unwrap(),
86 Decimal::from_i32(3).unwrap(),
87 );
88 let result = r_many_digits.decimal_value(6, Some(RoundingStrategy::RoundUp));
89 assert_eq!(result, Decimal::from_str("3.000334").unwrap());
90
91 let result = r_many_digits.decimal_value(7, Some(RoundingStrategy::RoundUp));
92 assert_eq!(result, Decimal::from_str("3.0003334").unwrap());
93
94 let result = r_many_digits.decimal_value(7, Some(RoundingStrategy::RoundHalfUp));
95 assert_eq!(result, Decimal::from_str("3.0003333").unwrap());
96 }
97}