Skip to main content

okane_core/report/
balance.rs

1use std::collections::HashMap;
2
3use crate::report::eval::EvalError;
4
5use super::{
6    context::Account,
7    eval::{Amount, OwnedEvalError, PostingAmount},
8    ReportContext,
9};
10
11/// Error related to [Balance] operations.
12#[derive(Debug, thiserror::Error, PartialEq, Eq)]
13pub enum BalanceError {
14    #[error("balance = 0 should be used on single commodity balance")]
15    MultiCommodityWithPartialSet(#[source] OwnedEvalError, String),
16}
17
18impl BalanceError {
19    pub(super) fn note(&self) -> impl std::fmt::Display + '_ {
20        BalanceErrorNote(self)
21    }
22}
23
24struct BalanceErrorNote<'a>(&'a BalanceError);
25
26impl std::fmt::Display for BalanceErrorNote<'_> {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        match self.0 {
29            BalanceError::MultiCommodityWithPartialSet(_, balance) => {
30                write!(f, "actual: {balance}")
31            }
32        }
33    }
34}
35
36/// Accumulated balance of accounts.
37#[derive(Debug, Default, PartialEq, Eq, Clone)]
38pub struct Balance<'ctx> {
39    accounts: HashMap<Account<'ctx>, Amount<'ctx>>,
40}
41
42impl<'ctx> FromIterator<(Account<'ctx>, Amount<'ctx>)> for Balance<'ctx> {
43    fn from_iter<T>(iter: T) -> Self
44    where
45        T: IntoIterator<Item = (Account<'ctx>, Amount<'ctx>)>,
46    {
47        Self {
48            accounts: iter.into_iter().collect(),
49        }
50    }
51}
52
53impl<'ctx> Balance<'ctx> {
54    /// Constructs an instance directly from the map.
55    pub fn from_map(values: HashMap<Account<'ctx>, Amount<'ctx>>) -> Self {
56        Self { accounts: values }
57    }
58
59    /// Converts into the underlying [`HashMap`].
60    pub fn into_map(self) -> HashMap<Account<'ctx>, Amount<'ctx>> {
61        self.accounts
62    }
63
64    /// Constructs sorted vec of account and commodity tuple.
65    pub fn into_vec(self) -> Vec<(Account<'ctx>, Amount<'ctx>)> {
66        let mut ret: Vec<(Account<'ctx>, Amount<'ctx>)> = self.accounts.into_iter().collect();
67        ret.sort_unstable_by_key(|(a, _)| a.as_str());
68        ret
69    }
70
71    /// Adds a particular account value, and returns the updated balance.
72    pub fn add_amount(&mut self, account: Account<'ctx>, amount: Amount<'ctx>) -> &Amount<'ctx> {
73        let curr: &mut Amount = self.accounts.entry(account).or_default();
74        *curr += amount;
75        curr.remove_zero_entries();
76        curr
77    }
78
79    /// Adds a particular account value with the specified commodity, and returns the updated balance.
80    pub(super) fn add_posting_amount(
81        &mut self,
82        account: Account<'ctx>,
83        amount: PostingAmount<'ctx>,
84    ) -> &Amount<'ctx> {
85        let curr: &mut Amount = self.accounts.entry(account).or_default();
86        *curr += amount;
87        curr.remove_zero_entries();
88        curr
89    }
90
91    /// Tries to set the particular account's balance with the specified commodity,
92    /// and returns the delta which should have caused the difference.
93    pub(super) fn set_partial(
94        &mut self,
95        ctx: &ReportContext<'ctx>,
96        account: Account<'ctx>,
97        amount: PostingAmount<'ctx>,
98    ) -> Result<PostingAmount<'ctx>, BalanceError> {
99        match amount {
100            PostingAmount::Zero => {
101                let prev: Amount<'ctx> = self
102                    .accounts
103                    .insert(account, Amount::zero())
104                    .unwrap_or_default();
105                (&prev).try_into().map_err(|e: EvalError<'_>| {
106                    BalanceError::MultiCommodityWithPartialSet(
107                        e.into_owned(ctx),
108                        prev.as_inline_display(ctx).to_string(),
109                    )
110                })
111            }
112            PostingAmount::Single(single_amount) => {
113                let prev = self
114                    .accounts
115                    .entry(account)
116                    .or_default()
117                    .set_partial(single_amount);
118                Ok(PostingAmount::Single(prev))
119            }
120        }
121    }
122
123    /// Gets the balance of the given account.
124    pub fn get(&self, account: Account<'ctx>) -> Option<&Amount<'ctx>> {
125        self.accounts.get(&account)
126    }
127
128    /// Rounds the balance following the context.
129    pub fn round(&mut self, ctx: &ReportContext<'ctx>) {
130        for amount in self.accounts.values_mut() {
131            amount.round_mut(ctx);
132        }
133    }
134
135    /// Returns unordered iterator for the account and the amount.
136    pub(crate) fn iter(&self) -> impl Iterator<Item = (&Account<'ctx>, &Amount<'ctx>)> {
137        self.accounts.iter()
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    use bumpalo::Bump;
146    use maplit::hashmap;
147    use pretty_assertions::assert_eq;
148    use rust_decimal_macros::dec;
149
150    use super::super::context::ReportContext;
151
152    #[test]
153    fn to_from_map() {
154        let arena = Bump::new();
155        let mut ctx = ReportContext::new(&arena);
156        let expenses = ctx.accounts.ensure("Expenses");
157
158        let m = hashmap! {
159            expenses =>
160                Amount::from_value(ctx.commodities.ensure("JPY"), dec!(10)),
161            ctx.accounts.ensure("Income") =>
162                Amount::from_iter([
163                    (ctx.commodities.ensure("CHF"), dec!(15)),
164                    (ctx.commodities.ensure("USD"), dec!(-5)),
165                ]),
166        };
167
168        let b = Balance::from_map(m.clone());
169        assert_eq!(
170            b.get(expenses),
171            Some(&Amount::from_value(ctx.commodities.ensure("JPY"), dec!(10)))
172        );
173
174        let m2 = b.into_map();
175
176        assert_eq!(m, m2);
177    }
178
179    #[test]
180    fn balance_gives_zero_amount_when_not_initalized() {
181        let arena = Bump::new();
182        let mut ctx = ReportContext::new(&arena);
183
184        let balance = Balance::default();
185        assert_eq!(balance.get(ctx.accounts.ensure("Expenses")), None);
186    }
187
188    #[test]
189    fn test_balance_increment_adds_value() {
190        let arena = Bump::new();
191        let mut ctx = ReportContext::new(&arena);
192
193        let mut balance = Balance::default();
194        let updated = balance
195            .add_posting_amount(
196                ctx.accounts.ensure("Expenses"),
197                PostingAmount::from_value(ctx.commodities.ensure("JPY"), dec!(1000)),
198            )
199            .clone();
200
201        assert_eq!(
202            updated,
203            Amount::from_value(ctx.commodities.ensure("JPY"), dec!(1000))
204        );
205        assert_eq!(balance.get(ctx.accounts.ensure("Expenses")), Some(&updated));
206
207        let updated = balance
208            .add_posting_amount(
209                ctx.accounts.ensure("Expenses"),
210                PostingAmount::from_value(ctx.commodities.ensure("JPY"), dec!(-1000)),
211            )
212            .clone();
213
214        assert_eq!(updated, Amount::zero());
215        assert_eq!(balance.get(ctx.accounts.ensure("Expenses")), Some(&updated));
216    }
217
218    #[test]
219    fn test_balance_set_partial_from_absolute_zero() {
220        let arena = Bump::new();
221        let mut ctx = ReportContext::new(&arena);
222        let mut balance = Balance::default();
223
224        let expenses = ctx.accounts.ensure("Expenses");
225        let jpy = ctx.commodities.insert("JPY").unwrap();
226        let prev = balance
227            .set_partial(&ctx, expenses, PostingAmount::from_value(jpy, dec!(1000)))
228            .unwrap();
229
230        // Note it won't be PostingAmount::zero(),
231        // as set_partial is called with commodity amount.
232        assert_eq!(prev, PostingAmount::from_value(jpy, dec!(0)));
233        assert_eq!(
234            balance.get(expenses),
235            Some(&Amount::from_value(jpy, dec!(1000)))
236        );
237    }
238
239    #[test]
240    fn test_balance_set_partial_hit_same_commodity() {
241        let arena = Bump::new();
242        let mut ctx = ReportContext::new(&arena);
243        let mut balance = Balance::default();
244        let jpy = ctx.commodities.ensure("JPY");
245        balance.add_posting_amount(
246            ctx.accounts.ensure("Expenses"),
247            PostingAmount::from_value(jpy, dec!(1000)),
248        );
249
250        let expenses = ctx.accounts.ensure("Expenses");
251
252        let prev = balance
253            .set_partial(&ctx, expenses, PostingAmount::from_value(jpy, dec!(-1000)))
254            .unwrap();
255
256        assert_eq!(prev, PostingAmount::from_value(jpy, dec!(1000)));
257        assert_eq!(
258            balance.get(ctx.accounts.ensure("Expenses")),
259            Some(&Amount::from_value(jpy, dec!(-1000)))
260        );
261    }
262
263    #[test]
264    fn test_balance_set_partial_multi_commodities() {
265        let arena = Bump::new();
266        let mut ctx = ReportContext::new(&arena);
267        let mut balance = Balance::default();
268        let jpy = ctx.commodities.ensure("JPY");
269        let chf = ctx.commodities.ensure("CHF");
270        balance.add_posting_amount(
271            ctx.accounts.ensure("Expenses"),
272            PostingAmount::from_value(jpy, dec!(1000)),
273        );
274        balance.add_posting_amount(
275            ctx.accounts.ensure("Expenses"),
276            PostingAmount::from_value(chf, dec!(200)),
277        );
278
279        let expenses = ctx.accounts.ensure("Expenses");
280
281        let prev = balance
282            .set_partial(&ctx, expenses, PostingAmount::from_value(chf, dec!(100)))
283            .unwrap();
284
285        assert_eq!(prev, PostingAmount::from_value(chf, dec!(200)));
286        assert_eq!(
287            balance.get(ctx.accounts.ensure("Expenses")),
288            Some(&Amount::from_iter([(jpy, dec!(1000)), (chf, dec!(100)),]))
289        );
290    }
291
292    #[test]
293    fn test_balance_set_partial_zero_on_zero() {
294        let arena = Bump::new();
295        let mut ctx = ReportContext::new(&arena);
296        let mut balance = Balance::default();
297
298        let expenses = ctx.accounts.ensure("Expenses");
299
300        let prev = balance
301            .set_partial(&ctx, expenses, PostingAmount::zero())
302            .unwrap();
303
304        assert_eq!(prev, PostingAmount::zero());
305        assert_eq!(
306            balance.get(ctx.accounts.ensure("Expenses")),
307            Some(&Amount::zero())
308        );
309    }
310
311    #[test]
312    fn test_balance_set_partial_zero_on_single_commodity() {
313        let arena = Bump::new();
314        let mut ctx = ReportContext::new(&arena);
315        let mut balance = Balance::default();
316        let jpy = ctx.commodities.ensure("JPY");
317        balance.add_posting_amount(
318            ctx.accounts.ensure("Expenses"),
319            PostingAmount::from_value(jpy, dec!(1000)),
320        );
321
322        let expenses = ctx.accounts.ensure("Expenses");
323
324        let prev = balance
325            .set_partial(&ctx, expenses, PostingAmount::zero())
326            .unwrap();
327
328        assert_eq!(prev, PostingAmount::from_value(jpy, dec!(1000)));
329        assert_eq!(
330            balance.get(ctx.accounts.ensure("Expenses")),
331            Some(&Amount::zero())
332        );
333    }
334
335    #[test]
336    fn test_balance_set_partial_zero_fails_on_multi_commodities() {
337        let arena = Bump::new();
338        let mut ctx = ReportContext::new(&arena);
339        let mut balance = Balance::default();
340        balance.add_posting_amount(
341            ctx.accounts.ensure("Expenses"),
342            PostingAmount::from_value(ctx.commodities.ensure("JPY"), dec!(1000)),
343        );
344        balance.add_posting_amount(
345            ctx.accounts.ensure("Expenses"),
346            PostingAmount::from_value(ctx.commodities.ensure("CHF"), dec!(200)),
347        );
348
349        let expenses = ctx.accounts.ensure("Expenses");
350
351        let err = balance
352            .set_partial(&ctx, expenses, PostingAmount::zero())
353            .unwrap_err();
354
355        assert_eq!(
356            err,
357            BalanceError::MultiCommodityWithPartialSet(
358                OwnedEvalError::PostingAmountRequired,
359                "(1000 JPY + 200 CHF)".to_string()
360            )
361        );
362    }
363}