1use std::{
2 collections::{hash_map, HashMap},
3 fmt::Display,
4 iter::FusedIterator,
5 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6};
7
8use rust_decimal::Decimal;
9
10use crate::report::{
11 commodity::{CommodityStore, CommodityTag},
12 context::ReportContext,
13};
14
15use super::{error::EvalError, PostingAmount, SingleAmount};
16
17#[derive(Debug, Default, PartialEq, Eq, Clone)]
19pub struct Amount<'ctx> {
20 values: HashMap<CommodityTag<'ctx>, Decimal>,
24}
25
26impl<'ctx> TryFrom<Amount<'ctx>> for SingleAmount<'ctx> {
27 type Error = EvalError<'ctx>;
28
29 fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
30 SingleAmount::try_from(&value)
31 }
32}
33
34impl<'ctx> TryFrom<Amount<'ctx>> for PostingAmount<'ctx> {
35 type Error = EvalError<'ctx>;
36
37 fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
38 PostingAmount::try_from(&value)
39 }
40}
41
42impl<'ctx> TryFrom<&Amount<'ctx>> for SingleAmount<'ctx> {
43 type Error = EvalError<'ctx>;
44
45 fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
46 let (commodity, value) = value
47 .values
48 .iter()
49 .next()
50 .ok_or(EvalError::SingleAmountRequired)?;
51 Ok(SingleAmount {
52 value: *value,
53 commodity: *commodity,
54 })
55 }
56}
57
58impl<'ctx> TryFrom<&Amount<'ctx>> for PostingAmount<'ctx> {
59 type Error = EvalError<'ctx>;
60
61 fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
62 if value.values.len() > 1 {
63 Err(EvalError::PostingAmountRequired)
64 } else {
65 Ok(value
66 .values
67 .iter()
68 .next()
69 .map(|(commodity, value)| {
70 PostingAmount::Single(SingleAmount {
71 value: *value,
72 commodity: *commodity,
73 })
74 })
75 .unwrap_or_default())
76 }
77 }
78}
79
80impl<'ctx> From<PostingAmount<'ctx>> for Amount<'ctx> {
81 fn from(value: PostingAmount<'ctx>) -> Self {
82 match value {
83 PostingAmount::Zero => Amount::zero(),
84 PostingAmount::Single(single_amount) => single_amount.into(),
85 }
86 }
87}
88
89impl<'ctx> From<SingleAmount<'ctx>> for Amount<'ctx> {
90 fn from(value: SingleAmount<'ctx>) -> Self {
91 Amount::from_value(value.value, value.commodity)
92 }
93}
94
95impl<'ctx> Amount<'ctx> {
96 #[inline(always)]
98 pub fn zero() -> Self {
99 Self::default()
100 }
101
102 pub fn from_value(amount: Decimal, commodity: CommodityTag<'ctx>) -> Self {
104 Self::zero() + SingleAmount::from_value(amount, commodity)
105 }
106
107 pub fn from_values<T>(values: T) -> Self
109 where
110 T: IntoIterator<Item = (Decimal, CommodityTag<'ctx>)>,
111 {
112 let mut ret = Amount::zero();
113 for (value, commodity) in values.into_iter() {
114 ret += SingleAmount::from_value(value, commodity);
115 }
116 ret
117 }
118
119 pub fn into_values(self) -> HashMap<CommodityTag<'ctx>, Decimal> {
121 self.values
122 }
123
124 pub fn iter(&self) -> impl Iterator<Item = SingleAmount<'ctx>> + '_ {
126 AmountIter(self.values.iter())
127 }
128
129 pub fn as_inline_display<'a>(&'a self, ctx: &'a ReportContext<'ctx>) -> impl Display + 'a + 'ctx
131 where
132 'a: 'ctx,
133 {
134 InlinePrintAmount {
135 commodity_store: &ctx.commodities,
136 amount: self,
137 }
138 }
139
140 pub fn is_absolute_zero(&self) -> bool {
143 self.values.is_empty()
144 }
145
146 pub fn is_zero(&self) -> bool {
148 self.values.iter().all(|(_, v)| v.is_zero())
149 }
150
151 pub fn remove_zero_entries(&mut self) {
155 self.values.retain(|_, v| !v.is_zero());
156 }
157
158 pub(crate) fn set_partial(&mut self, amount: SingleAmount<'ctx>) -> SingleAmount<'ctx> {
163 let value = if amount.value.is_zero() {
164 self.values.remove(&amount.commodity)
165 } else {
166 self.values.insert(amount.commodity, amount.value)
167 }
168 .unwrap_or_default();
169 SingleAmount {
170 value,
171 commodity: amount.commodity,
172 }
173 }
174
175 fn get_part(&self, commodity: CommodityTag<'ctx>) -> Decimal {
177 self.values.get(&commodity).copied().unwrap_or_default()
178 }
179
180 pub fn maybe_pair(&self) -> Option<(SingleAmount<'ctx>, SingleAmount<'ctx>)> {
183 if self.values.len() != 2 {
184 return None;
185 }
186 let ((c1, v1), (c2, v2)) = self.values.iter().zip(self.values.iter().skip(1)).next()?;
187 Some((
188 SingleAmount::from_value(*v1, *c1),
189 SingleAmount::from_value(*v2, *c2),
190 ))
191 }
192
193 pub fn round(mut self, ctx: &ReportContext) -> Self {
195 self.round_mut(ctx);
196 self
197 }
198
199 pub fn round_mut(&mut self, ctx: &ReportContext) {
201 for (k, v) in self.values.iter_mut() {
202 match ctx.commodities.get_decimal_point(*k) {
203 None => (),
204 Some(dp) => {
205 let updated = v.round_dp_with_strategy(
206 dp,
207 rust_decimal::RoundingStrategy::MidpointNearestEven,
208 );
209 *v = updated;
210 }
211 }
212 }
213 }
214
215 pub fn negate(mut self) -> Self {
217 for (_, v) in self.values.iter_mut() {
218 v.set_sign_positive(!v.is_sign_positive())
219 }
220 self
221 }
222
223 pub fn check_div(mut self, rhs: Decimal) -> Result<Self, EvalError<'ctx>> {
225 if rhs.is_zero() {
226 return Err(EvalError::DivideByZero);
227 }
228 for (_, v) in self.values.iter_mut() {
229 *v = v.checked_div(rhs).ok_or(EvalError::NumberOverflow)?;
230 }
231 Ok(self)
232 }
233
234 pub(crate) fn assert_balance(&self, expected: &PostingAmount<'ctx>) -> Self {
243 match expected {
244 PostingAmount::Zero => {
245 if self.is_zero() {
246 Self::zero()
247 } else {
248 -self.clone()
249 }
250 }
251 PostingAmount::Single(single) => {
252 let diff = single.value - self.get_part(single.commodity);
253 if diff.is_zero() {
254 Self::zero()
255 } else {
256 Self::from_value(diff, single.commodity)
257 }
258 }
259 }
260 }
261}
262
263#[derive(Debug)]
264struct AmountIter<'a, 'ctx>(hash_map::Iter<'a, CommodityTag<'ctx>, Decimal>);
265
266impl<'ctx> Iterator for AmountIter<'_, 'ctx> {
267 type Item = SingleAmount<'ctx>;
268
269 fn next(&mut self) -> Option<Self::Item> {
270 self.0.next().map(|(c, v)| SingleAmount::from_value(*v, *c))
271 }
272}
273
274impl FusedIterator for AmountIter<'_, '_> {}
275
276#[derive(Debug)]
277struct InlinePrintAmount<'a, 'ctx> {
278 commodity_store: &'a CommodityStore<'ctx>,
279 amount: &'a Amount<'ctx>,
280}
281
282impl Display for InlinePrintAmount<'_, '_> {
283 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284 let vs = &self.amount.values;
285 match vs.len() {
286 0 | 1 => match vs.iter().next() {
287 Some((c, v)) => {
288 write!(f, "{} {}", v, c.to_str_lossy(self.commodity_store))
289 }
290 None => write!(f, "0"),
291 },
292 _ => {
293 write!(f, "(")?;
294 for (i, (c, v)) in vs.iter().enumerate() {
295 if i != 0 {
296 write!(f, " + ")?;
297 }
298 write!(f, "{} {}", v, c.to_str_lossy(self.commodity_store))?;
299 }
300 write!(f, ")")
301 }
302 }
303 }
304}
305
306impl Neg for Amount<'_> {
307 type Output = Self;
308
309 fn neg(self) -> Self::Output {
310 self.negate()
311 }
312}
313
314impl Add for Amount<'_> {
315 type Output = Self;
316
317 fn add(mut self, rhs: Self) -> Self::Output {
318 self += rhs;
319 self
320 }
321}
322
323impl AddAssign for Amount<'_> {
324 fn add_assign(&mut self, rhs: Self) {
325 for (c, v2) in rhs.values {
326 let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
327 v1 += v2;
328 }
331 }
332}
333
334impl<'ctx> Add<SingleAmount<'ctx>> for Amount<'ctx> {
335 type Output = Amount<'ctx>;
336
337 fn add(mut self, rhs: SingleAmount<'ctx>) -> Self::Output {
338 self += rhs;
339 self
340 }
341}
342
343impl<'ctx> AddAssign<SingleAmount<'ctx>> for Amount<'ctx> {
344 fn add_assign(&mut self, rhs: SingleAmount<'ctx>) {
345 let curr = self.values.entry(rhs.commodity).or_default();
346 *curr += rhs.value;
347 }
348}
349
350impl<'ctx> AddAssign<PostingAmount<'ctx>> for Amount<'ctx> {
351 fn add_assign(&mut self, rhs: PostingAmount<'ctx>) {
352 match rhs {
353 PostingAmount::Zero => (),
354 PostingAmount::Single(single) => *self += single,
355 }
356 }
357}
358
359impl Sub for Amount<'_> {
360 type Output = Self;
361
362 fn sub(mut self, rhs: Self) -> Self::Output {
363 self -= rhs;
364 self
365 }
366}
367
368impl SubAssign for Amount<'_> {
369 fn sub_assign(&mut self, rhs: Self) {
370 for (c, v2) in rhs.values {
371 let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
372 v1 -= v2;
373 }
374 }
375}
376
377impl Mul<Decimal> for Amount<'_> {
378 type Output = Self;
379
380 fn mul(mut self, rhs: Decimal) -> Self::Output {
381 self *= rhs;
382 self
383 }
384}
385
386impl MulAssign<Decimal> for Amount<'_> {
387 fn mul_assign(&mut self, rhs: Decimal) {
388 for (_, mut v) in self.values.iter_mut() {
389 v *= rhs;
390 }
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397
398 use bumpalo::Bump;
399 use maplit::hashmap;
400 use pretty_assertions::assert_eq;
401 use pretty_decimal::PrettyDecimal;
402 use rust_decimal_macros::dec;
403
404 use crate::report::ReportContext;
405
406 #[test]
407 fn test_default() {
408 let arena = Bump::new();
409 let ctx = ReportContext::new(&arena);
410 let amount = Amount::default();
411 assert_eq!(format!("{}", amount.as_inline_display(&ctx)), "0")
412 }
413
414 #[test]
415 fn test_from_value() {
416 let arena = Bump::new();
417 let mut ctx = ReportContext::new(&arena);
418 let jpy = ctx.commodities.ensure("JPY");
419 let amount = Amount::from_value(dec!(123.45), jpy);
420 assert_eq!(format!("{}", amount.as_inline_display(&ctx)), "123.45 JPY")
421 }
422
423 #[test]
424 fn test_from_values() {
425 let arena = Bump::new();
426 let mut ctx = ReportContext::new(&arena);
427 let jpy = ctx.commodities.ensure("JPY");
428 let chf = ctx.commodities.ensure("CHF");
429
430 let amount = Amount::from_values([(dec!(10), jpy), (dec!(1), chf)]);
431 assert_eq!(
432 amount.into_values(),
433 hashmap! {jpy => dec!(10), chf => dec!(1)},
434 );
435
436 let amount = Amount::from_values([(dec!(10), jpy), (dec!(1), jpy)]);
437 assert_eq!(amount.into_values(), hashmap! {jpy => dec!(11)});
438
439 let amount = Amount::from_values([(dec!(10), jpy), (dec!(-10), jpy)]);
440 assert_eq!(amount.into_values(), hashmap! {jpy => dec!(0)});
441 }
442
443 #[test]
444 fn test_is_absolute_zero() {
445 let arena = Bump::new();
446 let mut ctx = ReportContext::new(&arena);
447 let jpy = ctx.commodities.ensure("JPY");
448 let usd = ctx.commodities.ensure("USD");
449
450 assert!(Amount::default().is_absolute_zero());
451 assert!(!Amount::from_value(dec!(0), jpy).is_absolute_zero());
452
453 let mut amount = Amount::from_values([(dec!(0), jpy), (dec!(0), usd)]);
454 assert!(
455 !amount.is_absolute_zero(),
456 "{}",
457 amount.as_inline_display(&ctx)
458 );
459
460 amount.remove_zero_entries();
461 assert!(
462 amount.is_absolute_zero(),
463 "{}",
464 amount.as_inline_display(&ctx)
465 );
466 }
467
468 #[test]
469 fn test_is_zero() {
470 let arena = Bump::new();
471 let mut ctx = ReportContext::new(&arena);
472 let jpy = ctx.commodities.ensure("JPY");
473 let usd = ctx.commodities.ensure("USD");
474
475 assert!(Amount::default().is_zero());
476 assert!(Amount::from_value(dec!(0), jpy).is_zero());
477 assert!(Amount::from_values([(dec!(0), jpy), (dec!(0), usd)]).is_zero());
478
479 assert!(!Amount::from_value(dec!(1), jpy).is_zero());
480 assert!(!Amount::from_values([(dec!(0), jpy), (dec!(1), usd)]).is_zero());
481 }
482
483 #[test]
484 fn test_neg() {
485 let arena = Bump::new();
486 let mut ctx = ReportContext::new(&arena);
487 let jpy = ctx.commodities.ensure("JPY");
488 let usd = ctx.commodities.ensure("USD");
489
490 assert_eq!(-Amount::zero(), Amount::zero());
491 assert_eq!(
492 -Amount::from_value(dec!(100), jpy),
493 Amount::from_value(dec!(-100), jpy)
494 );
495 assert_eq!(
496 -Amount::from_values([(dec!(100), jpy), (dec!(-20.35), usd)]),
497 Amount::from_values([(dec!(-100), jpy), (dec!(20.35), usd)]),
498 );
499 }
500
501 #[test]
502 fn test_add_amount() {
503 let arena = Bump::new();
504 let mut ctx = ReportContext::new(&arena);
505 let jpy = ctx.commodities.ensure("JPY");
506 let usd = ctx.commodities.ensure("USD");
507 let eur = ctx.commodities.ensure("EUR");
508 let chf = ctx.commodities.ensure("CHF");
509
510 let zero_plus_zero = Amount::zero() + Amount::zero();
511 assert_eq!(zero_plus_zero, Amount::zero());
512
513 assert_eq!(
514 Amount::from_value(dec!(1), jpy) + Amount::zero(),
515 Amount::from_value(dec!(1), jpy),
516 );
517 assert_eq!(
518 Amount::zero() + Amount::from_value(dec!(1), jpy),
519 Amount::from_value(dec!(1), jpy),
520 );
521 assert_eq!(
522 Amount::from_values([
523 (dec!(123.00), jpy),
524 (dec!(456.0), usd),
525 (dec!(7.89), eur),
526 (dec!(0), chf), ]),
528 Amount::from_value(dec!(123.45), jpy)
529 + Amount::from_value(dec!(-0.45), jpy)
530 + Amount::from_value(dec!(456), usd)
531 + Amount::from_value(dec!(0.0), usd)
532 + -Amount::from_value(dec!(100), chf)
533 + Amount::from_value(dec!(7.89), eur)
534 + Amount::from_value(dec!(100), chf),
535 );
536
537 assert_eq!(
538 Amount::from_values([(dec!(0), jpy), (dec!(0), usd), (dec!(0), chf)]),
539 Amount::from_values([(dec!(1), jpy), (dec!(2), usd), (dec!(3), chf)])
540 + Amount::from_values([(dec!(-1), jpy), (dec!(-2), usd), (dec!(-3), chf)])
541 );
542 }
543
544 #[test]
545 fn test_add_single_amount() {
546 let arena = Bump::new();
547 let mut ctx = ReportContext::new(&arena);
548 let jpy = ctx.commodities.ensure("JPY");
549 let usd = ctx.commodities.ensure("USD");
550
551 let amount = Amount::zero() + SingleAmount::from_value(dec!(0), usd);
552 assert_eq!(amount, Amount::from_value(dec!(0), usd));
553
554 assert_eq!(
555 Amount::zero() + SingleAmount::from_value(dec!(1), jpy),
556 Amount::from_value(dec!(1), jpy),
557 );
558 }
559
560 #[test]
561 fn test_sub() {
562 let arena = Bump::new();
563 let mut ctx = ReportContext::new(&arena);
564 let jpy = ctx.commodities.ensure("JPY");
565 let usd = ctx.commodities.ensure("USD");
566 let eur = ctx.commodities.ensure("EUR");
567 let chf = ctx.commodities.ensure("CHF");
568
569 let zero_minus_zero = Amount::zero() - Amount::zero();
570 assert_eq!(zero_minus_zero, Amount::zero());
571
572 assert_eq!(
573 Amount::from_value(dec!(1), jpy) - Amount::zero(),
574 Amount::from_value(dec!(1), jpy),
575 );
576 assert_eq!(
577 Amount::zero() - Amount::from_value(dec!(1), jpy),
578 Amount::from_value(dec!(-1), jpy),
579 );
580 assert_eq!(
581 Amount::from_values([
582 (dec!(12345), jpy),
583 (dec!(-200), eur),
584 (dec!(13.3), chf),
585 (dec!(0), usd)
586 ]),
587 Amount::from_values([(dec!(12345), jpy), (dec!(56.78), usd)])
588 - Amount::from_values([(dec!(56.780), usd), (dec!(200), eur), (dec!(-13.3), chf),]),
589 );
590 }
591
592 fn eps() -> Decimal {
593 Decimal::try_from_i128_with_scale(1, 28).unwrap()
594 }
595
596 #[test]
597 fn test_mul() {
598 let arena = Bump::new();
599 let mut ctx = ReportContext::new(&arena);
600 let jpy = ctx.commodities.ensure("JPY");
601 let eur = ctx.commodities.ensure("EUR");
602 let chf = ctx.commodities.ensure("CHF");
603
604 assert_eq!(Amount::zero() * dec!(5), Amount::zero());
605 assert_eq!(
606 Amount::from_value(dec!(1), jpy) * Decimal::ZERO,
607 Amount::from_value(dec!(0), jpy),
608 );
609 assert_eq!(
610 Amount::from_value(dec!(123), jpy) * dec!(3),
611 Amount::from_value(dec!(369), jpy),
612 );
613 assert_eq!(
614 Amount::from_values([(dec!(10081), jpy), (dec!(200), eur), (dec!(-13.3), chf)])
615 * dec!(-0.5),
616 Amount::from_values([(dec!(-5040.5), jpy), (dec!(-100.0), eur), (dec!(6.65), chf)]),
617 );
618 assert_eq!(
619 Amount::from_value(eps(), jpy) * eps(),
620 Amount::from_value(dec!(0), jpy)
621 );
622 }
623
624 #[test]
625 fn test_check_div() {
626 let arena = Bump::new();
627 let mut ctx = ReportContext::new(&arena);
628 let jpy = ctx.commodities.ensure("JPY");
629 let eur = ctx.commodities.ensure("EUR");
630 let chf = ctx.commodities.ensure("CHF");
631
632 assert_eq!(Amount::zero().check_div(dec!(5)).unwrap(), Amount::zero());
633 assert_eq!(
634 Amount::zero().check_div(dec!(0)).unwrap_err(),
635 EvalError::DivideByZero
636 );
637
638 assert_eq!(
639 Amount::from_value(dec!(50), jpy)
640 .check_div(dec!(4))
641 .unwrap(),
642 Amount::from_value(dec!(12.5), jpy)
643 );
644
645 assert_eq!(
646 Amount::from_value(Decimal::MAX, jpy)
647 .check_div(eps())
648 .unwrap_err(),
649 EvalError::NumberOverflow
650 );
651
652 assert_eq!(
653 Amount::from_value(eps(), jpy)
654 .check_div(Decimal::MAX)
655 .unwrap(),
656 Amount::from_value(dec!(0), jpy)
657 );
658
659 assert_eq!(
660 Amount::from_values([(dec!(810), jpy), (dec!(-100.0), eur), (dec!(6.66), chf)])
661 .check_div(dec!(3))
662 .unwrap(),
663 Amount::from_values([
664 (dec!(270), jpy),
665 (dec!(-33.333333333333333333333333333), eur),
666 (dec!(2.22), chf)
667 ]),
668 );
669 }
670
671 #[test]
672 fn test_round() {
673 let arena = Bump::new();
674 let mut ctx = ReportContext::new(&arena);
675 let jpy = ctx.commodities.ensure("JPY");
676 let eur = ctx.commodities.ensure("EUR");
677 let chf = ctx.commodities.ensure("CHF");
678
679 ctx.commodities
680 .set_format(jpy, PrettyDecimal::comma3dot(dec!(12345)));
681 ctx.commodities
682 .set_format(eur, PrettyDecimal::plain(dec!(123.45)));
683 ctx.commodities
684 .set_format(chf, PrettyDecimal::comma3dot(dec!(123.450)));
685
686 assert_eq!(Amount::zero(), Amount::zero().round(&ctx));
687
688 assert_eq!(
689 Amount::from_values([(dec!(812), jpy), (dec!(-100.00), eur), (dec!(6.660), chf)]),
690 Amount::from_values([(dec!(812), jpy), (dec!(-100.0), eur), (dec!(6.66), chf)])
691 .round(&ctx),
692 );
693
694 assert_eq!(
695 Amount::from_values([(dec!(812), jpy), (dec!(-100.02), eur), (dec!(6.666), chf)]),
696 Amount::from_values([
697 (dec!(812.5), jpy),
698 (dec!(-100.015), eur),
699 (dec!(6.6665), chf)
700 ])
701 .round(&ctx),
702 );
703 }
704}