okane_core/report/
price_db.rs

1//! Provides [PriceRepository], which can compute the commodity (currency) conversion.
2
3use std::{
4    collections::{hash_map, BinaryHeap, HashMap},
5    path::Path,
6};
7
8use chrono::{NaiveDate, TimeDelta};
9use rust_decimal::Decimal;
10
11use crate::parse;
12
13use super::{
14    commodity::Commodity,
15    context::ReportContext,
16    eval::{Amount, SingleAmount},
17};
18
19#[derive(Debug, thiserror::Error)]
20pub enum LoadError {
21    #[error("failed to perform IO")]
22    IO(#[from] std::io::Error),
23    #[error("failed to parse price DB entry: {0}")]
24    Parse(#[from] parse::ParseError),
25}
26
27/// Source of the price information.
28/// In the DB, latter one (larger one as Ord) has priority,
29/// and if you have events with higher priority,
30/// lower priority events are discarded.
31#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
32pub(super) enum PriceSource {
33    Ledger,
34    PriceDB,
35}
36
37#[derive(Debug)]
38struct Entry(PriceSource, Vec<(NaiveDate, Decimal)>);
39
40/// Builder of [`PriceRepository`].
41#[derive(Debug, Default)]
42pub(super) struct PriceRepositoryBuilder<'ctx> {
43    records: HashMap<Commodity<'ctx>, HashMap<Commodity<'ctx>, Entry>>,
44}
45
46/// Event of commodity price.
47#[derive(Debug, PartialEq, Eq)]
48pub(super) struct PriceEvent<'ctx> {
49    pub date: NaiveDate,
50    pub price_x: SingleAmount<'ctx>,
51    pub price_y: SingleAmount<'ctx>,
52}
53
54#[cfg(test)]
55impl<'ctx> PriceEvent<'ctx> {
56    fn sort_key<'a>(&'a self) -> (NaiveDate, &'ctx str, &'ctx str, &'a Decimal, &'a Decimal) {
57        let PriceEvent {
58            date,
59            price_x:
60                SingleAmount {
61                    value: value_x,
62                    commodity: commodity_x,
63                },
64            price_y:
65                SingleAmount {
66                    value: value_y,
67                    commodity: commodity_y,
68                },
69        } = self;
70        (
71            *date,
72            commodity_x.as_str(),
73            commodity_y.as_str(),
74            value_x,
75            value_y,
76        )
77    }
78}
79
80impl<'ctx> PriceRepositoryBuilder<'ctx> {
81    pub fn insert_price(&mut self, source: PriceSource, event: PriceEvent<'ctx>) {
82        if event.price_x.commodity == event.price_y.commodity {
83            // this must be an error returned, instead of log error.
84            log::error!("price log should not contain the self-mention rate");
85        }
86        self.insert_impl(source, event.date, event.price_x, event.price_y);
87        self.insert_impl(source, event.date, event.price_y, event.price_x);
88    }
89
90    fn insert_impl(
91        &mut self,
92        source: PriceSource,
93        date: NaiveDate,
94        price_of: SingleAmount<'ctx>,
95        price_with: SingleAmount<'ctx>,
96    ) {
97        let Entry(stored_source, entries): &mut _ = self
98            .records
99            .entry(price_with.commodity)
100            .or_default()
101            .entry(price_of.commodity)
102            .or_insert(Entry(PriceSource::Ledger, Vec::new()));
103        if *stored_source < source {
104            *stored_source = source;
105            entries.clear();
106        }
107        // price_of: x X
108        // price_with: y Y
109        //
110        // typical use: price_of: 1 X
111        // then records[Y][X] == y (/ 1)
112        entries.push((date, price_with.value / price_of.value));
113    }
114
115    /// Loads PriceDB information from the given file.
116    pub fn load_price_db(
117        &mut self,
118        ctx: &mut ReportContext<'ctx>,
119        path: &Path,
120    ) -> Result<(), LoadError> {
121        // Even though price db can be up to a few megabytes,
122        // still it's much easier to load everything into memory.
123        let content = std::fs::read_to_string(path)?;
124        for entry in parse::price::parse_price_db(&parse::ParseOptions::default(), &content) {
125            let (_, entry) = entry?;
126            // we cannot skip commodities we don't know, as the price might be indirected in DB.
127            // For example, if we have only AUD and JPY in Ledger,
128            // price DB might just expose AUD/EUR EUR/CHF CHF/JPY.
129            let target = ctx.commodities.ensure(entry.target.as_ref());
130            let rate: SingleAmount<'ctx> = SingleAmount::from_value(
131                entry.rate.value.value,
132                ctx.commodities.ensure(&entry.rate.commodity),
133            );
134            self.insert_price(
135                PriceSource::PriceDB,
136                PriceEvent {
137                    price_x: SingleAmount::from_value(Decimal::ONE, target),
138                    price_y: rate,
139                    date: entry.datetime.date(),
140                },
141            );
142        }
143        Ok(())
144    }
145
146    #[cfg(test)]
147    pub fn into_events(self) -> Vec<PriceEvent<'ctx>> {
148        let mut ret = Vec::new();
149        for (price_with, v) in self.records {
150            for (price_of, Entry(_, v)) in v {
151                for (date, rate) in v {
152                    ret.push(PriceEvent {
153                        price_x: SingleAmount::from_value(Decimal::ONE, price_of),
154                        price_y: SingleAmount::from_value(rate, price_with),
155                        date,
156                    });
157                }
158            }
159        }
160        ret.sort_by(|x, y| x.sort_key().cmp(&y.sort_key()));
161        ret
162    }
163
164    pub fn build(self) -> PriceRepository<'ctx> {
165        PriceRepository::new(self.build_naive())
166    }
167
168    fn build_naive(mut self) -> NaivePriceRepository<'ctx> {
169        self.records
170            .values_mut()
171            .for_each(|x| x.values_mut().for_each(|x| x.1.sort()));
172        NaivePriceRepository {
173            records: self.records,
174        }
175    }
176}
177
178#[derive(Debug, thiserror::Error)]
179pub enum ConversionError<'ctx> {
180    #[error("commodity rate {0} into {1} at {2} not found")]
181    RateNotFound(SingleAmount<'ctx>, Commodity<'ctx>, NaiveDate),
182}
183
184/// Converts the given amount into the specified commodity.
185pub fn convert_amount<'ctx>(
186    price_repos: &mut PriceRepository<'ctx>,
187    amount: &Amount<'ctx>,
188    commodity_with: Commodity<'ctx>,
189    date: NaiveDate,
190) -> Result<Amount<'ctx>, ConversionError<'ctx>> {
191    let mut result = Amount::zero();
192    for v in amount.iter() {
193        result += price_repos.convert_single(v, commodity_with, date)?;
194    }
195    Ok(result)
196}
197
198/// Repository which user can query the conversion rate with.
199#[derive(Debug)]
200pub struct PriceRepository<'ctx> {
201    inner: NaivePriceRepository<'ctx>,
202    // TODO: add price_with as a key, otherwise it's wrong.
203    // BTreeMap could be used if cursor support is ready.
204    // Then, we can avoid computing rates over and over if no rate update happens.
205    cache: HashMap<(Commodity<'ctx>, NaiveDate), HashMap<Commodity<'ctx>, WithDistance<Decimal>>>,
206}
207
208impl<'ctx> PriceRepository<'ctx> {
209    fn new(inner: NaivePriceRepository<'ctx>) -> Self {
210        Self {
211            inner,
212            cache: HashMap::new(),
213        }
214    }
215
216    /// Converts the given `value` into the `commodity_with`.
217    /// If the given value has already the `commodity_with`,
218    /// returns `Ok(value)` as-is.
219    pub fn convert_single(
220        &mut self,
221        value: SingleAmount<'ctx>,
222        commodity_with: Commodity<'ctx>,
223        date: NaiveDate,
224    ) -> Result<SingleAmount<'ctx>, ConversionError<'ctx>> {
225        if value.commodity == commodity_with {
226            return Ok(value);
227        }
228        let rate = self
229            .cache
230            .entry((commodity_with, date))
231            .or_insert_with(|| self.inner.compute_price_table(commodity_with, date))
232            .get(&value.commodity);
233        match rate {
234            Some(WithDistance(_, rate)) => {
235                Ok(SingleAmount::from_value(value.value * rate, commodity_with))
236            }
237            None => Err(ConversionError::RateNotFound(value, commodity_with, date)),
238        }
239    }
240}
241
242#[derive(Debug)]
243struct NaivePriceRepository<'ctx> {
244    // from comodity -> to commodity -> date -> price.
245    // e.g. USD AAPL 2024-01-01 100 means 1 AAPL == 100 USD at 2024-01-01.
246    // the value are sorted in NaiveDate order.
247    records: HashMap<Commodity<'ctx>, HashMap<Commodity<'ctx>, Entry>>,
248}
249
250impl<'ctx> NaivePriceRepository<'ctx> {
251    /// Copied from CachedPriceRepository, needs to be factored out properly.
252    #[cfg(test)]
253    fn convert(
254        &self,
255        value: SingleAmount<'ctx>,
256        commodity_with: Commodity<'ctx>,
257        date: NaiveDate,
258    ) -> Result<SingleAmount<'ctx>, SingleAmount<'ctx>> {
259        if value.commodity == commodity_with {
260            return Ok(value);
261        }
262        let rate = self
263            .compute_price_table(commodity_with, date)
264            .get(&value.commodity)
265            .map(|x| x.1);
266        match rate {
267            Some(rate) => Ok(SingleAmount::from_value(value.value * rate, commodity_with)),
268            None => Err(value),
269        }
270    }
271
272    fn compute_price_table(
273        &self,
274        price_with: Commodity<'ctx>,
275        date: NaiveDate,
276    ) -> HashMap<Commodity<'ctx>, WithDistance<Decimal>> {
277        // minimize the distance, and then minimize the staleness.
278        let mut queue: BinaryHeap<WithDistance<(Commodity<'ctx>, Decimal)>> = BinaryHeap::new();
279        let mut distances: HashMap<Commodity<'ctx>, WithDistance<Decimal>> = HashMap::new();
280        queue.push(WithDistance(
281            Distance {
282                num_ledger_conversions: 0,
283                num_all_conversions: 0,
284                staleness: TimeDelta::zero(),
285            },
286            (price_with, Decimal::ONE),
287        ));
288        while let Some(curr) = queue.pop() {
289            log::debug!("curr: {:?}", curr);
290            let WithDistance(curr_dist, (prev, prev_rate)) = curr;
291            if let Some(WithDistance(prev_dist, _)) = distances.get(&prev) {
292                if *prev_dist < curr_dist {
293                    log::debug!(
294                        "no need to update, prev_dist {:?} is smaller than curr_dist {:?}",
295                        prev_dist,
296                        curr_dist
297                    );
298                    continue;
299                }
300            }
301            for (j, Entry(source, rates)) in match self.records.get(&prev) {
302                None => continue,
303                Some(x) => x,
304            } {
305                let bound = rates.partition_point(|(record_date, _)| record_date <= &date);
306                log::debug!(
307                    "found next commodity {} with date bound {}",
308                    j.as_str(),
309                    bound
310                );
311                if bound == 0 {
312                    // we cannot find any rate information at the date (all rates are in future).
313                    // let's treat rates are not available.
314                    continue;
315                }
316                let (record_date, rate) = rates[bound - 1];
317                let next_dist = curr_dist.extend(*source, date - record_date);
318                let rate = prev_rate * rate;
319                let next = WithDistance(next_dist.clone(), (*j, rate));
320                let updated = match distances.entry(*j) {
321                    hash_map::Entry::Occupied(mut e) => {
322                        if e.get() <= &next_dist {
323                            false
324                        } else {
325                            e.insert(WithDistance(next_dist, rate));
326                            true
327                        }
328                    }
329                    hash_map::Entry::Vacant(e) => {
330                        e.insert(WithDistance(next_dist, rate));
331                        true
332                    }
333                };
334                if !updated {
335                    continue;
336                }
337                queue.push(next);
338            }
339        }
340        distances
341    }
342}
343
344/// Distance to minimize during the price DB computation.
345///
346/// Now this is using simple derived [Ord] logic,
347/// but we can work on heuristic cost function instead.
348#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
349struct Distance {
350    /// Number of conversions with [`PriceSource::Ledger`] used to compute the rate.
351    /// Minimize this because we assume [`PriceSource::PriceDB`] is more reliable
352    /// than the one in Ledger.
353    num_ledger_conversions: usize,
354    /// Number of conversions used to compute the rate.
355    num_all_conversions: usize,
356    /// Staleness of the conversion rate.
357    staleness: TimeDelta,
358}
359
360impl Distance {
361    fn extend(&self, source: PriceSource, staleness: TimeDelta) -> Self {
362        let num_ledger_conversions = self.num_ledger_conversions
363            + match source {
364                PriceSource::Ledger => 1,
365                PriceSource::PriceDB => 0,
366            };
367        Self {
368            num_ledger_conversions,
369            num_all_conversions: self.num_all_conversions + 1,
370            staleness: std::cmp::max(self.staleness, staleness),
371        }
372    }
373}
374
375#[derive(Debug)]
376struct WithDistance<T>(Distance, T);
377
378impl<T> PartialEq for WithDistance<T> {
379    fn eq(&self, other: &Self) -> bool {
380        self.0 == other.0
381    }
382}
383
384impl<T> PartialEq<Distance> for WithDistance<T> {
385    fn eq(&self, other: &Distance) -> bool {
386        self.0 == *other
387    }
388}
389
390impl<T> Eq for WithDistance<T> {}
391
392impl<T> PartialOrd for WithDistance<T> {
393    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
394        self.0.partial_cmp(&other.0)
395    }
396}
397
398impl<T: Eq> PartialOrd<Distance> for WithDistance<T> {
399    fn partial_cmp(&self, other: &Distance) -> Option<std::cmp::Ordering> {
400        self.0.partial_cmp(other)
401    }
402}
403
404impl<T: Eq> Ord for WithDistance<T> {
405    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
406        self.0.cmp(&other.0)
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413
414    use bumpalo::Bump;
415    use rust_decimal_macros::dec;
416
417    #[test]
418    fn price_db_computes_direct_price() {
419        let arena = Bump::new();
420        let mut ctx = ReportContext::new(&arena);
421        let chf = ctx.commodities.ensure("CHF");
422        let eur = ctx.commodities.ensure("EUR");
423        let mut builder = PriceRepositoryBuilder::default();
424        builder.insert_price(
425            PriceSource::Ledger,
426            PriceEvent {
427                date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
428                price_x: SingleAmount::from_value(dec!(1), eur),
429                price_y: SingleAmount::from_value(dec!(0.8), chf),
430            },
431        );
432
433        let db = builder.build_naive();
434
435        // before the event date, we can't convert the value, thus see Right.
436        let got = db.convert(
437            SingleAmount::from_value(dec!(1), eur),
438            chf,
439            NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
440        );
441        assert_eq!(got, Err(SingleAmount::from_value(dec!(1), eur)));
442
443        let got = db.convert(
444            SingleAmount::from_value(dec!(10), chf),
445            eur,
446            NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
447        );
448        assert_eq!(got, Err(SingleAmount::from_value(dec!(10), chf)));
449
450        let got = db.convert(
451            SingleAmount::from_value(dec!(1.0), eur),
452            chf,
453            NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
454        );
455        assert_eq!(got, Ok(SingleAmount::from_value(dec!(0.8), chf)));
456
457        let got = db.convert(
458            SingleAmount::from_value(dec!(10.0), chf),
459            eur,
460            NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
461        );
462        assert_eq!(got, Ok(SingleAmount::from_value(dec!(12.5), eur)));
463    }
464
465    #[test]
466    fn price_db_computes_indirect_price() {
467        let arena = Bump::new();
468        let mut ctx = ReportContext::new(&arena);
469        let chf = ctx.commodities.ensure("CHF");
470        let eur = ctx.commodities.ensure("EUR");
471        let usd = ctx.commodities.ensure("USD");
472        let jpy = ctx.commodities.ensure("JPY");
473        let mut builder = PriceRepositoryBuilder::default();
474
475        builder.insert_price(
476            PriceSource::Ledger,
477            PriceEvent {
478                date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
479                price_x: SingleAmount::from_value(dec!(0.8), chf),
480                price_y: SingleAmount::from_value(dec!(1), eur),
481            },
482        );
483        builder.insert_price(
484            PriceSource::Ledger,
485            PriceEvent {
486                date: NaiveDate::from_ymd_opt(2024, 10, 2).unwrap(),
487                price_x: SingleAmount::from_value(dec!(0.8), eur),
488                price_y: SingleAmount::from_value(dec!(1), usd),
489            },
490        );
491        builder.insert_price(
492            PriceSource::Ledger,
493            PriceEvent {
494                date: NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
495                price_x: SingleAmount::from_value(dec!(100), jpy),
496                price_y: SingleAmount::from_value(dec!(1), usd),
497            },
498        );
499
500        // 1 EUR = 0.8 CHF
501        // 1 USD = 0.8 EUR
502        // 1 USD = 100 JPY
503        // 1 CHF == 5/4 EUR == (5/4)*(5/4) USD == 156.25 JPY
504
505        let db = builder.build_naive();
506
507        let got = db.convert(
508            SingleAmount::from_value(dec!(1), chf),
509            jpy,
510            NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
511        );
512        assert_eq!(got, Ok(SingleAmount::from_value(dec!(156.25), jpy)));
513    }
514}