1use std::collections::HashMap;
2
3use crate::report::eval::EvalError;
4
5use super::{
6 context::Account,
7 eval::{Amount, OwnedEvalError, PostingAmount},
8 ReportContext,
9};
10
11#[derive(Debug, thiserror::Error, PartialEq, Eq)]
13pub enum BalanceError {
14 #[error("balance = 0 cannot deduce posting amount when balance has multi commodities")]
15 MultiCommodityWithPartialSet(#[from] OwnedEvalError),
16}
17
18#[derive(Debug, Default, PartialEq, Eq, Clone)]
20pub struct Balance<'ctx> {
21 accounts: HashMap<Account<'ctx>, Amount<'ctx>>,
22}
23
24impl<'ctx> FromIterator<(Account<'ctx>, Amount<'ctx>)> for Balance<'ctx> {
25 fn from_iter<T>(iter: T) -> Self
27 where
28 T: IntoIterator<Item = (Account<'ctx>, Amount<'ctx>)>,
29 {
30 Self {
31 accounts: iter.into_iter().collect(),
32 }
33 }
34}
35
36impl<'ctx> Balance<'ctx> {
37 pub fn add_amount(&mut self, account: Account<'ctx>, amount: Amount<'ctx>) -> &Amount<'ctx> {
39 let curr: &mut Amount = self.accounts.entry(account).or_default();
40 *curr += amount;
41 curr.remove_zero_entries();
42 curr
43 }
44
45 pub(super) fn add_posting_amount(
47 &mut self,
48 account: Account<'ctx>,
49 amount: PostingAmount<'ctx>,
50 ) -> &Amount<'ctx> {
51 let curr: &mut Amount = self.accounts.entry(account).or_default();
52 *curr += amount;
53 curr.remove_zero_entries();
54 curr
55 }
56
57 pub(super) fn set_partial(
60 &mut self,
61 ctx: &ReportContext<'ctx>,
62 account: Account<'ctx>,
63 amount: PostingAmount<'ctx>,
64 ) -> Result<PostingAmount<'ctx>, BalanceError> {
65 match amount {
66 PostingAmount::Zero => {
67 let prev: Amount<'ctx> = self
68 .accounts
69 .insert(account, Amount::zero())
70 .unwrap_or_default();
71 (&prev).try_into().map_err(|e: EvalError<'_>| {
72 BalanceError::MultiCommodityWithPartialSet(e.into_owned(ctx))
73 })
74 }
75 PostingAmount::Single(single_amount) => {
76 let prev = self
77 .accounts
78 .entry(account)
79 .or_default()
80 .set_partial(single_amount);
81 Ok(PostingAmount::Single(prev))
82 }
83 }
84 }
85
86 pub fn get(&self, account: &Account<'ctx>) -> Option<&Amount<'ctx>> {
88 self.accounts.get(account)
89 }
90
91 pub fn round(&mut self, ctx: &ReportContext<'ctx>) {
93 for amount in self.accounts.values_mut() {
94 amount.round_mut(ctx);
95 }
96 }
97
98 pub(crate) fn iter(&self) -> impl Iterator<Item = (&Account<'ctx>, &Amount<'ctx>)> {
100 self.accounts.iter()
101 }
102
103 pub fn into_vec(self) -> Vec<(Account<'ctx>, Amount<'ctx>)> {
105 let mut ret: Vec<(Account<'ctx>, Amount<'ctx>)> = self.accounts.into_iter().collect();
106 ret.sort_unstable_by_key(|(a, _)| a.as_str());
107 ret
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114
115 use bumpalo::Bump;
116 use pretty_assertions::assert_eq;
117 use rust_decimal_macros::dec;
118
119 use super::super::context::ReportContext;
120
121 #[test]
122 fn balance_gives_zero_amount_when_not_initalized() {
123 let arena = Bump::new();
124 let mut ctx = ReportContext::new(&arena);
125
126 let balance = Balance::default();
127 assert_eq!(balance.get(&ctx.accounts.ensure("Expenses")), None);
128 }
129
130 #[test]
131 fn test_balance_increment_adds_value() {
132 let arena = Bump::new();
133 let mut ctx = ReportContext::new(&arena);
134
135 let mut balance = Balance::default();
136 let updated = balance
137 .add_posting_amount(
138 ctx.accounts.ensure("Expenses"),
139 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
140 )
141 .clone();
142
143 assert_eq!(
144 updated,
145 Amount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
146 );
147 assert_eq!(
148 balance.get(&ctx.accounts.ensure("Expenses")),
149 Some(&updated)
150 );
151
152 let updated = balance
153 .add_posting_amount(
154 ctx.accounts.ensure("Expenses"),
155 PostingAmount::from_value(dec!(-1000), ctx.commodities.ensure("JPY")),
156 )
157 .clone();
158
159 assert_eq!(updated, Amount::zero());
160 assert_eq!(
161 balance.get(&ctx.accounts.ensure("Expenses")),
162 Some(&updated)
163 );
164 }
165
166 #[test]
167 fn test_balance_set_partial_from_absolute_zero() {
168 let arena = Bump::new();
169 let mut ctx = ReportContext::new(&arena);
170 let mut balance = Balance::default();
171
172 let expenses = ctx.accounts.ensure("Expenses");
173 let jpy = ctx.commodities.insert("JPY").unwrap();
174 let prev = balance
175 .set_partial(&ctx, expenses, PostingAmount::from_value(dec!(1000), jpy))
176 .unwrap();
177
178 assert_eq!(prev, PostingAmount::from_value(dec!(0), jpy));
181 assert_eq!(
182 balance.get(&expenses),
183 Some(&Amount::from_value(dec!(1000), jpy))
184 );
185 }
186
187 #[test]
188 fn test_balance_set_partial_hit_same_commodity() {
189 let arena = Bump::new();
190 let mut ctx = ReportContext::new(&arena);
191 let mut balance = Balance::default();
192 balance.add_posting_amount(
193 ctx.accounts.ensure("Expenses"),
194 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
195 );
196
197 let expenses = ctx.accounts.ensure("Expenses");
198 let jpy = ctx.commodities.ensure("JPY");
199
200 let prev = balance
201 .set_partial(&ctx, expenses, PostingAmount::from_value(dec!(-1000), jpy))
202 .unwrap();
203
204 assert_eq!(
205 prev,
206 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
207 );
208 assert_eq!(
209 balance.get(&ctx.accounts.ensure("Expenses")),
210 Some(&Amount::from_value(
211 dec!(-1000),
212 ctx.commodities.ensure("JPY")
213 ))
214 );
215 }
216
217 #[test]
218 fn test_balance_set_partial_multi_commodities() {
219 let arena = Bump::new();
220 let mut ctx = ReportContext::new(&arena);
221 let mut balance = Balance::default();
222 balance.add_posting_amount(
223 ctx.accounts.ensure("Expenses"),
224 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
225 );
226 balance.add_posting_amount(
227 ctx.accounts.ensure("Expenses"),
228 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF")),
229 );
230
231 let expenses = ctx.accounts.ensure("Expenses");
232 let chf = ctx.commodities.ensure("CHF");
233
234 let prev = balance
235 .set_partial(&ctx, expenses, PostingAmount::from_value(dec!(100), chf))
236 .unwrap();
237
238 assert_eq!(
239 prev,
240 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF"))
241 );
242 assert_eq!(
243 balance.get(&ctx.accounts.ensure("Expenses")),
244 Some(&Amount::from_values([
245 (dec!(1000), ctx.commodities.ensure("JPY")),
246 (dec!(100), ctx.commodities.ensure("CHF")),
247 ]))
248 );
249 }
250
251 #[test]
252 fn test_balance_set_partial_zero_on_zero() {
253 let arena = Bump::new();
254 let mut ctx = ReportContext::new(&arena);
255 let mut balance = Balance::default();
256
257 let expenses = ctx.accounts.ensure("Expenses");
258
259 let prev = balance
260 .set_partial(&ctx, expenses, PostingAmount::zero())
261 .unwrap();
262
263 assert_eq!(prev, PostingAmount::zero());
264 assert_eq!(
265 balance.get(&ctx.accounts.ensure("Expenses")),
266 Some(&Amount::zero())
267 );
268 }
269
270 #[test]
271 fn test_balance_set_partial_zero_on_single_commodity() {
272 let arena = Bump::new();
273 let mut ctx = ReportContext::new(&arena);
274 let mut balance = Balance::default();
275 balance.add_posting_amount(
276 ctx.accounts.ensure("Expenses"),
277 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
278 );
279
280 let expenses = ctx.accounts.ensure("Expenses");
281
282 let prev = balance
283 .set_partial(&ctx, expenses, PostingAmount::zero())
284 .unwrap();
285
286 assert_eq!(
287 prev,
288 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
289 );
290 assert_eq!(
291 balance.get(&ctx.accounts.ensure("Expenses")),
292 Some(&Amount::zero())
293 );
294 }
295
296 #[test]
297 fn test_balance_set_partial_zero_fails_on_multi_commodities() {
298 let arena = Bump::new();
299 let mut ctx = ReportContext::new(&arena);
300 let mut balance = Balance::default();
301 balance.add_posting_amount(
302 ctx.accounts.ensure("Expenses"),
303 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
304 );
305 balance.add_posting_amount(
306 ctx.accounts.ensure("Expenses"),
307 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF")),
308 );
309
310 let expenses = ctx.accounts.ensure("Expenses");
311
312 let err = balance
313 .set_partial(&ctx, expenses, PostingAmount::zero())
314 .unwrap_err();
315
316 assert_eq!(
317 err,
318 BalanceError::MultiCommodityWithPartialSet(OwnedEvalError::PostingAmountRequired)
319 );
320 }
321}