1use std::collections::HashMap;
2
3use crate::report::eval::EvalError;
4
5use super::{
6 context::Account,
7 eval::{Amount, OwnedEvalError, PostingAmount},
8 ReportContext,
9};
10
11#[derive(Debug, thiserror::Error, PartialEq, Eq)]
13pub enum BalanceError {
14 #[error("balance = 0 should be used on single commodity balance")]
15 MultiCommodityWithPartialSet(#[source] OwnedEvalError, String),
16}
17
18impl BalanceError {
19 pub(super) fn note(&self) -> impl std::fmt::Display + '_ {
20 BalanceErrorNote(self)
21 }
22}
23
24struct BalanceErrorNote<'a>(&'a BalanceError);
25
26impl std::fmt::Display for BalanceErrorNote<'_> {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 match self.0 {
29 BalanceError::MultiCommodityWithPartialSet(_, balance) => {
30 write!(f, "actual: {balance}")
31 }
32 }
33 }
34}
35
36#[derive(Debug, Default, PartialEq, Eq, Clone)]
38pub struct Balance<'ctx> {
39 accounts: HashMap<Account<'ctx>, Amount<'ctx>>,
40}
41
42impl<'ctx> FromIterator<(Account<'ctx>, Amount<'ctx>)> for Balance<'ctx> {
43 fn from_iter<T>(iter: T) -> Self
44 where
45 T: IntoIterator<Item = (Account<'ctx>, Amount<'ctx>)>,
46 {
47 Self {
48 accounts: iter.into_iter().collect(),
49 }
50 }
51}
52
53impl<'ctx> Balance<'ctx> {
54 pub fn from_map(values: HashMap<Account<'ctx>, Amount<'ctx>>) -> Self {
56 Self { accounts: values }
57 }
58
59 pub fn into_map(self) -> HashMap<Account<'ctx>, Amount<'ctx>> {
61 self.accounts
62 }
63
64 pub fn into_vec(self) -> Vec<(Account<'ctx>, Amount<'ctx>)> {
66 let mut ret: Vec<(Account<'ctx>, Amount<'ctx>)> = self.accounts.into_iter().collect();
67 ret.sort_unstable_by_key(|(a, _)| a.as_str());
68 ret
69 }
70
71 pub fn add_amount(&mut self, account: Account<'ctx>, amount: Amount<'ctx>) -> &Amount<'ctx> {
73 let curr: &mut Amount = self.accounts.entry(account).or_default();
74 *curr += amount;
75 curr.remove_zero_entries();
76 curr
77 }
78
79 pub(super) fn add_posting_amount(
81 &mut self,
82 account: Account<'ctx>,
83 amount: PostingAmount<'ctx>,
84 ) -> &Amount<'ctx> {
85 let curr: &mut Amount = self.accounts.entry(account).or_default();
86 *curr += amount;
87 curr.remove_zero_entries();
88 curr
89 }
90
91 pub(super) fn set_partial(
94 &mut self,
95 ctx: &ReportContext<'ctx>,
96 account: Account<'ctx>,
97 amount: PostingAmount<'ctx>,
98 ) -> Result<PostingAmount<'ctx>, BalanceError> {
99 match amount {
100 PostingAmount::Zero => {
101 let prev: Amount<'ctx> = self
102 .accounts
103 .insert(account, Amount::zero())
104 .unwrap_or_default();
105 (&prev).try_into().map_err(|e: EvalError<'_>| {
106 BalanceError::MultiCommodityWithPartialSet(
107 e.into_owned(ctx),
108 prev.as_inline_display(ctx).to_string(),
109 )
110 })
111 }
112 PostingAmount::Single(single_amount) => {
113 let prev = self
114 .accounts
115 .entry(account)
116 .or_default()
117 .set_partial(single_amount);
118 Ok(PostingAmount::Single(prev))
119 }
120 }
121 }
122
123 pub fn get(&self, account: Account<'ctx>) -> Option<&Amount<'ctx>> {
125 self.accounts.get(&account)
126 }
127
128 pub fn round(&mut self, ctx: &ReportContext<'ctx>) {
130 for amount in self.accounts.values_mut() {
131 amount.round_mut(ctx);
132 }
133 }
134
135 pub(crate) fn iter(&self) -> impl Iterator<Item = (&Account<'ctx>, &Amount<'ctx>)> {
137 self.accounts.iter()
138 }
139}
140
141#[cfg(test)]
142mod tests {
143 use super::*;
144
145 use bumpalo::Bump;
146 use maplit::hashmap;
147 use pretty_assertions::assert_eq;
148 use rust_decimal_macros::dec;
149
150 use super::super::context::ReportContext;
151
152 #[test]
153 fn to_from_map() {
154 let arena = Bump::new();
155 let mut ctx = ReportContext::new(&arena);
156 let expenses = ctx.accounts.ensure("Expenses");
157
158 let m = hashmap! {
159 expenses =>
160 Amount::from_value(ctx.commodities.ensure("JPY"), dec!(10)),
161 ctx.accounts.ensure("Income") =>
162 Amount::from_iter([
163 (ctx.commodities.ensure("CHF"), dec!(15)),
164 (ctx.commodities.ensure("USD"), dec!(-5)),
165 ]),
166 };
167
168 let b = Balance::from_map(m.clone());
169 assert_eq!(
170 b.get(expenses),
171 Some(&Amount::from_value(ctx.commodities.ensure("JPY"), dec!(10)))
172 );
173
174 let m2 = b.into_map();
175
176 assert_eq!(m, m2);
177 }
178
179 #[test]
180 fn balance_gives_zero_amount_when_not_initalized() {
181 let arena = Bump::new();
182 let mut ctx = ReportContext::new(&arena);
183
184 let balance = Balance::default();
185 assert_eq!(balance.get(ctx.accounts.ensure("Expenses")), None);
186 }
187
188 #[test]
189 fn test_balance_increment_adds_value() {
190 let arena = Bump::new();
191 let mut ctx = ReportContext::new(&arena);
192
193 let mut balance = Balance::default();
194 let updated = balance
195 .add_posting_amount(
196 ctx.accounts.ensure("Expenses"),
197 PostingAmount::from_value(ctx.commodities.ensure("JPY"), dec!(1000)),
198 )
199 .clone();
200
201 assert_eq!(
202 updated,
203 Amount::from_value(ctx.commodities.ensure("JPY"), dec!(1000))
204 );
205 assert_eq!(balance.get(ctx.accounts.ensure("Expenses")), Some(&updated));
206
207 let updated = balance
208 .add_posting_amount(
209 ctx.accounts.ensure("Expenses"),
210 PostingAmount::from_value(ctx.commodities.ensure("JPY"), dec!(-1000)),
211 )
212 .clone();
213
214 assert_eq!(updated, Amount::zero());
215 assert_eq!(balance.get(ctx.accounts.ensure("Expenses")), Some(&updated));
216 }
217
218 #[test]
219 fn test_balance_set_partial_from_absolute_zero() {
220 let arena = Bump::new();
221 let mut ctx = ReportContext::new(&arena);
222 let mut balance = Balance::default();
223
224 let expenses = ctx.accounts.ensure("Expenses");
225 let jpy = ctx.commodities.insert("JPY").unwrap();
226 let prev = balance
227 .set_partial(&ctx, expenses, PostingAmount::from_value(jpy, dec!(1000)))
228 .unwrap();
229
230 assert_eq!(prev, PostingAmount::from_value(jpy, dec!(0)));
233 assert_eq!(
234 balance.get(expenses),
235 Some(&Amount::from_value(jpy, dec!(1000)))
236 );
237 }
238
239 #[test]
240 fn test_balance_set_partial_hit_same_commodity() {
241 let arena = Bump::new();
242 let mut ctx = ReportContext::new(&arena);
243 let mut balance = Balance::default();
244 let jpy = ctx.commodities.ensure("JPY");
245 balance.add_posting_amount(
246 ctx.accounts.ensure("Expenses"),
247 PostingAmount::from_value(jpy, dec!(1000)),
248 );
249
250 let expenses = ctx.accounts.ensure("Expenses");
251
252 let prev = balance
253 .set_partial(&ctx, expenses, PostingAmount::from_value(jpy, dec!(-1000)))
254 .unwrap();
255
256 assert_eq!(prev, PostingAmount::from_value(jpy, dec!(1000)));
257 assert_eq!(
258 balance.get(ctx.accounts.ensure("Expenses")),
259 Some(&Amount::from_value(jpy, dec!(-1000)))
260 );
261 }
262
263 #[test]
264 fn test_balance_set_partial_multi_commodities() {
265 let arena = Bump::new();
266 let mut ctx = ReportContext::new(&arena);
267 let mut balance = Balance::default();
268 let jpy = ctx.commodities.ensure("JPY");
269 let chf = ctx.commodities.ensure("CHF");
270 balance.add_posting_amount(
271 ctx.accounts.ensure("Expenses"),
272 PostingAmount::from_value(jpy, dec!(1000)),
273 );
274 balance.add_posting_amount(
275 ctx.accounts.ensure("Expenses"),
276 PostingAmount::from_value(chf, dec!(200)),
277 );
278
279 let expenses = ctx.accounts.ensure("Expenses");
280
281 let prev = balance
282 .set_partial(&ctx, expenses, PostingAmount::from_value(chf, dec!(100)))
283 .unwrap();
284
285 assert_eq!(prev, PostingAmount::from_value(chf, dec!(200)));
286 assert_eq!(
287 balance.get(ctx.accounts.ensure("Expenses")),
288 Some(&Amount::from_iter([(jpy, dec!(1000)), (chf, dec!(100)),]))
289 );
290 }
291
292 #[test]
293 fn test_balance_set_partial_zero_on_zero() {
294 let arena = Bump::new();
295 let mut ctx = ReportContext::new(&arena);
296 let mut balance = Balance::default();
297
298 let expenses = ctx.accounts.ensure("Expenses");
299
300 let prev = balance
301 .set_partial(&ctx, expenses, PostingAmount::zero())
302 .unwrap();
303
304 assert_eq!(prev, PostingAmount::zero());
305 assert_eq!(
306 balance.get(ctx.accounts.ensure("Expenses")),
307 Some(&Amount::zero())
308 );
309 }
310
311 #[test]
312 fn test_balance_set_partial_zero_on_single_commodity() {
313 let arena = Bump::new();
314 let mut ctx = ReportContext::new(&arena);
315 let mut balance = Balance::default();
316 let jpy = ctx.commodities.ensure("JPY");
317 balance.add_posting_amount(
318 ctx.accounts.ensure("Expenses"),
319 PostingAmount::from_value(jpy, dec!(1000)),
320 );
321
322 let expenses = ctx.accounts.ensure("Expenses");
323
324 let prev = balance
325 .set_partial(&ctx, expenses, PostingAmount::zero())
326 .unwrap();
327
328 assert_eq!(prev, PostingAmount::from_value(jpy, dec!(1000)));
329 assert_eq!(
330 balance.get(ctx.accounts.ensure("Expenses")),
331 Some(&Amount::zero())
332 );
333 }
334
335 #[test]
336 fn test_balance_set_partial_zero_fails_on_multi_commodities() {
337 let arena = Bump::new();
338 let mut ctx = ReportContext::new(&arena);
339 let mut balance = Balance::default();
340 balance.add_posting_amount(
341 ctx.accounts.ensure("Expenses"),
342 PostingAmount::from_value(ctx.commodities.ensure("JPY"), dec!(1000)),
343 );
344 balance.add_posting_amount(
345 ctx.accounts.ensure("Expenses"),
346 PostingAmount::from_value(ctx.commodities.ensure("CHF"), dec!(200)),
347 );
348
349 let expenses = ctx.accounts.ensure("Expenses");
350
351 let err = balance
352 .set_partial(&ctx, expenses, PostingAmount::zero())
353 .unwrap_err();
354
355 assert_eq!(
356 err,
357 BalanceError::MultiCommodityWithPartialSet(
358 OwnedEvalError::PostingAmountRequired,
359 "(1000 JPY + 200 CHF)".to_string()
360 )
361 );
362 }
363}