1use std::{
2 collections::{hash_map, HashMap},
3 fmt::Display,
4 iter::FusedIterator,
5 ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign},
6};
7
8use rust_decimal::Decimal;
9
10use crate::report::{commodity::Commodity, context::ReportContext};
11
12use super::{error::EvalError, PostingAmount, SingleAmount};
13
14#[derive(Debug, Default, PartialEq, Eq, Clone)]
16pub struct Amount<'ctx> {
17 values: HashMap<Commodity<'ctx>, Decimal>,
21}
22
23impl<'ctx> TryFrom<Amount<'ctx>> for SingleAmount<'ctx> {
24 type Error = EvalError;
25
26 fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
27 SingleAmount::try_from(&value)
28 }
29}
30
31impl<'ctx> TryFrom<Amount<'ctx>> for PostingAmount<'ctx> {
32 type Error = EvalError;
33
34 fn try_from(value: Amount<'ctx>) -> Result<Self, Self::Error> {
35 PostingAmount::try_from(&value)
36 }
37}
38
39impl<'ctx> TryFrom<&Amount<'ctx>> for SingleAmount<'ctx> {
40 type Error = EvalError;
41
42 fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
43 let (commodity, value) = value
44 .values
45 .iter()
46 .next()
47 .ok_or(EvalError::SingleAmountRequired)?;
48 Ok(SingleAmount {
49 value: *value,
50 commodity: *commodity,
51 })
52 }
53}
54
55impl<'ctx> TryFrom<&Amount<'ctx>> for PostingAmount<'ctx> {
56 type Error = EvalError;
57
58 fn try_from(value: &Amount<'ctx>) -> Result<Self, Self::Error> {
59 if value.values.len() > 1 {
60 Err(EvalError::PostingAmountRequired)
61 } else {
62 Ok(value
63 .values
64 .iter()
65 .next()
66 .map(|(commodity, value)| {
67 PostingAmount::Single(SingleAmount {
68 value: *value,
69 commodity: *commodity,
70 })
71 })
72 .unwrap_or_default())
73 }
74 }
75}
76
77impl<'ctx> From<PostingAmount<'ctx>> for Amount<'ctx> {
78 fn from(value: PostingAmount<'ctx>) -> Self {
79 match value {
80 PostingAmount::Zero => Amount::zero(),
81 PostingAmount::Single(single_amount) => single_amount.into(),
82 }
83 }
84}
85
86impl<'ctx> From<SingleAmount<'ctx>> for Amount<'ctx> {
87 fn from(value: SingleAmount<'ctx>) -> Self {
88 Amount::from_value(value.value, value.commodity)
89 }
90}
91
92impl<'ctx> Amount<'ctx> {
93 #[inline(always)]
95 pub fn zero() -> Self {
96 Self::default()
97 }
98
99 pub fn from_value(amount: Decimal, commodity: Commodity<'ctx>) -> Self {
101 Self::zero() + SingleAmount::from_value(amount, commodity)
102 }
103
104 pub fn from_values<T>(values: T) -> Self
106 where
107 T: IntoIterator<Item = (Decimal, Commodity<'ctx>)>,
108 {
109 let mut ret = Amount::zero();
110 for (value, commodity) in values.into_iter() {
111 ret += SingleAmount::from_value(value, commodity);
112 }
113 ret
114 }
115
116 pub fn into_values(self) -> HashMap<Commodity<'ctx>, Decimal> {
118 self.values
119 }
120
121 pub fn iter(&self) -> impl Iterator<Item = SingleAmount<'ctx>> + '_ {
123 AmountIter(self.values.iter())
124 }
125
126 pub fn as_inline_display(&self) -> impl Display + '_ {
128 InlinePrintAmount(self)
129 }
130
131 pub fn is_absolute_zero(&self) -> bool {
134 self.values.is_empty()
135 }
136
137 pub fn is_zero(&self) -> bool {
139 self.values.iter().all(|(_, v)| v.is_zero())
140 }
141
142 pub fn remove_zero_entries(&mut self) {
146 self.values.retain(|_, v| !v.is_zero());
147 }
148
149 pub(crate) fn set_partial(&mut self, amount: SingleAmount<'ctx>) -> SingleAmount<'ctx> {
154 let value = if amount.value.is_zero() {
155 self.values.remove(&amount.commodity)
156 } else {
157 self.values.insert(amount.commodity, amount.value)
158 }
159 .unwrap_or_default();
160 SingleAmount {
161 value,
162 commodity: amount.commodity,
163 }
164 }
165
166 fn get_part(&self, commodity: Commodity<'ctx>) -> Decimal {
168 self.values.get(&commodity).copied().unwrap_or_default()
169 }
170
171 pub fn maybe_pair(&self) -> Option<(SingleAmount<'ctx>, SingleAmount<'ctx>)> {
174 if self.values.len() != 2 {
175 return None;
176 }
177 let ((c1, v1), (c2, v2)) = self.values.iter().zip(self.values.iter().skip(1)).next()?;
178 Some((
179 SingleAmount::from_value(*v1, *c1),
180 SingleAmount::from_value(*v2, *c2),
181 ))
182 }
183
184 pub fn round(mut self, ctx: &ReportContext) -> Self {
186 self.round_mut(ctx);
187 self
188 }
189
190 pub fn round_mut(&mut self, ctx: &ReportContext) {
192 for (k, v) in self.values.iter_mut() {
193 match ctx.commodities.get_decimal_point(*k) {
194 None => (),
195 Some(dp) => {
196 let updated = v.round_dp_with_strategy(
197 dp,
198 rust_decimal::RoundingStrategy::MidpointNearestEven,
199 );
200 *v = updated;
201 }
202 }
203 }
204 }
205
206 pub fn negate(mut self) -> Self {
208 for (_, v) in self.values.iter_mut() {
209 v.set_sign_positive(!v.is_sign_positive())
210 }
211 self
212 }
213
214 pub fn check_div(mut self, rhs: Decimal) -> Result<Self, EvalError> {
216 if rhs.is_zero() {
217 return Err(EvalError::DivideByZero);
218 }
219 for (_, v) in self.values.iter_mut() {
220 *v = v.checked_div(rhs).ok_or(EvalError::NumberOverflow)?;
221 }
222 Ok(self)
223 }
224
225 pub(crate) fn is_consistent(&self, rhs: &PostingAmount<'ctx>) -> bool {
232 match rhs {
233 PostingAmount::Zero => self.is_zero(),
234 PostingAmount::Single(single) => self.get_part(single.commodity) == single.value,
235 }
236 }
237}
238
239#[derive(Debug)]
240struct AmountIter<'a, 'ctx>(hash_map::Iter<'a, Commodity<'ctx>, Decimal>);
241
242impl<'ctx> Iterator for AmountIter<'_, 'ctx> {
243 type Item = SingleAmount<'ctx>;
244
245 fn next(&mut self) -> Option<Self::Item> {
246 self.0.next().map(|(c, v)| SingleAmount::from_value(*v, *c))
247 }
248}
249
250impl FusedIterator for AmountIter<'_, '_> {}
251
252#[derive(Debug)]
253struct InlinePrintAmount<'a, 'ctx>(&'a Amount<'ctx>);
254
255impl Display for InlinePrintAmount<'_, '_> {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 let vs = &self.0.values;
258 match vs.len() {
259 0 | 1 => match vs.iter().next() {
260 Some((c, v)) => write!(f, "{} {}", v, c.as_str()),
261 None => write!(f, "0"),
262 },
263 _ => {
264 write!(f, "(")?;
265 for (i, (c, v)) in vs.iter().enumerate() {
266 if i != 0 {
267 write!(f, " + ")?;
268 }
269 write!(f, "{} {}", v, c.as_str())?;
270 }
271 write!(f, ")")
272 }
273 }
274 }
275}
276
277impl Neg for Amount<'_> {
278 type Output = Self;
279
280 fn neg(self) -> Self::Output {
281 self.negate()
282 }
283}
284
285impl Add for Amount<'_> {
286 type Output = Self;
287
288 fn add(mut self, rhs: Self) -> Self::Output {
289 self += rhs;
290 self
291 }
292}
293
294impl AddAssign for Amount<'_> {
295 fn add_assign(&mut self, rhs: Self) {
296 for (c, v2) in rhs.values {
297 let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
298 v1 += v2;
299 }
302 }
303}
304
305impl<'ctx> Add<SingleAmount<'ctx>> for Amount<'ctx> {
306 type Output = Amount<'ctx>;
307
308 fn add(mut self, rhs: SingleAmount<'ctx>) -> Self::Output {
309 self += rhs;
310 self
311 }
312}
313
314impl<'ctx> AddAssign<SingleAmount<'ctx>> for Amount<'ctx> {
315 fn add_assign(&mut self, rhs: SingleAmount<'ctx>) {
316 let curr = self.values.entry(rhs.commodity).or_default();
317 *curr += rhs.value;
318 }
319}
320
321impl<'ctx> AddAssign<PostingAmount<'ctx>> for Amount<'ctx> {
322 fn add_assign(&mut self, rhs: PostingAmount<'ctx>) {
323 match rhs {
324 PostingAmount::Zero => (),
325 PostingAmount::Single(single) => *self += single,
326 }
327 }
328}
329
330impl Sub for Amount<'_> {
331 type Output = Self;
332
333 fn sub(mut self, rhs: Self) -> Self::Output {
334 self -= rhs;
335 self
336 }
337}
338
339impl SubAssign for Amount<'_> {
340 fn sub_assign(&mut self, rhs: Self) {
341 for (c, v2) in rhs.values {
342 let mut v1 = self.values.entry(c).or_insert(Decimal::ZERO);
343 v1 -= v2;
344 }
345 }
346}
347
348impl Mul<Decimal> for Amount<'_> {
349 type Output = Self;
350
351 fn mul(mut self, rhs: Decimal) -> Self::Output {
352 self *= rhs;
353 self
354 }
355}
356
357impl MulAssign<Decimal> for Amount<'_> {
358 fn mul_assign(&mut self, rhs: Decimal) {
359 for (_, mut v) in self.values.iter_mut() {
360 v *= rhs;
361 }
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 use bumpalo::Bump;
370 use maplit::hashmap;
371 use pretty_assertions::assert_eq;
372 use rust_decimal_macros::dec;
373
374 use crate::{report::ReportContext, syntax::pretty_decimal::PrettyDecimal};
375
376 #[test]
377 fn test_default() {
378 let amount = Amount::default();
379 assert_eq!(format!("{}", amount.as_inline_display()), "0")
380 }
381
382 #[test]
383 fn test_from_value() {
384 let arena = Bump::new();
385 let mut ctx = ReportContext::new(&arena);
386 let jpy = ctx.commodities.ensure("JPY");
387 let amount = Amount::from_value(dec!(123.45), jpy);
388 assert_eq!(format!("{}", amount.as_inline_display()), "123.45 JPY")
389 }
390
391 #[test]
392 fn test_from_values() {
393 let arena = Bump::new();
394 let mut ctx = ReportContext::new(&arena);
395 let jpy = ctx.commodities.ensure("JPY");
396 let chf = ctx.commodities.ensure("CHF");
397
398 let amount = Amount::from_values([(dec!(10), jpy), (dec!(1), chf)]);
399 assert_eq!(
400 amount.into_values(),
401 hashmap! {jpy => dec!(10), chf => dec!(1)},
402 );
403
404 let amount = Amount::from_values([(dec!(10), jpy), (dec!(1), jpy)]);
405 assert_eq!(amount.into_values(), hashmap! {jpy => dec!(11)});
406
407 let amount = Amount::from_values([(dec!(10), jpy), (dec!(-10), jpy)]);
408 assert_eq!(amount.into_values(), hashmap! {jpy => dec!(0)});
409 }
410
411 #[test]
412 fn test_is_absolute_zero() {
413 let arena = Bump::new();
414 let mut ctx = ReportContext::new(&arena);
415 let jpy = ctx.commodities.ensure("JPY");
416 let usd = ctx.commodities.ensure("USD");
417
418 assert!(Amount::default().is_absolute_zero());
419 assert!(!Amount::from_value(dec!(0), jpy).is_absolute_zero());
420
421 let mut amount = Amount::from_values([(dec!(0), jpy), (dec!(0), usd)]);
422 assert!(!amount.is_absolute_zero(), "{}", amount.as_inline_display());
423
424 amount.remove_zero_entries();
425 assert!(amount.is_absolute_zero(), "{}", amount.as_inline_display());
426 }
427
428 #[test]
429 fn test_is_zero() {
430 let arena = Bump::new();
431 let mut ctx = ReportContext::new(&arena);
432 let jpy = ctx.commodities.ensure("JPY");
433 let usd = ctx.commodities.ensure("USD");
434
435 assert!(Amount::default().is_zero());
436 assert!(Amount::from_value(dec!(0), jpy).is_zero());
437 assert!(Amount::from_values([(dec!(0), jpy), (dec!(0), usd)]).is_zero());
438
439 assert!(!Amount::from_value(dec!(1), jpy).is_zero());
440 assert!(!Amount::from_values([(dec!(0), jpy), (dec!(1), usd)]).is_zero());
441 }
442
443 #[test]
444 fn test_neg() {
445 let arena = Bump::new();
446 let mut ctx = ReportContext::new(&arena);
447 let jpy = ctx.commodities.ensure("JPY");
448 let usd = ctx.commodities.ensure("USD");
449
450 assert_eq!(-Amount::zero(), Amount::zero());
451 assert_eq!(
452 -Amount::from_value(dec!(100), jpy),
453 Amount::from_value(dec!(-100), jpy)
454 );
455 assert_eq!(
456 -Amount::from_values([(dec!(100), jpy), (dec!(-20.35), usd)]),
457 Amount::from_values([(dec!(-100), jpy), (dec!(20.35), usd)]),
458 );
459 }
460
461 #[test]
462 fn test_add_amount() {
463 let arena = Bump::new();
464 let mut ctx = ReportContext::new(&arena);
465 let jpy = ctx.commodities.ensure("JPY");
466 let usd = ctx.commodities.ensure("USD");
467 let eur = ctx.commodities.ensure("EUR");
468 let chf = ctx.commodities.ensure("CHF");
469
470 let zero_plus_zero = Amount::zero() + Amount::zero();
471 assert_eq!(zero_plus_zero, Amount::zero());
472
473 assert_eq!(
474 Amount::from_value(dec!(1), jpy) + Amount::zero(),
475 Amount::from_value(dec!(1), jpy),
476 );
477 assert_eq!(
478 Amount::zero() + Amount::from_value(dec!(1), jpy),
479 Amount::from_value(dec!(1), jpy),
480 );
481 assert_eq!(
482 Amount::from_values([
483 (dec!(123.00), jpy),
484 (dec!(456.0), usd),
485 (dec!(7.89), eur),
486 (dec!(0), chf), ]),
488 Amount::from_value(dec!(123.45), jpy)
489 + Amount::from_value(dec!(-0.45), jpy)
490 + Amount::from_value(dec!(456), usd)
491 + Amount::from_value(dec!(0.0), usd)
492 + -Amount::from_value(dec!(100), chf)
493 + Amount::from_value(dec!(7.89), eur)
494 + Amount::from_value(dec!(100), chf),
495 );
496
497 assert_eq!(
498 Amount::from_values([(dec!(0), jpy), (dec!(0), usd), (dec!(0), chf)]),
499 Amount::from_values([(dec!(1), jpy), (dec!(2), usd), (dec!(3), chf)])
500 + Amount::from_values([(dec!(-1), jpy), (dec!(-2), usd), (dec!(-3), chf)])
501 );
502 }
503
504 #[test]
505 fn test_add_single_amount() {
506 let arena = Bump::new();
507 let mut ctx = ReportContext::new(&arena);
508 let jpy = ctx.commodities.ensure("JPY");
509 let usd = ctx.commodities.ensure("USD");
510
511 let amount = Amount::zero() + SingleAmount::from_value(dec!(0), usd);
512 assert_eq!(amount, Amount::from_value(dec!(0), usd));
513
514 assert_eq!(
515 Amount::zero() + SingleAmount::from_value(dec!(1), jpy),
516 Amount::from_value(dec!(1), jpy),
517 );
518 }
519
520 #[test]
521 fn test_sub() {
522 let arena = Bump::new();
523 let mut ctx = ReportContext::new(&arena);
524 let jpy = ctx.commodities.ensure("JPY");
525 let usd = ctx.commodities.ensure("USD");
526 let eur = ctx.commodities.ensure("EUR");
527 let chf = ctx.commodities.ensure("CHF");
528
529 let zero_minus_zero = Amount::zero() - Amount::zero();
530 assert_eq!(zero_minus_zero, Amount::zero());
531
532 assert_eq!(
533 Amount::from_value(dec!(1), jpy) - Amount::zero(),
534 Amount::from_value(dec!(1), jpy),
535 );
536 assert_eq!(
537 Amount::zero() - Amount::from_value(dec!(1), jpy),
538 Amount::from_value(dec!(-1), jpy),
539 );
540 assert_eq!(
541 Amount::from_values([
542 (dec!(12345), jpy),
543 (dec!(-200), eur),
544 (dec!(13.3), chf),
545 (dec!(0), usd)
546 ]),
547 Amount::from_values([(dec!(12345), jpy), (dec!(56.78), usd)])
548 - Amount::from_values([(dec!(56.780), usd), (dec!(200), eur), (dec!(-13.3), chf),]),
549 );
550 }
551
552 fn eps() -> Decimal {
553 Decimal::try_from_i128_with_scale(1, 28).unwrap()
554 }
555
556 #[test]
557 fn test_mul() {
558 let arena = Bump::new();
559 let mut ctx = ReportContext::new(&arena);
560 let jpy = ctx.commodities.ensure("JPY");
561 let eur = ctx.commodities.ensure("EUR");
562 let chf = ctx.commodities.ensure("CHF");
563
564 assert_eq!(Amount::zero() * dec!(5), Amount::zero());
565 assert_eq!(
566 Amount::from_value(dec!(1), jpy) * Decimal::ZERO,
567 Amount::from_value(dec!(0), jpy),
568 );
569 assert_eq!(
570 Amount::from_value(dec!(123), jpy) * dec!(3),
571 Amount::from_value(dec!(369), jpy),
572 );
573 assert_eq!(
574 Amount::from_values([(dec!(10081), jpy), (dec!(200), eur), (dec!(-13.3), chf)])
575 * dec!(-0.5),
576 Amount::from_values([(dec!(-5040.5), jpy), (dec!(-100.0), eur), (dec!(6.65), chf)]),
577 );
578 assert_eq!(
579 Amount::from_value(eps(), jpy) * eps(),
580 Amount::from_value(dec!(0), jpy)
581 );
582 }
583
584 #[test]
585 fn test_check_div() {
586 let arena = Bump::new();
587 let mut ctx = ReportContext::new(&arena);
588 let jpy = ctx.commodities.ensure("JPY");
589 let eur = ctx.commodities.ensure("EUR");
590 let chf = ctx.commodities.ensure("CHF");
591
592 assert_eq!(Amount::zero().check_div(dec!(5)).unwrap(), Amount::zero());
593 assert_eq!(
594 Amount::zero().check_div(dec!(0)).unwrap_err(),
595 EvalError::DivideByZero
596 );
597
598 assert_eq!(
599 Amount::from_value(dec!(50), jpy)
600 .check_div(dec!(4))
601 .unwrap(),
602 Amount::from_value(dec!(12.5), jpy)
603 );
604
605 assert_eq!(
606 Amount::from_value(Decimal::MAX, jpy)
607 .check_div(eps())
608 .unwrap_err(),
609 EvalError::NumberOverflow
610 );
611
612 assert_eq!(
613 Amount::from_value(eps(), jpy)
614 .check_div(Decimal::MAX)
615 .unwrap(),
616 Amount::from_value(dec!(0), jpy)
617 );
618
619 assert_eq!(
620 Amount::from_values([(dec!(810), jpy), (dec!(-100.0), eur), (dec!(6.66), chf)])
621 .check_div(dec!(3))
622 .unwrap(),
623 Amount::from_values([
624 (dec!(270), jpy),
625 (dec!(-33.333333333333333333333333333), eur),
626 (dec!(2.22), chf)
627 ]),
628 );
629 }
630
631 #[test]
632 fn test_round() {
633 let arena = Bump::new();
634 let mut ctx = ReportContext::new(&arena);
635 let jpy = ctx.commodities.ensure("JPY");
636 let eur = ctx.commodities.ensure("EUR");
637 let chf = ctx.commodities.ensure("CHF");
638
639 ctx.commodities
640 .set_format(jpy, PrettyDecimal::comma3dot(dec!(12345)));
641 ctx.commodities
642 .set_format(eur, PrettyDecimal::plain(dec!(123.45)));
643 ctx.commodities
644 .set_format(chf, PrettyDecimal::comma3dot(dec!(123.450)));
645
646 assert_eq!(Amount::zero(), Amount::zero().round(&ctx));
647
648 assert_eq!(
649 Amount::from_values([(dec!(812), jpy), (dec!(-100.00), eur), (dec!(6.660), chf)]),
650 Amount::from_values([(dec!(812), jpy), (dec!(-100.0), eur), (dec!(6.66), chf)])
651 .round(&ctx),
652 );
653
654 assert_eq!(
655 Amount::from_values([(dec!(812), jpy), (dec!(-100.02), eur), (dec!(6.666), chf)]),
656 Amount::from_values([
657 (dec!(812.5), jpy),
658 (dec!(-100.015), eur),
659 (dec!(6.6665), chf)
660 ])
661 .round(&ctx),
662 );
663 }
664}