1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149
use std::borrow::Cow; use std::ops::Neg; use std::str::FromStr; use lazy_static::lazy_static; use regex::Regex; use rust_decimal::RoundingStrategy; use crate::core::GenericResult; use crate::currency::Cash; use crate::types::Decimal; #[derive(Clone, Copy)] pub enum DecimalRestrictions { No, Zero, NonZero, NegativeOrZero, PositiveOrZero, StrictlyPositive, StrictlyNegative, } pub fn parse_decimal(string: &str, restrictions: DecimalRestrictions) -> GenericResult<Decimal> { let value = Decimal::from_str(string).map_err(|_| "Invalid decimal value")?; validate_decimal(value, restrictions) } pub fn validate_decimal(value: Decimal, restrictions: DecimalRestrictions) -> GenericResult<Decimal> { if !match restrictions { DecimalRestrictions::No => true, DecimalRestrictions::Zero => value.is_zero(), DecimalRestrictions::NonZero => !value.is_zero(), DecimalRestrictions::NegativeOrZero => value.is_sign_negative() || value.is_zero(), DecimalRestrictions::PositiveOrZero => value.is_sign_positive() || value.is_zero(), DecimalRestrictions::StrictlyPositive => value.is_sign_positive() && !value.is_zero(), DecimalRestrictions::StrictlyNegative => value.is_sign_negative() && !value.is_zero(), } { return Err!("The value doesn't comply to the specified restrictions"); } Ok(value) } pub fn validate_named_decimal(name: &str, value: Decimal, restrictions: DecimalRestrictions) -> GenericResult<Decimal> { Ok(validate_decimal(value, restrictions).map_err(|e| format!( "Invalid {} ({}): {}", name, value, e))?) } pub fn validate_named_cash(name: &str, currency: &str, value: Decimal, restrictions: DecimalRestrictions) -> GenericResult<Cash> { Ok(Cash::new(currency, validate_named_decimal(name, value, restrictions)?)) } pub fn decimal_precision(value: Decimal) -> u32 { value.fract().scale() } pub fn round(value: Decimal, points: u32) -> Decimal { round_with(value, points, RoundingMethod::Round) } #[derive(Clone, Copy, Debug)] pub enum RoundingMethod { Round, Truncate, } pub fn round_with(value: Decimal, points: u32, method: RoundingMethod) -> Decimal { let mut round_value = match method { RoundingMethod::Round => value.round_dp_with_strategy( points, RoundingStrategy::MidpointAwayFromZero), RoundingMethod::Truncate => value.round_dp_with_strategy( points, RoundingStrategy::ToZero), }; if round_value.is_zero() && round_value.is_sign_negative() { round_value = round_value.neg(); } round_value.normalize() } pub fn fold_spaces(string: &str) -> Cow<str> { lazy_static! { static ref SPACES_REGEX: Regex = Regex::new(r"\s{2,}").unwrap(); } SPACES_REGEX.replace_all(string, " ") } #[cfg(test)] mod tests { use rstest::rstest; use super::*; #[rstest(num, scale, precision, case(321, 0, 0), case(321, 1, 1), case(321, 2, 2), case(321, 3, 3), case(321, 4, 4), case(3210, 0, 0), case(3210, 1, 1), case(3210, 2, 2), case(3210, 3, 3), case(3210, 4, 4), case(3210, 5, 5), )] fn decimal_precision(num: i64, scale: u32, precision: u32) { let value = Decimal::new(num, scale); assert_eq!(super::decimal_precision(value), precision) } #[rstest(value, expected, case(dec!(-1.5), dec!(-2)), case(dec!(-1.4), dec!(-1)), case(dec!(-1), dec!(-1)), case(dec!(-0.5), dec!(-1)), case(dec!(-0.4), dec!(0)), case(dec!( 0), dec!(0)), case(dec!(-0), dec!(0)), case(dec!(0.4), dec!(0)), case(dec!(0.5), dec!(1)), case(dec!(1), dec!(1)), case(dec!(1.4), dec!(1)), case(dec!(1.5), dec!(2)), )] fn rounding(value: Decimal, expected: Decimal) { assert_eq!(round(value, 0), expected); } #[rstest(value, expected, case(dec!(-1.6), dec!(-1)), case(dec!(-1.4), dec!(-1)), case(dec!(-1), dec!(-1)), case(dec!(-0.6), dec!(0)), case(dec!(-0.4), dec!(0)), case(dec!( 0), dec!(0)), case(dec!(-0), dec!(0)), case(dec!(0.4), dec!(0)), case(dec!(0.6), dec!(0)), case(dec!(1), dec!(1)), case(dec!(1.4), dec!(1)), case(dec!(1.6), dec!(1)), )] fn truncate_rounding(value: Decimal, expected: Decimal) { assert_eq!(round_with(value, 0, RoundingMethod::Truncate), expected); } }