1use std::{
2 fmt::Display,
3 ops::{Mul, Neg},
4};
5
6use rust_decimal::Decimal;
7
8use crate::report::{commodity::Commodity, ReportContext};
9
10use super::error::EvalError;
11
12#[derive(Debug, PartialEq, Eq, Clone, Copy)]
14pub struct SingleAmount<'ctx> {
15 pub(crate) value: Decimal,
16 pub(crate) commodity: Commodity<'ctx>,
17}
18
19impl Neg for SingleAmount<'_> {
20 type Output = Self;
21
22 fn neg(self) -> Self::Output {
23 SingleAmount {
24 value: -self.value,
25 commodity: self.commodity,
26 }
27 }
28}
29
30impl Mul<Decimal> for SingleAmount<'_> {
31 type Output = Self;
32
33 fn mul(self, rhs: Decimal) -> Self::Output {
34 Self {
35 value: self.value * rhs,
36 commodity: self.commodity,
37 }
38 }
39}
40
41impl<'ctx> SingleAmount<'ctx> {
42 #[inline]
44 pub fn from_value(value: Decimal, commodity: Commodity<'ctx>) -> Self {
45 Self { value, commodity }
46 }
47
48 pub fn check_add(self, rhs: Self) -> Result<Self, EvalError> {
50 if self.commodity != rhs.commodity {
51 Err(EvalError::UnmatchingCommodities(
52 self.commodity.into(),
53 rhs.commodity.into(),
54 ))
55 } else {
56 Ok(Self {
57 value: self
58 .value
59 .checked_add(rhs.value)
60 .ok_or(EvalError::NumberOverflow)?,
61 commodity: self.commodity,
62 })
63 }
64 }
65
66 pub fn check_sub(self, rhs: Self) -> Result<Self, EvalError> {
68 self.check_add(-rhs)
69 }
70
71 pub fn check_div(self, rhs: Decimal) -> Result<Self, EvalError> {
73 if rhs.is_zero() {
74 return Err(EvalError::DivideByZero);
75 }
76 Ok(Self {
77 value: self
78 .value
79 .checked_div(rhs)
80 .ok_or(EvalError::NumberOverflow)?,
81 commodity: self.commodity,
82 })
83 }
84
85 pub fn abs(self) -> Self {
87 Self {
88 value: self.value.abs(),
89 commodity: self.commodity,
90 }
91 }
92
93 pub fn round(self, ctx: &ReportContext) -> Self {
95 match ctx.commodities.get_decimal_point(self.commodity) {
96 None => self,
97 Some(dp) => Self {
98 value: self.value.round_dp_with_strategy(
99 dp,
100 rust_decimal::RoundingStrategy::MidpointNearestEven,
101 ),
102 commodity: self.commodity,
103 },
104 }
105 }
106
107 pub(crate) fn with_sign_of(mut self, sign: Self) -> Self {
109 self.value.set_sign_positive(sign.value.is_sign_positive());
110 self
111 }
112}
113
114impl Display for SingleAmount<'_> {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 write!(f, "{} {}", self.value, self.commodity.as_str())
117 }
118}
119
120#[cfg(test)]
121mod tests {
122 use super::*;
123
124 use bumpalo::Bump;
125 use pretty_assertions::assert_eq;
126 use rust_decimal_macros::dec;
127
128 use crate::{report::ReportContext, syntax::pretty_decimal::PrettyDecimal};
129
130 #[test]
131 fn neg_returns_negative_value() {
132 let arena = Bump::new();
133 let mut ctx = ReportContext::new(&arena);
134
135 let jpy = ctx.commodities.insert_canonical("JPY").unwrap();
136
137 assert_eq!(
138 SingleAmount::from_value(dec!(-5), jpy),
139 -SingleAmount::from_value(dec!(5), jpy)
140 );
141 }
142
143 #[test]
144 fn check_add_fails_different_commodity() {
145 let arena = Bump::new();
146 let mut ctx = ReportContext::new(&arena);
147
148 let jpy = ctx.commodities.insert_canonical("JPY").unwrap();
149 let chf = ctx.commodities.insert_canonical("CHF").unwrap();
150
151 assert_eq!(
152 Err(EvalError::UnmatchingCommodities(jpy.into(), chf.into())),
153 SingleAmount::from_value(dec!(10), jpy)
154 .check_add(SingleAmount::from_value(dec!(20), chf))
155 );
156 }
157
158 #[test]
159 fn check_add_succeeds() {
160 let arena = Bump::new();
161 let mut ctx = ReportContext::new(&arena);
162
163 let jpy = ctx.commodities.insert_canonical("JPY").unwrap();
164
165 assert_eq!(
166 SingleAmount::from_value(dec!(-10), jpy),
167 SingleAmount::from_value(dec!(10), jpy)
168 .check_add(SingleAmount::from_value(dec!(-20), jpy))
169 .unwrap()
170 );
171 }
172
173 #[test]
174 fn check_sub_fails_different_commodity() {
175 let arena = Bump::new();
176 let mut ctx = ReportContext::new(&arena);
177
178 let jpy = ctx.commodities.insert_canonical("JPY").unwrap();
179 let chf = ctx.commodities.insert_canonical("CHF").unwrap();
180
181 assert_eq!(
182 Err(EvalError::UnmatchingCommodities(jpy.into(), chf.into())),
183 SingleAmount::from_value(dec!(10), jpy)
184 .check_sub(SingleAmount::from_value(dec!(0), chf))
185 );
186 }
187
188 #[test]
189 fn check_sub_succeeds() {
190 let arena = Bump::new();
191 let mut ctx = ReportContext::new(&arena);
192
193 let jpy = ctx.commodities.insert_canonical("JPY").unwrap();
194
195 assert_eq!(
196 SingleAmount::from_value(dec!(5), jpy),
197 SingleAmount::from_value(dec!(10), jpy)
198 .check_sub(SingleAmount::from_value(dec!(5), jpy))
199 .unwrap()
200 );
201 }
202
203 #[test]
204 fn single_amount_to_string() {
205 let arena = Bump::new();
206 let mut ctx = ReportContext::new(&arena);
207
208 let usd = ctx.commodities.insert_canonical("USD").unwrap();
209
210 assert_eq!(
211 "1.20 USD".to_string(),
212 SingleAmount::from_value(dec!(1.20), usd).to_string()
213 );
214 }
215
216 #[test]
217 fn single_amount_round() {
218 let arena = Bump::new();
219 let mut ctx = ReportContext::new(&arena);
220 let jpy = ctx.commodities.ensure("JPY");
221 let eur = ctx.commodities.ensure("EUR");
222 let chf = ctx.commodities.ensure("CHF");
223
224 ctx.commodities
225 .set_format(jpy, PrettyDecimal::comma3dot(dec!(12345)));
226 ctx.commodities
227 .set_format(eur, PrettyDecimal::plain(dec!(123.45)));
228 ctx.commodities
229 .set_format(chf, PrettyDecimal::comma3dot(dec!(123.450)));
230
231 assert_eq!(
233 SingleAmount::from_value(dec!(812), jpy),
234 SingleAmount::from_value(dec!(812), jpy).round(&ctx),
235 );
236 assert_eq!(
237 SingleAmount::from_value(dec!(-100.00), eur),
238 SingleAmount::from_value(dec!(-100.0), eur).round(&ctx),
239 );
240 assert_eq!(
241 SingleAmount::from_value(dec!(6.660), chf),
242 SingleAmount::from_value(dec!(6.66), chf).round(&ctx),
243 );
244
245 assert_eq!(
246 SingleAmount::from_value(dec!(812), jpy),
247 SingleAmount::from_value(dec!(812.5), jpy).round(&ctx),
248 );
249 assert_eq!(
250 SingleAmount::from_value(dec!(-100.02), eur),
251 SingleAmount::from_value(dec!(-100.015), eur).round(&ctx),
252 );
253 assert_eq!(
254 SingleAmount::from_value(dec!(6.666), chf),
255 SingleAmount::from_value(dec!(6.6665), chf).round(&ctx),
256 );
257 }
258
259 #[test]
260 fn with_sign_negative() {
261 let arena = Bump::new();
262 let mut ctx = ReportContext::new(&arena);
263
264 let jpy = ctx.commodities.insert_canonical("JPY").unwrap();
265 let eur = ctx.commodities.insert_canonical("EUR").unwrap();
266
267 let positive = SingleAmount::from_value(dec!(1000), jpy);
268 assert_eq!(
269 SingleAmount::from_value(dec!(15), eur),
270 SingleAmount::from_value(dec!(15), eur).with_sign_of(positive)
271 );
272 assert_eq!(
273 SingleAmount::from_value(dec!(0), eur),
274 SingleAmount::from_value(dec!(0), eur).with_sign_of(positive)
275 );
276 assert_eq!(
277 SingleAmount::from_value(dec!(15), eur),
278 SingleAmount::from_value(dec!(-15), eur).with_sign_of(positive)
279 );
280
281 let negative = SingleAmount::from_value(dec!(-1000), jpy);
282 assert_eq!(
283 SingleAmount::from_value(dec!(-15), eur),
284 SingleAmount::from_value(dec!(15), eur).with_sign_of(negative)
285 );
286 assert_eq!(
287 SingleAmount::from_value(dec!(0), eur),
288 SingleAmount::from_value(dec!(0), eur).with_sign_of(negative)
289 );
290 assert_eq!(
291 SingleAmount::from_value(dec!(-15), eur),
292 SingleAmount::from_value(dec!(-15), eur).with_sign_of(negative)
293 );
294 }
295}