1use std::collections::HashMap;
2
3use super::{
4 context::Account,
5 eval::{Amount, EvalError, PostingAmount},
6 ReportContext,
7};
8
9#[derive(Debug, thiserror::Error, PartialEq, Eq)]
11pub enum BalanceError {
12 #[error("balance = 0 cannot deduce posting amount when balance has multi commodities")]
13 MultiCommodityWithPartialSet(#[from] EvalError),
14}
15
16#[derive(Debug, Default, PartialEq, Eq, Clone)]
18pub struct Balance<'ctx> {
19 accounts: HashMap<Account<'ctx>, Amount<'ctx>>,
20}
21
22impl<'ctx> FromIterator<(Account<'ctx>, Amount<'ctx>)> for Balance<'ctx> {
23 fn from_iter<T>(iter: T) -> Self
25 where
26 T: IntoIterator<Item = (Account<'ctx>, Amount<'ctx>)>,
27 {
28 Self {
29 accounts: iter.into_iter().collect(),
30 }
31 }
32}
33
34impl<'ctx> Balance<'ctx> {
35 pub fn add_amount(&mut self, account: Account<'ctx>, amount: Amount<'ctx>) -> &Amount<'ctx> {
37 let curr: &mut Amount = self.accounts.entry(account).or_default();
38 *curr += amount;
39 curr.remove_zero_entries();
40 curr
41 }
42
43 pub(super) fn add_posting_amount(
45 &mut self,
46 account: Account<'ctx>,
47 amount: PostingAmount<'ctx>,
48 ) -> &Amount<'ctx> {
49 let curr: &mut Amount = self.accounts.entry(account).or_default();
50 *curr += amount;
51 curr.remove_zero_entries();
52 curr
53 }
54
55 pub(super) fn set_partial(
58 &mut self,
59 account: Account<'ctx>,
60 amount: PostingAmount<'ctx>,
61 ) -> Result<PostingAmount<'ctx>, BalanceError> {
62 match amount {
63 PostingAmount::Zero => {
64 let prev: Amount<'ctx> = self
65 .accounts
66 .insert(account, Amount::zero())
67 .unwrap_or_default();
68 (&prev)
69 .try_into()
70 .map_err(BalanceError::MultiCommodityWithPartialSet)
71 }
72 PostingAmount::Single(single_amount) => {
73 let prev = self
74 .accounts
75 .entry(account)
76 .or_default()
77 .set_partial(single_amount);
78 Ok(PostingAmount::Single(prev))
79 }
80 }
81 }
82
83 pub fn get(&self, account: &Account<'ctx>) -> Option<&Amount<'ctx>> {
85 self.accounts.get(account)
86 }
87
88 pub fn round(&mut self, ctx: &ReportContext<'ctx>) {
90 for amount in self.accounts.values_mut() {
91 amount.round_mut(ctx);
92 }
93 }
94
95 pub(crate) fn iter(&self) -> impl Iterator<Item = (&Account<'ctx>, &Amount<'ctx>)> {
97 self.accounts.iter()
98 }
99
100 pub fn into_vec(self) -> Vec<(Account<'ctx>, Amount<'ctx>)> {
102 let mut ret: Vec<(Account<'ctx>, Amount<'ctx>)> = self.accounts.into_iter().collect();
103 ret.sort_unstable_by_key(|(a, _)| a.as_str());
104 ret
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 use bumpalo::Bump;
113 use pretty_assertions::assert_eq;
114 use rust_decimal_macros::dec;
115
116 use super::super::context::ReportContext;
117
118 #[test]
119 fn balance_gives_zero_amount_when_not_initalized() {
120 let arena = Bump::new();
121 let mut ctx = ReportContext::new(&arena);
122
123 let balance = Balance::default();
124 assert_eq!(balance.get(&ctx.accounts.ensure("Expenses")), None);
125 }
126
127 #[test]
128 fn test_balance_increment_adds_value() {
129 let arena = Bump::new();
130 let mut ctx = ReportContext::new(&arena);
131
132 let mut balance = Balance::default();
133 let updated = balance
134 .add_posting_amount(
135 ctx.accounts.ensure("Expenses"),
136 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
137 )
138 .clone();
139
140 assert_eq!(
141 updated,
142 Amount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
143 );
144 assert_eq!(
145 balance.get(&ctx.accounts.ensure("Expenses")),
146 Some(&updated)
147 );
148
149 let updated = balance
150 .add_posting_amount(
151 ctx.accounts.ensure("Expenses"),
152 PostingAmount::from_value(dec!(-1000), ctx.commodities.ensure("JPY")),
153 )
154 .clone();
155
156 assert_eq!(updated, Amount::zero());
157 assert_eq!(
158 balance.get(&ctx.accounts.ensure("Expenses")),
159 Some(&updated)
160 );
161 }
162
163 #[test]
164 fn test_balance_set_partial_from_absolute_zero() {
165 let arena = Bump::new();
166 let mut ctx = ReportContext::new(&arena);
167 let mut balance = Balance::default();
168
169 let prev = balance
170 .set_partial(
171 ctx.accounts.ensure("Expenses"),
172 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
173 )
174 .unwrap();
175
176 assert_eq!(
179 prev,
180 PostingAmount::from_value(dec!(0), ctx.commodities.ensure("JPY"))
181 );
182 assert_eq!(
183 balance.get(&ctx.accounts.ensure("Expenses")),
184 Some(&Amount::from_value(
185 dec!(1000),
186 ctx.commodities.ensure("JPY")
187 ))
188 );
189 }
190
191 #[test]
192 fn test_balance_set_partial_hit_same_commodity() {
193 let arena = Bump::new();
194 let mut ctx = ReportContext::new(&arena);
195 let mut balance = Balance::default();
196 balance.add_posting_amount(
197 ctx.accounts.ensure("Expenses"),
198 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
199 );
200
201 let prev = balance
202 .set_partial(
203 ctx.accounts.ensure("Expenses"),
204 PostingAmount::from_value(dec!(-1000), ctx.commodities.ensure("JPY")),
205 )
206 .unwrap();
207
208 assert_eq!(
209 prev,
210 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
211 );
212 assert_eq!(
213 balance.get(&ctx.accounts.ensure("Expenses")),
214 Some(&Amount::from_value(
215 dec!(-1000),
216 ctx.commodities.ensure("JPY")
217 ))
218 );
219 }
220
221 #[test]
222 fn test_balance_set_partial_multi_commodities() {
223 let arena = Bump::new();
224 let mut ctx = ReportContext::new(&arena);
225 let mut balance = Balance::default();
226 balance.add_posting_amount(
227 ctx.accounts.ensure("Expenses"),
228 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
229 );
230 balance.add_posting_amount(
231 ctx.accounts.ensure("Expenses"),
232 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF")),
233 );
234
235 let prev = balance
236 .set_partial(
237 ctx.accounts.ensure("Expenses"),
238 PostingAmount::from_value(dec!(100), ctx.commodities.ensure("CHF")),
239 )
240 .unwrap();
241
242 assert_eq!(
243 prev,
244 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF"))
245 );
246 assert_eq!(
247 balance.get(&ctx.accounts.ensure("Expenses")),
248 Some(&Amount::from_values([
249 (dec!(1000), ctx.commodities.ensure("JPY")),
250 (dec!(100), ctx.commodities.ensure("CHF")),
251 ]))
252 );
253 }
254
255 #[test]
256 fn test_balance_set_partial_zero_on_zero() {
257 let arena = Bump::new();
258 let mut ctx = ReportContext::new(&arena);
259 let mut balance = Balance::default();
260
261 let prev = balance
262 .set_partial(ctx.accounts.ensure("Expenses"), PostingAmount::zero())
263 .unwrap();
264
265 assert_eq!(prev, PostingAmount::zero());
266 assert_eq!(
267 balance.get(&ctx.accounts.ensure("Expenses")),
268 Some(&Amount::zero())
269 );
270 }
271
272 #[test]
273 fn test_balance_set_partial_zero_on_single_commodity() {
274 let arena = Bump::new();
275 let mut ctx = ReportContext::new(&arena);
276 let mut balance = Balance::default();
277 balance.add_posting_amount(
278 ctx.accounts.ensure("Expenses"),
279 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
280 );
281
282 let prev = balance
283 .set_partial(ctx.accounts.ensure("Expenses"), PostingAmount::zero())
284 .unwrap();
285
286 assert_eq!(
287 prev,
288 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
289 );
290 assert_eq!(
291 balance.get(&ctx.accounts.ensure("Expenses")),
292 Some(&Amount::zero())
293 );
294 }
295
296 #[test]
297 fn test_balance_set_partial_zero_fails_on_multi_commodities() {
298 let arena = Bump::new();
299 let mut ctx = ReportContext::new(&arena);
300 let mut balance = Balance::default();
301 balance.add_posting_amount(
302 ctx.accounts.ensure("Expenses"),
303 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
304 );
305 balance.add_posting_amount(
306 ctx.accounts.ensure("Expenses"),
307 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF")),
308 );
309
310 let err = balance
311 .set_partial(ctx.accounts.ensure("Expenses"), PostingAmount::zero())
312 .unwrap_err();
313
314 assert_eq!(
315 err,
316 BalanceError::MultiCommodityWithPartialSet(EvalError::PostingAmountRequired)
317 );
318 }
319}