okane_core/report/eval/
amount.rs

1use std::{
2    collections::{hash_map, HashMap},
3    fmt::Display,
4    iter::FusedIterator,
5    ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6};
7
8use rust_decimal::Decimal;
9
10use crate::report::{commodity::Commodity, context::ReportContext};
11
12use super::{error::EvalError, PostingAmount, SingleAmount};
13
14/// Amount with multiple commodities, or simple zero.
15#[derive(Debug, Default, PartialEq, Eq, Clone)]
16pub struct Amount<'ctx> {
17    // if values.len == zero, then it'll be completely zero.
18    // TODO: Consider optimizing for small number of commodities,
19    // as most of the case it needs to be just a few elements.
20    values: HashMap<Commodity<'ctx>, Decimal>,
21}
22
23impl<'ctx> TryFrom<Amount<'ctx>> for SingleAmount<'ctx> {
24    type Error = EvalError;
25
26    fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
27        SingleAmount::try_from(&value)
28    }
29}
30
31impl<'ctx> TryFrom<Amount<'ctx>> for PostingAmount<'ctx> {
32    type Error = EvalError;
33
34    fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
35        PostingAmount::try_from(&value)
36    }
37}
38
39impl<'ctx> TryFrom<&Amount<'ctx>> for SingleAmount<'ctx> {
40    type Error = EvalError;
41
42    fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
43        let (commodity, value) = value
44            .values
45            .iter()
46            .next()
47            .ok_or(EvalError::SingleAmountRequired)?;
48        Ok(SingleAmount {
49            value: *value,
50            commodity: *commodity,
51        })
52    }
53}
54
55impl<'ctx> TryFrom<&Amount<'ctx>> for PostingAmount<'ctx> {
56    type Error = EvalError;
57
58    fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
59        if value.values.len() > 1 {
60            Err(EvalError::PostingAmountRequired)
61        } else {
62            Ok(value
63                .values
64                .iter()
65                .next()
66                .map(|(commodity, value)| {
67                    PostingAmount::Single(SingleAmount {
68                        value: *value,
69                        commodity: *commodity,
70                    })
71                })
72                .unwrap_or_default())
73        }
74    }
75}
76
77impl<'ctx> From<PostingAmount<'ctx>> for Amount<'ctx> {
78    fn from(value: PostingAmount<'ctx>) -> Self {
79        match value {
80            PostingAmount::Zero => Amount::zero(),
81            PostingAmount::Single(single_amount) => single_amount.into(),
82        }
83    }
84}
85
86impl<'ctx> From<SingleAmount<'ctx>> for Amount<'ctx> {
87    fn from(value: SingleAmount<'ctx>) -> Self {
88        Amount::from_value(value.value, value.commodity)
89    }
90}
91
92impl<'ctx> Amount<'ctx> {
93    /// Creates an [`Amount`] with zero value.
94    #[inline(always)]
95    pub fn zero() -> Self {
96        Self::default()
97    }
98
99    /// Creates an [`Amount`] with single value and commodity.
100    pub fn from_value(amount: Decimal, commodity: Commodity<'ctx>) -> Self {
101        Self::zero() + SingleAmount::from_value(amount, commodity)
102    }
103
104    /// Creates an [`Amount`] from a set of values.
105    pub fn from_values<T>(values: T) -> Self
106    where
107        T: IntoIterator<Item = (Decimal, Commodity<'ctx>)>,
108    {
109        let mut ret = Amount::zero();
110        for (value, commodity) in values.into_iter() {
111            ret += SingleAmount::from_value(value, commodity);
112        }
113        ret
114    }
115
116    /// Takes out the instance and returns map from commodity to its value.
117    pub fn into_values(self) -> HashMap<Commodity<'ctx>, Decimal> {
118        self.values
119    }
120
121    /// Returns iterator over its amount.
122    pub fn iter(&self) -> impl Iterator<Item = SingleAmount<'ctx>> + '_ {
123        AmountIter(self.values.iter())
124    }
125
126    /// Returns an objectt to print the amount as inline.
127    pub fn as_inline_display(&self) -> impl Display + '_ {
128        InlinePrintAmount(self)
129    }
130
131    /// Returns `true` if this is 'non-commoditized zero', which is used to assert
132    /// the account balance is completely zero.
133    pub fn is_absolute_zero(&self) -> bool {
134        self.values.is_empty()
135    }
136
137    /// Returns `true` if this is zero, including zero commodities.
138    pub fn is_zero(&self) -> bool {
139        self.values.iter().all(|(_, v)| v.is_zero())
140    }
141
142    /// Removes zero values, useful when callers doesn't care zero value.
143    /// However, if caller must distinguish `0` and `0 commodity`,
144    /// caller must not use this method.
145    pub fn remove_zero_entries(&mut self) {
146        self.values.retain(|_, v| !v.is_zero());
147    }
148
149    /// Replace the amount of the particular commodity, and returns the previous amount for the commodity.
150    /// E.g. (100 USD + 100 EUR).set_partial(200, USD) returns 100.
151    /// Note this method removes the given commodity if value is zero,
152    /// so only meant for [`Balance`].
153    pub(crate) fn set_partial(&mut self, amount: SingleAmount<'ctx>) -> SingleAmount<'ctx> {
154        let value = if amount.value.is_zero() {
155            self.values.remove(&amount.commodity)
156        } else {
157            self.values.insert(amount.commodity, amount.value)
158        }
159        .unwrap_or_default();
160        SingleAmount {
161            value,
162            commodity: amount.commodity,
163        }
164    }
165
166    /// Returns the amount of the particular commodity.
167    fn get_part(&self, commodity: Commodity<'ctx>) -> Decimal {
168        self.values.get(&commodity).copied().unwrap_or_default()
169    }
170
171    /// Returns pair of commodity amount, if the amount contains exactly 2 commodities.
172    /// Otherwise returns None.
173    pub fn maybe_pair(&self) -> Option<(SingleAmount<'ctx>, SingleAmount<'ctx>)> {
174        if self.values.len() != 2 {
175            return None;
176        }
177        let ((c1, v1), (c2, v2)) = self.values.iter().zip(self.values.iter().skip(1)).next()?;
178        Some((
179            SingleAmount::from_value(*v1, *c1),
180            SingleAmount::from_value(*v2, *c2),
181        ))
182    }
183
184    /// Rounds the given Amount and returns the new instance.
185    pub fn round(mut self, ctx: &ReportContext) -> Self {
186        self.round_mut(ctx);
187        self
188    }
189
190    /// Rounds the Amount in-place with the given context provided precision.
191    pub fn round_mut(&mut self, ctx: &ReportContext) {
192        for (k, v) in self.values.iter_mut() {
193            match ctx.commodities.get_decimal_point(*k) {
194                None => (),
195                Some(dp) => {
196                    let updated = v.round_dp_with_strategy(
197                        dp,
198                        rust_decimal::RoundingStrategy::MidpointNearestEven,
199                    );
200                    *v = updated;
201                }
202            }
203        }
204    }
205
206    /// Creates negated instance.
207    pub fn negate(mut self) -> Self {
208        for (_, v) in self.values.iter_mut() {
209            v.set_sign_positive(!v.is_sign_positive())
210        }
211        self
212    }
213
214    /// Run division with error checking.
215    pub fn check_div(mut self, rhs: Decimal) -> Result<Self, EvalError> {
216        if rhs.is_zero() {
217            return Err(EvalError::DivideByZero);
218        }
219        for (_, v) in self.values.iter_mut() {
220            *v = v.checked_div(rhs).ok_or(EvalError::NumberOverflow)?;
221        }
222        Ok(self)
223    }
224
225    /// Checks if the amount is consistent with the given [PostingAmount].
226    /// Consistent means
227    ///
228    /// *   If the [PostingAmount] is zero, then the amount must be zero.
229    /// *   If the [PostingAmount] is a value with commodity,
230    ///     then the amount should be equal to given value only on the commodity.
231    pub(crate) fn is_consistent(&self, rhs: &PostingAmount<'ctx>) -> bool {
232        match rhs {
233            PostingAmount::Zero => self.is_zero(),
234            PostingAmount::Single(single) => self.get_part(single.commodity) == single.value,
235        }
236    }
237}
238
239#[derive(Debug)]
240struct AmountIter<'a, 'ctx>(hash_map::Iter<'a, Commodity<'ctx>, Decimal>);
241
242impl<'ctx> Iterator for AmountIter<'_, 'ctx> {
243    type Item = SingleAmount<'ctx>;
244
245    fn next(&mut self) -> Option<Self::Item> {
246        self.0.next().map(|(c, v)| SingleAmount::from_value(*v, *c))
247    }
248}
249
250impl FusedIterator for AmountIter<'_, '_> {}
251
252#[derive(Debug)]
253struct InlinePrintAmount<'a, 'ctx>(&'a Amount<'ctx>);
254
255impl Display for InlinePrintAmount<'_, '_> {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        let vs = &self.0.values;
258        match vs.len() {
259            0 | 1 => match vs.iter().next() {
260                Some((c, v)) => write!(f, "{} {}", v, c.as_str()),
261                None => write!(f, "0"),
262            },
263            _ => {
264                write!(f, "(")?;
265                for (i, (c, v)) in vs.iter().enumerate() {
266                    if i != 0 {
267                        write!(f, " + ")?;
268                    }
269                    write!(f, "{} {}", v, c.as_str())?;
270                }
271                write!(f, ")")
272            }
273        }
274    }
275}
276
277impl Neg for Amount<'_> {
278    type Output = Self;
279
280    fn neg(self) -> Self::Output {
281        self.negate()
282    }
283}
284
285impl Add for Amount<'_> {
286    type Output = Self;
287
288    fn add(mut self, rhs: Self) -> Self::Output {
289        self += rhs;
290        self
291    }
292}
293
294impl AddAssign for Amount<'_> {
295    fn add_assign(&mut self, rhs: Self) {
296        for (c, v2) in rhs.values {
297            let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
298            v1 += v2;
299            // we should retain the value even if zero,
300            // as (0 USD + 0 EUR) are different from 0 or (0 USD + 0 USD).
301        }
302    }
303}
304
305impl<'ctx> Add<SingleAmount<'ctx>> for Amount<'ctx> {
306    type Output = Amount<'ctx>;
307
308    fn add(mut self, rhs: SingleAmount<'ctx>) -> Self::Output {
309        self += rhs;
310        self
311    }
312}
313
314impl<'ctx> AddAssign<SingleAmount<'ctx>> for Amount<'ctx> {
315    fn add_assign(&mut self, rhs: SingleAmount<'ctx>) {
316        let curr = self.values.entry(rhs.commodity).or_default();
317        *curr += rhs.value;
318    }
319}
320
321impl<'ctx> AddAssign<PostingAmount<'ctx>> for Amount<'ctx> {
322    fn add_assign(&mut self, rhs: PostingAmount<'ctx>) {
323        match rhs {
324            PostingAmount::Zero => (),
325            PostingAmount::Single(single) => *self += single,
326        }
327    }
328}
329
330impl Sub for Amount<'_> {
331    type Output = Self;
332
333    fn sub(mut self, rhs: Self) -> Self::Output {
334        self -= rhs;
335        self
336    }
337}
338
339impl SubAssign for Amount<'_> {
340    fn sub_assign(&mut self, rhs: Self) {
341        for (c, v2) in rhs.values {
342            let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
343            v1 -= v2;
344        }
345    }
346}
347
348impl Mul<Decimal> for Amount<'_> {
349    type Output = Self;
350
351    fn mul(mut self, rhs: Decimal) -> Self::Output {
352        self *= rhs;
353        self
354    }
355}
356
357impl MulAssign<Decimal> for Amount<'_> {
358    fn mul_assign(&mut self, rhs: Decimal) {
359        for (_, mut v) in self.values.iter_mut() {
360            v *= rhs;
361        }
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    use bumpalo::Bump;
370    use maplit::hashmap;
371    use pretty_assertions::assert_eq;
372    use rust_decimal_macros::dec;
373
374    use crate::{report::ReportContext, syntax::pretty_decimal::PrettyDecimal};
375
376    #[test]
377    fn test_default() {
378        let amount = Amount::default();
379        assert_eq!(format!("{}", amount.as_inline_display()), "0")
380    }
381
382    #[test]
383    fn test_from_value() {
384        let arena = Bump::new();
385        let mut ctx = ReportContext::new(&arena);
386        let jpy = ctx.commodities.ensure("JPY");
387        let amount = Amount::from_value(dec!(123.45), jpy);
388        assert_eq!(format!("{}", amount.as_inline_display()), "123.45 JPY")
389    }
390
391    #[test]
392    fn test_from_values() {
393        let arena = Bump::new();
394        let mut ctx = ReportContext::new(&arena);
395        let jpy = ctx.commodities.ensure("JPY");
396        let chf = ctx.commodities.ensure("CHF");
397
398        let amount = Amount::from_values([(dec!(10), jpy), (dec!(1), chf)]);
399        assert_eq!(
400            amount.into_values(),
401            hashmap! {jpy => dec!(10), chf => dec!(1)},
402        );
403
404        let amount = Amount::from_values([(dec!(10), jpy), (dec!(1), jpy)]);
405        assert_eq!(amount.into_values(), hashmap! {jpy => dec!(11)});
406
407        let amount = Amount::from_values([(dec!(10), jpy), (dec!(-10), jpy)]);
408        assert_eq!(amount.into_values(), hashmap! {jpy => dec!(0)});
409    }
410
411    #[test]
412    fn test_is_absolute_zero() {
413        let arena = Bump::new();
414        let mut ctx = ReportContext::new(&arena);
415        let jpy = ctx.commodities.ensure("JPY");
416        let usd = ctx.commodities.ensure("USD");
417
418        assert!(Amount::default().is_absolute_zero());
419        assert!(!Amount::from_value(dec!(0), jpy).is_absolute_zero());
420
421        let mut amount = Amount::from_values([(dec!(0), jpy), (dec!(0), usd)]);
422        assert!(!amount.is_absolute_zero(), "{}", amount.as_inline_display());
423
424        amount.remove_zero_entries();
425        assert!(amount.is_absolute_zero(), "{}", amount.as_inline_display());
426    }
427
428    #[test]
429    fn test_is_zero() {
430        let arena = Bump::new();
431        let mut ctx = ReportContext::new(&arena);
432        let jpy = ctx.commodities.ensure("JPY");
433        let usd = ctx.commodities.ensure("USD");
434
435        assert!(Amount::default().is_zero());
436        assert!(Amount::from_value(dec!(0), jpy).is_zero());
437        assert!(Amount::from_values([(dec!(0), jpy), (dec!(0), usd)]).is_zero());
438
439        assert!(!Amount::from_value(dec!(1), jpy).is_zero());
440        assert!(!Amount::from_values([(dec!(0), jpy), (dec!(1), usd)]).is_zero());
441    }
442
443    #[test]
444    fn test_neg() {
445        let arena = Bump::new();
446        let mut ctx = ReportContext::new(&arena);
447        let jpy = ctx.commodities.ensure("JPY");
448        let usd = ctx.commodities.ensure("USD");
449
450        assert_eq!(-Amount::zero(), Amount::zero());
451        assert_eq!(
452            -Amount::from_value(dec!(100), jpy),
453            Amount::from_value(dec!(-100), jpy)
454        );
455        assert_eq!(
456            -Amount::from_values([(dec!(100), jpy), (dec!(-20.35), usd)]),
457            Amount::from_values([(dec!(-100), jpy), (dec!(20.35), usd)]),
458        );
459    }
460
461    #[test]
462    fn test_add_amount() {
463        let arena = Bump::new();
464        let mut ctx = ReportContext::new(&arena);
465        let jpy = ctx.commodities.ensure("JPY");
466        let usd = ctx.commodities.ensure("USD");
467        let eur = ctx.commodities.ensure("EUR");
468        let chf = ctx.commodities.ensure("CHF");
469
470        let zero_plus_zero = Amount::zero() + Amount::zero();
471        assert_eq!(zero_plus_zero, Amount::zero());
472
473        assert_eq!(
474            Amount::from_value(dec!(1), jpy) + Amount::zero(),
475            Amount::from_value(dec!(1), jpy),
476        );
477        assert_eq!(
478            Amount::zero() + Amount::from_value(dec!(1), jpy),
479            Amount::from_value(dec!(1), jpy),
480        );
481        assert_eq!(
482            Amount::from_values([
483                (dec!(123.00), jpy),
484                (dec!(456.0), usd),
485                (dec!(7.89), eur),
486                (dec!(0), chf), // 0 CHF retained
487            ]),
488            Amount::from_value(dec!(123.45), jpy)
489                + Amount::from_value(dec!(-0.45), jpy)
490                + Amount::from_value(dec!(456), usd)
491                + Amount::from_value(dec!(0.0), usd)
492                + -Amount::from_value(dec!(100), chf)
493                + Amount::from_value(dec!(7.89), eur)
494                + Amount::from_value(dec!(100), chf),
495        );
496
497        assert_eq!(
498            Amount::from_values([(dec!(0), jpy), (dec!(0), usd), (dec!(0), chf)]),
499            Amount::from_values([(dec!(1), jpy), (dec!(2), usd), (dec!(3), chf)])
500                + Amount::from_values([(dec!(-1), jpy), (dec!(-2), usd), (dec!(-3), chf)])
501        );
502    }
503
504    #[test]
505    fn test_add_single_amount() {
506        let arena = Bump::new();
507        let mut ctx = ReportContext::new(&arena);
508        let jpy = ctx.commodities.ensure("JPY");
509        let usd = ctx.commodities.ensure("USD");
510
511        let amount = Amount::zero() + SingleAmount::from_value(dec!(0), usd);
512        assert_eq!(amount, Amount::from_value(dec!(0), usd));
513
514        assert_eq!(
515            Amount::zero() + SingleAmount::from_value(dec!(1), jpy),
516            Amount::from_value(dec!(1), jpy),
517        );
518    }
519
520    #[test]
521    fn test_sub() {
522        let arena = Bump::new();
523        let mut ctx = ReportContext::new(&arena);
524        let jpy = ctx.commodities.ensure("JPY");
525        let usd = ctx.commodities.ensure("USD");
526        let eur = ctx.commodities.ensure("EUR");
527        let chf = ctx.commodities.ensure("CHF");
528
529        let zero_minus_zero = Amount::zero() - Amount::zero();
530        assert_eq!(zero_minus_zero, Amount::zero());
531
532        assert_eq!(
533            Amount::from_value(dec!(1), jpy) - Amount::zero(),
534            Amount::from_value(dec!(1), jpy),
535        );
536        assert_eq!(
537            Amount::zero() - Amount::from_value(dec!(1), jpy),
538            Amount::from_value(dec!(-1), jpy),
539        );
540        assert_eq!(
541            Amount::from_values([
542                (dec!(12345), jpy),
543                (dec!(-200), eur),
544                (dec!(13.3), chf),
545                (dec!(0), usd)
546            ]),
547            Amount::from_values([(dec!(12345), jpy), (dec!(56.78), usd)])
548                - Amount::from_values([(dec!(56.780), usd), (dec!(200), eur), (dec!(-13.3), chf),]),
549        );
550    }
551
552    fn eps() -> Decimal {
553        Decimal::try_from_i128_with_scale(1, 28).unwrap()
554    }
555
556    #[test]
557    fn test_mul() {
558        let arena = Bump::new();
559        let mut ctx = ReportContext::new(&arena);
560        let jpy = ctx.commodities.ensure("JPY");
561        let eur = ctx.commodities.ensure("EUR");
562        let chf = ctx.commodities.ensure("CHF");
563
564        assert_eq!(Amount::zero() * dec!(5), Amount::zero());
565        assert_eq!(
566            Amount::from_value(dec!(1), jpy) * Decimal::ZERO,
567            Amount::from_value(dec!(0), jpy),
568        );
569        assert_eq!(
570            Amount::from_value(dec!(123), jpy) * dec!(3),
571            Amount::from_value(dec!(369), jpy),
572        );
573        assert_eq!(
574            Amount::from_values([(dec!(10081), jpy), (dec!(200), eur), (dec!(-13.3), chf)])
575                * dec!(-0.5),
576            Amount::from_values([(dec!(-5040.5), jpy), (dec!(-100.0), eur), (dec!(6.65), chf)]),
577        );
578        assert_eq!(
579            Amount::from_value(eps(), jpy) * eps(),
580            Amount::from_value(dec!(0), jpy)
581        );
582    }
583
584    #[test]
585    fn test_check_div() {
586        let arena = Bump::new();
587        let mut ctx = ReportContext::new(&arena);
588        let jpy = ctx.commodities.ensure("JPY");
589        let eur = ctx.commodities.ensure("EUR");
590        let chf = ctx.commodities.ensure("CHF");
591
592        assert_eq!(Amount::zero().check_div(dec!(5)).unwrap(), Amount::zero());
593        assert_eq!(
594            Amount::zero().check_div(dec!(0)).unwrap_err(),
595            EvalError::DivideByZero
596        );
597
598        assert_eq!(
599            Amount::from_value(dec!(50), jpy)
600                .check_div(dec!(4))
601                .unwrap(),
602            Amount::from_value(dec!(12.5), jpy)
603        );
604
605        assert_eq!(
606            Amount::from_value(Decimal::MAX, jpy)
607                .check_div(eps())
608                .unwrap_err(),
609            EvalError::NumberOverflow
610        );
611
612        assert_eq!(
613            Amount::from_value(eps(), jpy)
614                .check_div(Decimal::MAX)
615                .unwrap(),
616            Amount::from_value(dec!(0), jpy)
617        );
618
619        assert_eq!(
620            Amount::from_values([(dec!(810), jpy), (dec!(-100.0), eur), (dec!(6.66), chf)])
621                .check_div(dec!(3))
622                .unwrap(),
623            Amount::from_values([
624                (dec!(270), jpy),
625                (dec!(-33.333333333333333333333333333), eur),
626                (dec!(2.22), chf)
627            ]),
628        );
629    }
630
631    #[test]
632    fn test_round() {
633        let arena = Bump::new();
634        let mut ctx = ReportContext::new(&arena);
635        let jpy = ctx.commodities.ensure("JPY");
636        let eur = ctx.commodities.ensure("EUR");
637        let chf = ctx.commodities.ensure("CHF");
638
639        ctx.commodities
640            .set_format(jpy, PrettyDecimal::comma3dot(dec!(12345)));
641        ctx.commodities
642            .set_format(eur, PrettyDecimal::plain(dec!(123.45)));
643        ctx.commodities
644            .set_format(chf, PrettyDecimal::comma3dot(dec!(123.450)));
645
646        assert_eq!(Amount::zero(), Amount::zero().round(&ctx));
647
648        assert_eq!(
649            Amount::from_values([(dec!(812), jpy), (dec!(-100.00), eur), (dec!(6.660), chf)]),
650            Amount::from_values([(dec!(812), jpy), (dec!(-100.0), eur), (dec!(6.66), chf)])
651                .round(&ctx),
652        );
653
654        assert_eq!(
655            Amount::from_values([(dec!(812), jpy), (dec!(-100.02), eur), (dec!(6.666), chf)]),
656            Amount::from_values([
657                (dec!(812.5), jpy),
658                (dec!(-100.015), eur),
659                (dec!(6.6665), chf)
660            ])
661            .round(&ctx),
662        );
663    }
664}