1use std::collections::HashMap;
2
3use super::{
4 context::Account,
5 eval::{Amount, EvalError, PostingAmount},
6};
7
8#[derive(Debug, thiserror::Error, PartialEq, Eq)]
10pub enum BalanceError {
11 #[error("balance = 0 cannot deduce posting amount when balance has multi commodities")]
12 MultiCommodityWithPartialSet(#[from] EvalError),
13}
14
15#[derive(Debug, Default, PartialEq, Eq, Clone)]
17pub struct Balance<'ctx> {
18 accounts: HashMap<Account<'ctx>, Amount<'ctx>>,
19}
20
21impl<'ctx> FromIterator<(Account<'ctx>, Amount<'ctx>)> for Balance<'ctx> {
22 fn from_iter<T>(iter: T) -> Self
24 where
25 T: IntoIterator<Item = (Account<'ctx>, Amount<'ctx>)>,
26 {
27 Self {
28 accounts: iter.into_iter().collect(),
29 }
30 }
31}
32
33impl<'ctx> Balance<'ctx> {
34 pub fn add_amount(&mut self, account: Account<'ctx>, amount: Amount<'ctx>) -> &Amount<'ctx> {
36 let curr: &mut Amount = self.accounts.entry(account).or_default();
37 *curr += amount;
38 curr
39 }
40
41 pub fn add_posting_amount(
43 &mut self,
44 account: Account<'ctx>,
45 amount: PostingAmount<'ctx>,
46 ) -> &Amount<'ctx> {
47 let curr: &mut Amount = self.accounts.entry(account).or_default();
48 *curr += amount;
49 curr
50 }
51
52 pub fn set_partial(
55 &mut self,
56 account: Account<'ctx>,
57 amount: PostingAmount<'ctx>,
58 ) -> Result<PostingAmount<'ctx>, BalanceError> {
59 match amount {
60 PostingAmount::Zero => {
61 let prev: Amount<'ctx> = self
62 .accounts
63 .insert(account, amount.into())
64 .unwrap_or_default();
65 (&prev)
66 .try_into()
67 .map_err(BalanceError::MultiCommodityWithPartialSet)
68 }
69 PostingAmount::Single(single_amount) => {
70 let prev = self
71 .accounts
72 .entry(account)
73 .or_default()
74 .set_partial(single_amount);
75 Ok(PostingAmount::Single(prev))
76 }
77 }
78 }
79
80 pub fn get(&self, account: &Account<'ctx>) -> Option<&Amount<'ctx>> {
82 self.accounts.get(account)
83 }
84
85 pub fn into_vec(self) -> Vec<(Account<'ctx>, Amount<'ctx>)> {
87 let mut ret: Vec<(Account<'ctx>, Amount<'ctx>)> = self.accounts.into_iter().collect();
88 ret.sort_unstable_by_key(|(a, _)| a.as_str());
89 ret
90 }
91}
92
93#[cfg(test)]
94mod tests {
95 use super::*;
96
97 use bumpalo::Bump;
98 use pretty_assertions::assert_eq;
99 use rust_decimal_macros::dec;
100
101 use super::super::context::ReportContext;
102
103 #[test]
104 fn balance_gives_zero_amount_when_not_initalized() {
105 let arena = Bump::new();
106 let mut ctx = ReportContext::new(&arena);
107
108 let balance = Balance::default();
109 assert_eq!(balance.get(&ctx.accounts.ensure("Expenses")), None);
110 }
111
112 #[test]
113 fn test_balance_increment_adds_value() {
114 let arena = Bump::new();
115 let mut ctx = ReportContext::new(&arena);
116
117 let mut balance = Balance::default();
118 let updated = balance
119 .add_posting_amount(
120 ctx.accounts.ensure("Expenses"),
121 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
122 )
123 .clone();
124
125 assert_eq!(
126 updated,
127 Amount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
128 );
129 assert_eq!(
130 balance.get(&ctx.accounts.ensure("Expenses")),
131 Some(&updated)
132 );
133
134 let updated = balance
135 .add_posting_amount(
136 ctx.accounts.ensure("Expenses"),
137 PostingAmount::from_value(dec!(-1000), ctx.commodities.ensure("JPY")),
138 )
139 .clone();
140
141 assert_eq!(updated, Amount::zero());
142 assert_eq!(
143 balance.get(&ctx.accounts.ensure("Expenses")),
144 Some(&updated)
145 );
146 }
147
148 #[test]
149 fn test_balance_set_partial_from_absolute_zero() {
150 let arena = Bump::new();
151 let mut ctx = ReportContext::new(&arena);
152 let mut balance = Balance::default();
153
154 let prev = balance
155 .set_partial(
156 ctx.accounts.ensure("Expenses"),
157 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
158 )
159 .unwrap();
160
161 assert_eq!(
164 prev,
165 PostingAmount::from_value(dec!(0), ctx.commodities.ensure("JPY"))
166 );
167 assert_eq!(
168 balance.get(&ctx.accounts.ensure("Expenses")),
169 Some(&Amount::from_value(
170 dec!(1000),
171 ctx.commodities.ensure("JPY")
172 ))
173 );
174 }
175
176 #[test]
177 fn test_balance_set_partial_hit_same_commodity() {
178 let arena = Bump::new();
179 let mut ctx = ReportContext::new(&arena);
180 let mut balance = Balance::default();
181 balance.add_posting_amount(
182 ctx.accounts.ensure("Expenses"),
183 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
184 );
185
186 let prev = balance
187 .set_partial(
188 ctx.accounts.ensure("Expenses"),
189 PostingAmount::from_value(dec!(-1000), ctx.commodities.ensure("JPY")),
190 )
191 .unwrap();
192
193 assert_eq!(
194 prev,
195 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
196 );
197 assert_eq!(
198 balance.get(&ctx.accounts.ensure("Expenses")),
199 Some(&Amount::from_value(
200 dec!(-1000),
201 ctx.commodities.ensure("JPY")
202 ))
203 );
204 }
205
206 #[test]
207 fn test_balance_set_partial_multi_commodities() {
208 let arena = Bump::new();
209 let mut ctx = ReportContext::new(&arena);
210 let mut balance = Balance::default();
211 balance.add_posting_amount(
212 ctx.accounts.ensure("Expenses"),
213 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
214 );
215 balance.add_posting_amount(
216 ctx.accounts.ensure("Expenses"),
217 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF")),
218 );
219
220 let prev = balance
221 .set_partial(
222 ctx.accounts.ensure("Expenses"),
223 PostingAmount::from_value(dec!(100), ctx.commodities.ensure("CHF")),
224 )
225 .unwrap();
226
227 assert_eq!(
228 prev,
229 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF"))
230 );
231 assert_eq!(
232 balance.get(&ctx.accounts.ensure("Expenses")),
233 Some(&Amount::from_values([
234 (dec!(1000), ctx.commodities.ensure("JPY")),
235 (dec!(100), ctx.commodities.ensure("CHF")),
236 ]))
237 );
238 }
239
240 #[test]
241 fn test_balance_set_partial_zero_on_zero() {
242 let arena = Bump::new();
243 let mut ctx = ReportContext::new(&arena);
244 let mut balance = Balance::default();
245
246 let prev = balance
247 .set_partial(ctx.accounts.ensure("Expenses"), PostingAmount::zero())
248 .unwrap();
249
250 assert_eq!(prev, PostingAmount::zero());
251 assert_eq!(
252 balance.get(&ctx.accounts.ensure("Expenses")),
253 Some(&Amount::zero())
254 );
255 }
256
257 #[test]
258 fn test_balance_set_partial_zero_on_single_commodity() {
259 let arena = Bump::new();
260 let mut ctx = ReportContext::new(&arena);
261 let mut balance = Balance::default();
262 balance.add_posting_amount(
263 ctx.accounts.ensure("Expenses"),
264 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
265 );
266
267 let prev = balance
268 .set_partial(ctx.accounts.ensure("Expenses"), PostingAmount::zero())
269 .unwrap();
270
271 assert_eq!(
272 prev,
273 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY"))
274 );
275 assert_eq!(
276 balance.get(&ctx.accounts.ensure("Expenses")),
277 Some(&Amount::zero())
278 );
279 }
280
281 #[test]
282 fn test_balance_set_partial_zero_fails_on_multi_commodities() {
283 let arena = Bump::new();
284 let mut ctx = ReportContext::new(&arena);
285 let mut balance = Balance::default();
286 balance.add_posting_amount(
287 ctx.accounts.ensure("Expenses"),
288 PostingAmount::from_value(dec!(1000), ctx.commodities.ensure("JPY")),
289 );
290 balance.add_posting_amount(
291 ctx.accounts.ensure("Expenses"),
292 PostingAmount::from_value(dec!(200), ctx.commodities.ensure("CHF")),
293 );
294
295 let err = balance
296 .set_partial(ctx.accounts.ensure("Expenses"), PostingAmount::zero())
297 .unwrap_err();
298
299 assert_eq!(
300 err,
301 BalanceError::MultiCommodityWithPartialSet(EvalError::PostingAmountRequired)
302 );
303 }
304}