1use std::{
2 fmt::Display,
3 ops::{Mul, Neg},
4};
5
6use rust_decimal::Decimal;
7
8use crate::report::{commodity::CommodityTag, ReportContext};
9
10use super::error::EvalError;
11
12#[derive(Debug, PartialEq, Eq, Clone, Copy)]
14pub struct SingleAmount<'ctx> {
15 pub(crate) value: Decimal,
16 pub(crate) commodity: CommodityTag<'ctx>,
17}
18
19impl Neg for SingleAmount<'_> {
20 type Output = Self;
21
22 fn neg(self) -> Self::Output {
23 SingleAmount {
24 value: -self.value,
25 commodity: self.commodity,
26 }
27 }
28}
29
30impl Mul<Decimal> for SingleAmount<'_> {
31 type Output = Self;
32
33 fn mul(self, rhs: Decimal) -> Self::Output {
34 Self {
35 value: self.value * rhs,
36 commodity: self.commodity,
37 }
38 }
39}
40
41impl<'ctx> SingleAmount<'ctx> {
42 #[inline]
44 pub fn from_value(value: Decimal, commodity: CommodityTag<'ctx>) -> Self {
45 Self { value, commodity }
46 }
47
48 pub fn check_add(self, rhs: Self) -> Result<Self, EvalError<'ctx>> {
50 if self.commodity != rhs.commodity {
51 Err(EvalError::UnmatchingCommodities(
52 self.commodity,
53 rhs.commodity,
54 ))
55 } else {
56 Ok(Self {
57 value: self
58 .value
59 .checked_add(rhs.value)
60 .ok_or(EvalError::NumberOverflow)?,
61 commodity: self.commodity,
62 })
63 }
64 }
65
66 pub fn check_sub(self, rhs: Self) -> Result<Self, EvalError<'ctx>> {
68 self.check_add(-rhs)
69 }
70
71 pub fn check_div(self, rhs: Decimal) -> Result<Self, EvalError<'ctx>> {
73 if rhs.is_zero() {
74 return Err(EvalError::DivideByZero);
75 }
76 Ok(Self {
77 value: self
78 .value
79 .checked_div(rhs)
80 .ok_or(EvalError::NumberOverflow)?,
81 commodity: self.commodity,
82 })
83 }
84
85 pub fn abs(self) -> Self {
87 Self {
88 value: self.value.abs(),
89 commodity: self.commodity,
90 }
91 }
92
93 pub fn round(self, ctx: &ReportContext) -> Self {
95 match ctx.commodities.get_decimal_point(self.commodity) {
96 None => self,
97 Some(dp) => Self {
98 value: self.value.round_dp_with_strategy(
99 dp,
100 rust_decimal::RoundingStrategy::MidpointNearestEven,
101 ),
102 commodity: self.commodity,
103 },
104 }
105 }
106
107 pub(crate) fn with_sign_of(mut self, sign: Self) -> Self {
109 self.value.set_sign_positive(sign.value.is_sign_positive());
110 self
111 }
112
113 pub fn as_display<'a>(&'a self, ctx: &'a ReportContext<'ctx>) -> impl Display + 'a
115 where
116 'a: 'ctx,
117 {
118 SingleAmountDisplay(self, ctx)
119 }
120}
121
122struct SingleAmountDisplay<'a, 'ctx>(&'a SingleAmount<'ctx>, &'a ReportContext<'ctx>);
123
124impl Display for SingleAmountDisplay<'_, '_> {
125 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
126 write!(
127 f,
128 "{} {}",
129 self.0.value,
130 self.0.commodity.to_str_lossy(&self.1.commodities)
131 )
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 use bumpalo::Bump;
140 use pretty_assertions::assert_eq;
141 use pretty_decimal::PrettyDecimal;
142 use rust_decimal_macros::dec;
143
144 use crate::report::ReportContext;
145
146 #[test]
147 fn neg_returns_negative_value() {
148 let arena = Bump::new();
149 let mut ctx = ReportContext::new(&arena);
150
151 let jpy = ctx.commodities.insert("JPY").unwrap();
152
153 assert_eq!(
154 SingleAmount::from_value(dec!(-5), jpy),
155 -SingleAmount::from_value(dec!(5), jpy)
156 );
157 }
158
159 #[test]
160 fn check_add_fails_different_commodity() {
161 let arena = Bump::new();
162 let mut ctx = ReportContext::new(&arena);
163
164 let jpy = ctx.commodities.insert("JPY").unwrap();
165 let chf = ctx.commodities.insert("CHF").unwrap();
166
167 assert_eq!(
168 Err(EvalError::UnmatchingCommodities(jpy.into(), chf.into())),
169 SingleAmount::from_value(dec!(10), jpy)
170 .check_add(SingleAmount::from_value(dec!(20), chf))
171 );
172 }
173
174 #[test]
175 fn check_add_succeeds() {
176 let arena = Bump::new();
177 let mut ctx = ReportContext::new(&arena);
178
179 let jpy = ctx.commodities.insert("JPY").unwrap();
180
181 assert_eq!(
182 SingleAmount::from_value(dec!(-10), jpy),
183 SingleAmount::from_value(dec!(10), jpy)
184 .check_add(SingleAmount::from_value(dec!(-20), jpy))
185 .unwrap()
186 );
187 }
188
189 #[test]
190 fn check_sub_fails_different_commodity() {
191 let arena = Bump::new();
192 let mut ctx = ReportContext::new(&arena);
193
194 let jpy = ctx.commodities.insert("JPY").unwrap();
195 let chf = ctx.commodities.insert("CHF").unwrap();
196
197 assert_eq!(
198 Err(EvalError::UnmatchingCommodities(jpy.into(), chf.into())),
199 SingleAmount::from_value(dec!(10), jpy)
200 .check_sub(SingleAmount::from_value(dec!(0), chf))
201 );
202 }
203
204 #[test]
205 fn check_sub_succeeds() {
206 let arena = Bump::new();
207 let mut ctx = ReportContext::new(&arena);
208
209 let jpy = ctx.commodities.insert("JPY").unwrap();
210
211 assert_eq!(
212 SingleAmount::from_value(dec!(5), jpy),
213 SingleAmount::from_value(dec!(10), jpy)
214 .check_sub(SingleAmount::from_value(dec!(5), jpy))
215 .unwrap()
216 );
217 }
218
219 #[test]
220 fn single_amount_to_string() {
221 let arena = Bump::new();
222 let mut ctx = ReportContext::new(&arena);
223
224 let usd = ctx.commodities.insert("USD").unwrap();
225
226 assert_eq!(
227 "1.20 USD".to_string(),
228 SingleAmount::from_value(dec!(1.20), usd)
229 .as_display(&ctx)
230 .to_string()
231 );
232 }
233
234 #[test]
235 fn single_amount_round() {
236 let arena = Bump::new();
237 let mut ctx = ReportContext::new(&arena);
238 let jpy = ctx.commodities.ensure("JPY");
239 let eur = ctx.commodities.ensure("EUR");
240 let chf = ctx.commodities.ensure("CHF");
241
242 ctx.commodities
243 .set_format(jpy, PrettyDecimal::comma3dot(dec!(12345)));
244 ctx.commodities
245 .set_format(eur, PrettyDecimal::plain(dec!(123.45)));
246 ctx.commodities
247 .set_format(chf, PrettyDecimal::comma3dot(dec!(123.450)));
248
249 assert_eq!(
251 SingleAmount::from_value(dec!(812), jpy),
252 SingleAmount::from_value(dec!(812), jpy).round(&ctx),
253 );
254 assert_eq!(
255 SingleAmount::from_value(dec!(-100.00), eur),
256 SingleAmount::from_value(dec!(-100.0), eur).round(&ctx),
257 );
258 assert_eq!(
259 SingleAmount::from_value(dec!(6.660), chf),
260 SingleAmount::from_value(dec!(6.66), chf).round(&ctx),
261 );
262
263 assert_eq!(
264 SingleAmount::from_value(dec!(812), jpy),
265 SingleAmount::from_value(dec!(812.5), jpy).round(&ctx),
266 );
267 assert_eq!(
268 SingleAmount::from_value(dec!(-100.02), eur),
269 SingleAmount::from_value(dec!(-100.015), eur).round(&ctx),
270 );
271 assert_eq!(
272 SingleAmount::from_value(dec!(6.666), chf),
273 SingleAmount::from_value(dec!(6.6665), chf).round(&ctx),
274 );
275 }
276
277 #[test]
278 fn with_sign_negative() {
279 let arena = Bump::new();
280 let mut ctx = ReportContext::new(&arena);
281
282 let jpy = ctx.commodities.insert("JPY").unwrap();
283 let eur = ctx.commodities.insert("EUR").unwrap();
284
285 let positive = SingleAmount::from_value(dec!(1000), jpy);
286 assert_eq!(
287 SingleAmount::from_value(dec!(15), eur),
288 SingleAmount::from_value(dec!(15), eur).with_sign_of(positive)
289 );
290 assert_eq!(
291 SingleAmount::from_value(dec!(0), eur),
292 SingleAmount::from_value(dec!(0), eur).with_sign_of(positive)
293 );
294 assert_eq!(
295 SingleAmount::from_value(dec!(15), eur),
296 SingleAmount::from_value(dec!(-15), eur).with_sign_of(positive)
297 );
298
299 let negative = SingleAmount::from_value(dec!(-1000), jpy);
300 assert_eq!(
301 SingleAmount::from_value(dec!(-15), eur),
302 SingleAmount::from_value(dec!(15), eur).with_sign_of(negative)
303 );
304 assert_eq!(
305 SingleAmount::from_value(dec!(0), eur),
306 SingleAmount::from_value(dec!(0), eur).with_sign_of(negative)
307 );
308 assert_eq!(
309 SingleAmount::from_value(dec!(-15), eur),
310 SingleAmount::from_value(dec!(-15), eur).with_sign_of(negative)
311 );
312 }
313}