1use std::{
2 collections::{btree_map, BTreeMap},
3 fmt::Display,
4 iter::FusedIterator,
5 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6};
7
8use rust_decimal::Decimal;
9
10use crate::report::{
11 commodity::{CommodityStore, CommodityTag},
12 context::ReportContext,
13};
14
15use super::{error::EvalError, PostingAmount, SingleAmount};
16
17#[derive(Debug, Default, PartialEq, Eq, Clone)]
19pub struct Amount<'ctx> {
20 values: BTreeMap<CommodityTag<'ctx>, Decimal>,
24}
25
26impl<'ctx> TryFrom<Amount<'ctx>> for SingleAmount<'ctx> {
27 type Error = EvalError<'ctx>;
28
29 fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
30 SingleAmount::try_from(&value)
31 }
32}
33
34impl<'ctx> TryFrom<Amount<'ctx>> for PostingAmount<'ctx> {
35 type Error = EvalError<'ctx>;
36
37 fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
38 PostingAmount::try_from(&value)
39 }
40}
41
42impl<'ctx> TryFrom<&Amount<'ctx>> for SingleAmount<'ctx> {
43 type Error = EvalError<'ctx>;
44
45 fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
46 let (commodity, value) = value
47 .values
48 .iter()
49 .next()
50 .ok_or(EvalError::SingleAmountRequired)?;
51 Ok(SingleAmount {
52 value: *value,
53 commodity: *commodity,
54 })
55 }
56}
57
58impl<'ctx> TryFrom<&Amount<'ctx>> for PostingAmount<'ctx> {
59 type Error = EvalError<'ctx>;
60
61 fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
62 if value.values.len() > 1 {
63 Err(EvalError::PostingAmountRequired)
64 } else {
65 Ok(value
66 .values
67 .iter()
68 .next()
69 .map(|(commodity, value)| {
70 PostingAmount::Single(SingleAmount {
71 value: *value,
72 commodity: *commodity,
73 })
74 })
75 .unwrap_or_default())
76 }
77 }
78}
79
80impl<'ctx> From<PostingAmount<'ctx>> for Amount<'ctx> {
81 fn from(value: PostingAmount<'ctx>) -> Self {
82 match value {
83 PostingAmount::Zero => Amount::zero(),
84 PostingAmount::Single(single_amount) => single_amount.into(),
85 }
86 }
87}
88
89impl<'ctx> From<SingleAmount<'ctx>> for Amount<'ctx> {
90 fn from(value: SingleAmount<'ctx>) -> Self {
91 Amount::from_value(value.commodity, value.value)
92 }
93}
94
95impl<'ctx> FromIterator<(CommodityTag<'ctx>, Decimal)> for Amount<'ctx> {
96 fn from_iter<T>(iter: T) -> Self
97 where
98 T: IntoIterator<Item = (CommodityTag<'ctx>, Decimal)>,
99 {
100 let mut ret = Self::zero();
101 for (commodity, value) in iter.into_iter() {
102 ret += SingleAmount::from_value(commodity, value);
103 }
104 ret
105 }
106}
107
108impl<'ctx> Amount<'ctx> {
109 #[inline(always)]
111 pub fn zero() -> Self {
112 Self::default()
113 }
114
115 pub fn from_value(commodity: CommodityTag<'ctx>, amount: Decimal) -> Self {
117 Self::zero() + SingleAmount::from_value(commodity, amount)
118 }
119
120 pub fn from_values(values: BTreeMap<CommodityTag<'ctx>, Decimal>) -> Self {
122 Self { values }
123 }
124
125 pub fn into_values(self) -> BTreeMap<CommodityTag<'ctx>, Decimal> {
127 self.values
128 }
129
130 pub fn iter(&self) -> impl Iterator<Item = SingleAmount<'ctx>> + '_ {
132 AmountIter(self.values.iter())
133 }
134
135 pub fn as_inline_display<'a>(&'a self, ctx: &'a ReportContext<'ctx>) -> impl Display + 'a + 'ctx
139 where
140 'a: 'ctx,
141 {
142 InlinePrintAmount {
143 commodity_store: &ctx.commodities,
144 amount: self,
145 }
146 }
147
148 pub fn is_absolute_zero(&self) -> bool {
151 self.values.is_empty()
152 }
153
154 pub fn is_zero(&self) -> bool {
156 self.values.iter().all(|(_, v)| v.is_zero())
157 }
158
159 pub fn remove_zero_entries(&mut self) {
163 self.values.retain(|_, v| !v.is_zero());
164 }
165
166 pub(crate) fn set_partial(&mut self, amount: SingleAmount<'ctx>) -> SingleAmount<'ctx> {
171 let value = if amount.value.is_zero() {
172 self.values.remove(&amount.commodity)
173 } else {
174 self.values.insert(amount.commodity, amount.value)
175 }
176 .unwrap_or_default();
177 SingleAmount {
178 value,
179 commodity: amount.commodity,
180 }
181 }
182
183 fn get_part(&self, commodity: CommodityTag<'ctx>) -> Decimal {
185 self.values.get(&commodity).copied().unwrap_or_default()
186 }
187
188 pub fn maybe_pair(&self) -> Option<(SingleAmount<'ctx>, SingleAmount<'ctx>)> {
191 if self.values.len() != 2 {
192 return None;
193 }
194 let ((c1, v1), (c2, v2)) = self.values.iter().zip(self.values.iter().skip(1)).next()?;
195 Some((
196 SingleAmount::from_value(*c1, *v1),
197 SingleAmount::from_value(*c2, *v2),
198 ))
199 }
200
201 pub fn round(mut self, ctx: &ReportContext) -> Self {
203 self.round_mut(ctx);
204 self
205 }
206
207 pub fn round_mut(&mut self, ctx: &ReportContext) {
209 for (k, v) in self.values.iter_mut() {
210 match ctx.commodities.get_decimal_point(*k) {
211 None => (),
212 Some(dp) => {
213 let updated = v.round_dp_with_strategy(
214 dp,
215 rust_decimal::RoundingStrategy::MidpointNearestEven,
216 );
217 *v = updated;
218 }
219 }
220 }
221 }
222
223 pub fn negate(mut self) -> Self {
225 for (_, v) in self.values.iter_mut() {
226 v.set_sign_positive(!v.is_sign_positive())
227 }
228 self
229 }
230
231 pub fn check_div(mut self, rhs: Decimal) -> Result<Self, EvalError<'ctx>> {
233 if rhs.is_zero() {
234 return Err(EvalError::DivideByZero);
235 }
236 for (_, v) in self.values.iter_mut() {
237 *v = v.checked_div(rhs).ok_or(EvalError::NumberOverflow)?;
238 }
239 Ok(self)
240 }
241
242 pub(crate) fn assert_balance(&self, expected: &PostingAmount<'ctx>) -> Self {
251 match expected {
252 PostingAmount::Zero => {
253 if self.is_zero() {
254 Self::zero()
255 } else {
256 -self.clone()
257 }
258 }
259 PostingAmount::Single(single) => {
260 let diff = single.value - self.get_part(single.commodity);
261 if diff.is_zero() {
262 Self::zero()
263 } else {
264 Self::from_value(single.commodity, diff)
265 }
266 }
267 }
268 }
269}
270
271#[derive(Debug)]
272struct AmountIter<'a, 'ctx>(btree_map::Iter<'a, CommodityTag<'ctx>, Decimal>);
273
274impl<'ctx> Iterator for AmountIter<'_, 'ctx> {
275 type Item = SingleAmount<'ctx>;
276
277 fn next(&mut self) -> Option<Self::Item> {
278 self.0.next().map(|(c, v)| SingleAmount::from_value(*c, *v))
279 }
280}
281
282impl FusedIterator for AmountIter<'_, '_> {}
283
284#[derive(Debug)]
285struct InlinePrintAmount<'a, 'ctx> {
286 commodity_store: &'a CommodityStore<'ctx>,
287 amount: &'a Amount<'ctx>,
288}
289
290impl Display for InlinePrintAmount<'_, '_> {
291 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
292 let vs = &self.amount.values;
293 if vs.len() <= 1 {
294 return match vs.iter().next() {
295 Some((c, v)) => {
296 write!(f, "{} {}", v, c.to_str_lossy(self.commodity_store))
297 }
298 None => write!(f, "0"),
299 };
300 }
301 write!(f, "(")?;
303 for (i, (c, v)) in vs.iter().enumerate() {
304 let mut v = *v;
305 if i != 0 {
306 if v.is_sign_negative() {
307 v.set_sign_negative(false);
308 write!(f, " - ")?;
309 } else {
310 write!(f, " + ")?;
311 }
312 }
313 write!(f, "{} {}", v, c.to_str_lossy(self.commodity_store))?;
314 }
315 write!(f, ")")
316 }
317}
318
319impl Neg for Amount<'_> {
320 type Output = Self;
321
322 fn neg(self) -> Self::Output {
323 self.negate()
324 }
325}
326
327impl Add for Amount<'_> {
328 type Output = Self;
329
330 fn add(mut self, rhs: Self) -> Self::Output {
331 self += rhs;
332 self
333 }
334}
335
336impl AddAssign for Amount<'_> {
337 fn add_assign(&mut self, rhs: Self) {
338 for (c, v2) in rhs.values {
339 let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
340 v1 += v2;
341 }
344 }
345}
346
347impl<'ctx> Add<SingleAmount<'ctx>> for Amount<'ctx> {
348 type Output = Amount<'ctx>;
349
350 fn add(mut self, rhs: SingleAmount<'ctx>) -> Self::Output {
351 self += rhs;
352 self
353 }
354}
355
356impl<'ctx> AddAssign<SingleAmount<'ctx>> for Amount<'ctx> {
357 fn add_assign(&mut self, rhs: SingleAmount<'ctx>) {
358 let curr = self.values.entry(rhs.commodity).or_default();
359 *curr += rhs.value;
360 }
361}
362
363impl<'ctx> AddAssign<PostingAmount<'ctx>> for Amount<'ctx> {
364 fn add_assign(&mut self, rhs: PostingAmount<'ctx>) {
365 match rhs {
366 PostingAmount::Zero => (),
367 PostingAmount::Single(single) => *self += single,
368 }
369 }
370}
371
372impl Sub for Amount<'_> {
373 type Output = Self;
374
375 fn sub(mut self, rhs: Self) -> Self::Output {
376 self -= rhs;
377 self
378 }
379}
380
381impl SubAssign for Amount<'_> {
382 fn sub_assign(&mut self, rhs: Self) {
383 for (c, v2) in rhs.values {
384 let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
385 v1 -= v2;
386 }
387 }
388}
389
390impl Mul<Decimal> for Amount<'_> {
391 type Output = Self;
392
393 fn mul(mut self, rhs: Decimal) -> Self::Output {
394 self *= rhs;
395 self
396 }
397}
398
399impl MulAssign<Decimal> for Amount<'_> {
400 fn mul_assign(&mut self, rhs: Decimal) {
401 for (_, mut v) in self.values.iter_mut() {
402 v *= rhs;
403 }
404 }
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 use bumpalo::Bump;
412 use maplit::btreemap;
413 use pretty_assertions::assert_eq;
414 use pretty_decimal::PrettyDecimal;
415 use rust_decimal_macros::dec;
416
417 use crate::report::ReportContext;
418
419 #[test]
420 fn test_default() {
421 let arena = Bump::new();
422 let ctx = ReportContext::new(&arena);
423 let amount = Amount::default();
424 assert_eq!(format!("{}", amount.as_inline_display(&ctx)), "0")
425 }
426
427 #[test]
428 fn test_from_value() {
429 let arena = Bump::new();
430 let mut ctx = ReportContext::new(&arena);
431 let jpy = ctx.commodities.ensure("JPY");
432 let amount = Amount::from_value(jpy, dec!(123.45));
433 assert_eq!(format!("{}", amount.as_inline_display(&ctx)), "123.45 JPY")
434 }
435
436 #[test]
437 fn test_from_values() {
438 let arena = Bump::new();
439 let mut ctx = ReportContext::new(&arena);
440 let jpy = ctx.commodities.ensure("JPY");
441 let chf = ctx.commodities.ensure("CHF");
442
443 let amount = Amount::from_iter([(jpy, dec!(10)), (chf, dec!(1))]);
444 assert_eq!(
445 amount.into_values(),
446 btreemap! {jpy => dec!(10), chf => dec!(1)},
447 );
448
449 let amount = Amount::from_iter([(jpy, dec!(10)), (jpy, dec!(1))]);
450 assert_eq!(amount.into_values(), btreemap! {jpy => dec!(11)});
451
452 let amount = Amount::from_iter([(jpy, dec!(10)), (jpy, dec!(-10))]);
453 assert_eq!(amount.into_values(), btreemap! {jpy => dec!(0)});
454 }
455
456 #[test]
457 fn test_is_absolute_zero() {
458 let arena = Bump::new();
459 let mut ctx = ReportContext::new(&arena);
460 let jpy = ctx.commodities.ensure("JPY");
461 let usd = ctx.commodities.ensure("USD");
462
463 assert!(Amount::default().is_absolute_zero());
464 assert!(!Amount::from_value(jpy, dec!(0)).is_absolute_zero());
465
466 let mut amount = Amount::from_iter([(jpy, dec!(0)), (usd, dec!(0))]);
467 assert!(
468 !amount.is_absolute_zero(),
469 "{}",
470 amount.as_inline_display(&ctx)
471 );
472
473 amount.remove_zero_entries();
474 assert!(
475 amount.is_absolute_zero(),
476 "{}",
477 amount.as_inline_display(&ctx)
478 );
479 }
480
481 #[test]
482 fn test_is_zero() {
483 let arena = Bump::new();
484 let mut ctx = ReportContext::new(&arena);
485 let jpy = ctx.commodities.ensure("JPY");
486 let usd = ctx.commodities.ensure("USD");
487
488 assert!(Amount::default().is_zero());
489 assert!(Amount::from_value(jpy, dec!(0)).is_zero());
490 assert!(Amount::from_iter([(jpy, dec!(0)), (usd, dec!(0))]).is_zero());
491
492 assert!(!Amount::from_value(jpy, dec!(1)).is_zero());
493 assert!(!Amount::from_iter([(jpy, dec!(0)), (usd, dec!(1))]).is_zero());
494 }
495
496 #[test]
497 fn test_neg() {
498 let arena = Bump::new();
499 let mut ctx = ReportContext::new(&arena);
500 let jpy = ctx.commodities.ensure("JPY");
501 let usd = ctx.commodities.ensure("USD");
502
503 assert_eq!(-Amount::zero(), Amount::zero());
504 assert_eq!(
505 -Amount::from_value(jpy, dec!(100)),
506 Amount::from_value(jpy, dec!(-100))
507 );
508 assert_eq!(
509 -Amount::from_iter([(jpy, dec!(100)), (usd, dec!(-20.35))]),
510 Amount::from_iter([(jpy, dec!(-100)), (usd, dec!(20.35))]),
511 );
512 }
513
514 #[test]
515 fn test_add_amount() {
516 let arena = Bump::new();
517 let mut ctx = ReportContext::new(&arena);
518 let jpy = ctx.commodities.ensure("JPY");
519 let usd = ctx.commodities.ensure("USD");
520 let eur = ctx.commodities.ensure("EUR");
521 let chf = ctx.commodities.ensure("CHF");
522
523 let zero_plus_zero = Amount::zero() + Amount::zero();
524 assert_eq!(zero_plus_zero, Amount::zero());
525
526 assert_eq!(
527 Amount::from_value(jpy, dec!(1)) + Amount::zero(),
528 Amount::from_value(jpy, dec!(1)),
529 );
530 assert_eq!(
531 Amount::zero() + Amount::from_value(jpy, dec!(1)),
532 Amount::from_value(jpy, dec!(1)),
533 );
534 assert_eq!(
535 Amount::from_iter([
536 (jpy, dec!(123.00)),
537 (usd, dec!(456.0)),
538 (eur, dec!(7.89)),
539 (chf, dec!(0)), ]),
541 Amount::from_value(jpy, dec!(123.45))
542 + Amount::from_value(jpy, dec!(-0.45))
543 + Amount::from_value(usd, dec!(456))
544 + Amount::from_value(usd, dec!(0.0))
545 + -Amount::from_value(chf, dec!(100))
546 + Amount::from_value(eur, dec!(7.89))
547 + Amount::from_value(chf, dec!(100)),
548 );
549
550 assert_eq!(
551 Amount::from_iter([(jpy, dec!(0)), (usd, dec!(0)), (chf, dec!(0))]),
552 Amount::from_iter([(jpy, dec!(1)), (usd, dec!(2)), (chf, dec!(3))])
553 + Amount::from_iter([(jpy, dec!(-1)), (usd, dec!(-2)), (chf, dec!(-3))])
554 );
555 }
556
557 #[test]
558 fn test_add_single_amount() {
559 let arena = Bump::new();
560 let mut ctx = ReportContext::new(&arena);
561 let jpy = ctx.commodities.ensure("JPY");
562 let usd = ctx.commodities.ensure("USD");
563
564 let amount = Amount::zero() + SingleAmount::from_value(usd, dec!(0));
565 assert_eq!(amount, Amount::from_value(usd, dec!(0)));
566
567 assert_eq!(
568 Amount::zero() + SingleAmount::from_value(jpy, dec!(1)),
569 Amount::from_value(jpy, dec!(1)),
570 );
571 }
572
573 #[test]
574 fn test_sub() {
575 let arena = Bump::new();
576 let mut ctx = ReportContext::new(&arena);
577 let jpy = ctx.commodities.ensure("JPY");
578 let usd = ctx.commodities.ensure("USD");
579 let eur = ctx.commodities.ensure("EUR");
580 let chf = ctx.commodities.ensure("CHF");
581
582 let zero_minus_zero = Amount::zero() - Amount::zero();
583 assert_eq!(zero_minus_zero, Amount::zero());
584
585 assert_eq!(
586 Amount::from_value(jpy, dec!(1)) - Amount::zero(),
587 Amount::from_value(jpy, dec!(1)),
588 );
589 assert_eq!(
590 Amount::zero() - Amount::from_value(jpy, dec!(1)),
591 Amount::from_value(jpy, dec!(-1)),
592 );
593 assert_eq!(
594 Amount::from_iter([
595 (jpy, dec!(12345)),
596 (eur, dec!(-200)),
597 (chf, dec!(13.3)),
598 (usd, dec!(0))
599 ]),
600 Amount::from_iter([(jpy, dec!(12345)), (usd, dec!(56.78))])
601 - Amount::from_iter([(usd, dec!(56.780)), (eur, dec!(200)), (chf, dec!(-13.3)),]),
602 );
603 }
604
605 fn eps() -> Decimal {
606 Decimal::try_from_i128_with_scale(1, 28).unwrap()
607 }
608
609 #[test]
610 fn test_mul() {
611 let arena = Bump::new();
612 let mut ctx = ReportContext::new(&arena);
613 let jpy = ctx.commodities.ensure("JPY");
614 let eur = ctx.commodities.ensure("EUR");
615 let chf = ctx.commodities.ensure("CHF");
616
617 assert_eq!(Amount::zero() * dec!(5), Amount::zero());
618 assert_eq!(
619 Amount::from_value(jpy, dec!(1)) * Decimal::ZERO,
620 Amount::from_value(jpy, dec!(0)),
621 );
622 assert_eq!(
623 Amount::from_value(jpy, dec!(123)) * dec!(3),
624 Amount::from_value(jpy, dec!(369)),
625 );
626 assert_eq!(
627 Amount::from_iter([(jpy, dec!(10081)), (eur, dec!(200)), (chf, dec!(-13.3))])
628 * dec!(-0.5),
629 Amount::from_iter([(jpy, dec!(-5040.5)), (eur, dec!(-100.0)), (chf, dec!(6.65))]),
630 );
631 assert_eq!(
632 Amount::from_value(jpy, eps()) * eps(),
633 Amount::from_value(jpy, dec!(0))
634 );
635 }
636
637 #[test]
638 fn test_check_div() {
639 let arena = Bump::new();
640 let mut ctx = ReportContext::new(&arena);
641 let jpy = ctx.commodities.ensure("JPY");
642 let eur = ctx.commodities.ensure("EUR");
643 let chf = ctx.commodities.ensure("CHF");
644
645 assert_eq!(Amount::zero().check_div(dec!(5)).unwrap(), Amount::zero());
646 assert_eq!(
647 Amount::zero().check_div(dec!(0)).unwrap_err(),
648 EvalError::DivideByZero
649 );
650
651 assert_eq!(
652 Amount::from_value(jpy, dec!(50))
653 .check_div(dec!(4))
654 .unwrap(),
655 Amount::from_value(jpy, dec!(12.5))
656 );
657
658 assert_eq!(
659 Amount::from_value(jpy, Decimal::MAX)
660 .check_div(eps())
661 .unwrap_err(),
662 EvalError::NumberOverflow
663 );
664
665 assert_eq!(
666 Amount::from_value(jpy, eps())
667 .check_div(Decimal::MAX)
668 .unwrap(),
669 Amount::from_value(jpy, dec!(0))
670 );
671
672 assert_eq!(
673 Amount::from_iter([(jpy, dec!(810)), (eur, dec!(-100.0)), (chf, dec!(6.66))])
674 .check_div(dec!(3))
675 .unwrap(),
676 Amount::from_iter([
677 (jpy, dec!(270)),
678 (eur, dec!(-33.333333333333333333333333333)),
679 (chf, dec!(2.22))
680 ]),
681 );
682 }
683
684 #[test]
685 fn test_round() {
686 let arena = Bump::new();
687 let mut ctx = ReportContext::new(&arena);
688 let jpy = ctx.commodities.ensure("JPY");
689 let eur = ctx.commodities.ensure("EUR");
690 let chf = ctx.commodities.ensure("CHF");
691
692 ctx.commodities
693 .set_format(jpy, PrettyDecimal::comma3dot(dec!(12345)));
694 ctx.commodities
695 .set_format(eur, PrettyDecimal::plain(dec!(123.45)));
696 ctx.commodities
697 .set_format(chf, PrettyDecimal::comma3dot(dec!(123.450)));
698
699 assert_eq!(Amount::zero(), Amount::zero().round(&ctx));
700
701 assert_eq!(
702 Amount::from_iter([(jpy, dec!(812)), (eur, dec!(-100.00)), (chf, dec!(6.660))]),
703 Amount::from_iter([(jpy, dec!(812)), (eur, dec!(-100.0)), (chf, dec!(6.66))])
704 .round(&ctx),
705 );
706
707 assert_eq!(
708 Amount::from_iter([(jpy, dec!(812)), (eur, dec!(-100.02)), (chf, dec!(6.666))]),
709 Amount::from_iter([
710 (jpy, dec!(812.5)),
711 (eur, dec!(-100.015)),
712 (chf, dec!(6.6665))
713 ])
714 .round(&ctx),
715 );
716 }
717
718 #[test]
719 fn test_to_string() {
720 let arena = Bump::new();
721 let mut ctx = ReportContext::new(&arena);
722 let jpy = ctx.commodities.ensure("JPY");
723 let chf = ctx.commodities.ensure("CHF");
724
725 assert_eq!("0", Amount::default().as_inline_display(&ctx).to_string());
726
727 assert_eq!(
728 "10 JPY",
729 Amount::from_value(jpy, dec!(10))
730 .as_inline_display(&ctx)
731 .to_string()
732 );
733
734 assert_eq!(
735 "(10 JPY + 1 CHF)",
736 Amount::from_iter([(jpy, dec!(10)), (chf, dec!(1))])
737 .as_inline_display(&ctx)
738 .to_string()
739 );
740
741 assert_eq!(
742 "(-10 JPY - 1 CHF)",
743 Amount::from_iter([(jpy, dec!(-10)), (chf, dec!(-1))])
744 .as_inline_display(&ctx)
745 .to_string()
746 );
747 }
748}