Skip to main content

okane_core/report/
price_db.rs

1//! Provides [PriceRepository], which can compute the commodity (currency) conversion.
2
3use std::collections::{BinaryHeap, HashMap};
4use std::path::{Path, PathBuf};
5
6use chrono::{NaiveDate, TimeDelta};
7use rust_decimal::Decimal;
8
9use crate::load;
10use crate::parse;
11use crate::report::commodity::{CommodityMap, CommodityTag, OwnedCommodity};
12
13use super::context::ReportContext;
14use super::eval::{Amount, SingleAmount};
15
16#[derive(Debug, thiserror::Error)]
17pub enum LoadError {
18    #[error("failed to load price DB file {0}")]
19    IO(PathBuf, #[source] std::io::Error),
20    #[error("failed to parse price DB file {0}")]
21    Parse(PathBuf, #[source] parse::ParseError),
22}
23
24/// Source of the price information.
25/// In the DB, latter one (larger one as Ord) has priority,
26/// and if you have events with higher priority,
27/// lower priority events are discarded.
28#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
29pub(super) enum PriceSource {
30    Ledger,
31    PriceDB,
32}
33
34#[derive(Debug)]
35struct Entry(PriceSource, Vec<(NaiveDate, Decimal)>);
36
37/// Builder of [`PriceRepository`].
38#[derive(Debug, Default)]
39pub(super) struct PriceRepositoryBuilder<'ctx> {
40    records: HashMap<CommodityTag<'ctx>, HashMap<CommodityTag<'ctx>, Entry>>,
41}
42
43/// Event of commodity price.
44#[derive(Debug, PartialEq, Eq)]
45pub(super) struct PriceEvent<'ctx> {
46    pub date: NaiveDate,
47    pub price_x: SingleAmount<'ctx>,
48    pub price_y: SingleAmount<'ctx>,
49}
50
51#[cfg(test)]
52impl<'ctx> PriceEvent<'ctx> {
53    fn sort_key(&self) -> (NaiveDate, usize, usize) {
54        let PriceEvent {
55            date,
56            price_x:
57                SingleAmount {
58                    value: _,
59                    commodity: commodity_x,
60                },
61            price_y:
62                SingleAmount {
63                    value: _,
64                    commodity: commodity_y,
65                },
66        } = self;
67        (*date, commodity_x.as_index(), commodity_y.as_index())
68    }
69}
70
71impl<'ctx> PriceRepositoryBuilder<'ctx> {
72    pub fn insert_price(&mut self, source: PriceSource, event: PriceEvent<'ctx>) {
73        if event.price_x.commodity == event.price_y.commodity {
74            // this must be an error returned, instead of log error.
75            log::error!("price log should not contain the self-mention rate");
76        }
77        self.insert_impl(source, event.date, event.price_x, event.price_y);
78        self.insert_impl(source, event.date, event.price_y, event.price_x);
79    }
80
81    fn insert_impl(
82        &mut self,
83        source: PriceSource,
84        date: NaiveDate,
85        price_of: SingleAmount<'ctx>,
86        price_with: SingleAmount<'ctx>,
87    ) {
88        let Entry(stored_source, entries): &mut _ = self
89            .records
90            .entry(price_with.commodity)
91            .or_default()
92            .entry(price_of.commodity)
93            .or_insert(Entry(PriceSource::Ledger, Vec::new()));
94        if *stored_source < source {
95            *stored_source = source;
96            entries.clear();
97        }
98        // price_of: x X
99        // price_with: y Y
100        //
101        // typical use: price_of: 1 X
102        // then records[Y][X] == y (/ 1)
103        entries.push((date, price_with.value / price_of.value));
104    }
105
106    /// Loads PriceDB information from the given file.
107    pub fn load_price_db<F: load::FileSystem>(
108        &mut self,
109        ctx: &mut ReportContext<'ctx>,
110        filesystem: &F,
111        path: &Path,
112    ) -> Result<(), LoadError> {
113        // Even though price db can be up to a few megabytes,
114        // still it's much easier to load everything into memory.
115        let content = filesystem
116            .file_content_utf8(path)
117            .map_err(|e| LoadError::IO(path.to_owned(), e))?;
118        for entry in parse::price::parse_price_db(&parse::ParseOptions::default(), &content) {
119            let (_, entry) = entry.map_err(|e| LoadError::Parse(path.to_owned(), e))?;
120            // we cannot skip commodities which doesn't appear in Ledger source,
121            // as the price might be computed via indirect relationship.
122            // For example, if we have only AUD and JPY in Ledger,
123            // price DB might expose AUD/EUR EUR/CHF CHF/JPY conversion.
124            let target = ctx.commodities.ensure(entry.target.as_ref());
125            let rate: SingleAmount<'ctx> = SingleAmount::from_value(
126                ctx.commodities.ensure(&entry.rate.commodity),
127                entry.rate.value.value,
128            );
129            self.insert_price(
130                PriceSource::PriceDB,
131                PriceEvent {
132                    price_x: SingleAmount::from_value(target, Decimal::ONE),
133                    price_y: rate,
134                    date: entry.datetime.date(),
135                },
136            );
137        }
138        Ok(())
139    }
140
141    /// Returns iterator of [`PriceEvent`] in unspecified order.
142    #[cfg(test)]
143    pub fn iter_events(&self) -> impl Iterator<Item = (PriceSource, PriceEvent<'ctx>)> {
144        self.records.iter().flat_map(|(price_with, v)| {
145            v.iter().flat_map(|(price_of, Entry(source, v))| {
146                v.iter().map(|(date, rate)| {
147                    (
148                        *source,
149                        PriceEvent {
150                            price_x: SingleAmount::from_value(*price_of, Decimal::ONE),
151                            price_y: SingleAmount::from_value(*price_with, *rate),
152                            date: *date,
153                        },
154                    )
155                })
156            })
157        })
158    }
159
160    #[cfg(test)]
161    pub fn to_events(&self) -> Vec<PriceEvent<'ctx>> {
162        let mut ret: Vec<PriceEvent<'ctx>> =
163            self.iter_events().map(|(_source, event)| event).collect();
164        ret.sort_by_key(|x| x.sort_key());
165        ret
166    }
167
168    pub fn build(self) -> PriceRepository<'ctx> {
169        PriceRepository::new(self.build_naive())
170    }
171
172    fn build_naive(mut self) -> NaivePriceRepository<'ctx> {
173        self.records
174            .values_mut()
175            .for_each(|x| x.values_mut().for_each(|x| x.1.sort()));
176        NaivePriceRepository {
177            records: self.records,
178        }
179    }
180}
181
182#[derive(Debug, thiserror::Error)]
183pub enum ConversionError {
184    #[error("commodity rate {0} into {1} at {2} not found")]
185    RateNotFound(OwnedCommodity, OwnedCommodity, NaiveDate),
186}
187
188/// Converts the given amount into the specified commodity.
189pub fn convert_amount<'ctx>(
190    ctx: &ReportContext<'ctx>,
191    price_repos: &mut PriceRepository<'ctx>,
192    amount: &Amount<'ctx>,
193    commodity_with: CommodityTag<'ctx>,
194    date: NaiveDate,
195) -> Result<Amount<'ctx>, ConversionError> {
196    let mut result = Amount::zero();
197    for v in amount.iter() {
198        result += price_repos.convert_single(ctx, v, commodity_with, date)?;
199    }
200    Ok(result)
201}
202
203/// Repository which user can query the conversion rate with.
204#[derive(Debug)]
205pub struct PriceRepository<'ctx> {
206    inner: NaivePriceRepository<'ctx>,
207    // BTreeMap could be used if cursor support is ready.
208    // Then, we can avoid computing rates over and over if no rate update happens.
209    cache: HashMap<(CommodityTag<'ctx>, NaiveDate), CommodityMap<WithDistance<Decimal>>>,
210}
211
212impl<'ctx> PriceRepository<'ctx> {
213    fn new(inner: NaivePriceRepository<'ctx>) -> Self {
214        Self {
215            inner,
216            cache: HashMap::new(),
217        }
218    }
219
220    /// Converts the given `value` into the `commodity_with`.
221    /// If the given value has already the `commodity_with`,
222    /// returns `Ok(value)` as-is.
223    pub fn convert_single(
224        &mut self,
225        ctx: &ReportContext<'ctx>,
226        value: SingleAmount<'ctx>,
227        commodity_with: CommodityTag<'ctx>,
228        date: NaiveDate,
229    ) -> Result<SingleAmount<'ctx>, ConversionError> {
230        if value.commodity == commodity_with {
231            return Ok(value);
232        }
233        let rate = self
234            .cache
235            .entry((commodity_with, date))
236            .or_insert_with(|| self.inner.compute_price_table(ctx, commodity_with, date))
237            .get(value.commodity);
238        match rate {
239            Some(WithDistance(_, rate)) => {
240                Ok(SingleAmount::from_value(commodity_with, value.value * rate))
241            }
242            None => Err(ConversionError::RateNotFound(
243                value.commodity.to_owned_lossy(&ctx.commodities),
244                commodity_with.to_owned_lossy(&ctx.commodities),
245                date,
246            )),
247        }
248    }
249}
250
251#[derive(Debug)]
252struct NaivePriceRepository<'ctx> {
253    // from comodity -> to commodity -> date -> price.
254    // e.g. USD AAPL 2024-01-01 100 means 1 AAPL == 100 USD at 2024-01-01.
255    // the value are sorted in NaiveDate order.
256    records: HashMap<CommodityTag<'ctx>, HashMap<CommodityTag<'ctx>, Entry>>,
257}
258
259impl<'ctx> NaivePriceRepository<'ctx> {
260    /// Copied from CachedPriceRepository, needs to be factored out properly.
261    #[cfg(test)]
262    fn convert(
263        &self,
264        ctx: &ReportContext<'ctx>,
265        value: SingleAmount<'ctx>,
266        commodity_with: CommodityTag<'ctx>,
267        date: NaiveDate,
268    ) -> Result<SingleAmount<'ctx>, SingleAmount<'ctx>> {
269        if value.commodity == commodity_with {
270            return Ok(value);
271        }
272        let rate = self
273            .compute_price_table(ctx, commodity_with, date)
274            .get(value.commodity)
275            .map(|x| x.1);
276        match rate {
277            Some(rate) => Ok(SingleAmount::from_value(commodity_with, value.value * rate)),
278            None => Err(value),
279        }
280    }
281
282    fn compute_price_table(
283        &self,
284        ctx: &ReportContext<'ctx>,
285        price_with: CommodityTag<'ctx>,
286        date: NaiveDate,
287    ) -> CommodityMap<WithDistance<Decimal>> {
288        // minimize the distance, and then minimize the staleness.
289        let mut queue: BinaryHeap<WithDistance<(CommodityTag<'ctx>, Decimal)>> = BinaryHeap::new();
290        let mut distances: CommodityMap<WithDistance<Decimal>> =
291            CommodityMap::with_capacity(ctx.commodities.len());
292        queue.push(WithDistance(
293            Distance {
294                num_ledger_conversions: 0,
295                num_all_conversions: 0,
296                staleness: TimeDelta::zero(),
297            },
298            (price_with, Decimal::ONE),
299        ));
300        while let Some(curr) = queue.pop() {
301            log::debug!("curr: {:?}", curr);
302            let WithDistance(curr_dist, (prev, prev_rate)) = curr;
303            if let Some(WithDistance(prev_dist, _)) = distances.get(prev)
304                && *prev_dist < curr_dist
305            {
306                log::debug!(
307                    "no need to update, prev_dist {:?} is smaller than curr_dist {:?}",
308                    prev_dist,
309                    curr_dist
310                );
311                continue;
312            }
313            for (j, Entry(source, rates)) in match self.records.get(&prev) {
314                None => continue,
315                Some(x) => x,
316            } {
317                let bound = rates.partition_point(|(record_date, _)| record_date <= &date);
318                log::debug!(
319                    "found next commodity #{} with date bound {}",
320                    j.as_index(),
321                    bound
322                );
323                if bound == 0 {
324                    // we cannot find any rate information at the date (all rates are in future).
325                    // let's treat rates are not available.
326                    continue;
327                }
328                let (record_date, rate) = rates[bound - 1];
329                let next_dist = curr_dist.extend(*source, date - record_date);
330                let rate = prev_rate * rate;
331                let next = WithDistance(next_dist.clone(), (*j, rate));
332                let e: &mut Option<_> = distances.get_mut(*j);
333                let updated = match e.as_mut() {
334                    Some(e) => {
335                        if *e <= next_dist {
336                            false
337                        } else {
338                            *e = WithDistance(next_dist, rate);
339                            true
340                        }
341                    }
342                    None => {
343                        *e = Some(WithDistance(next_dist, rate));
344                        true
345                    }
346                };
347                if !updated {
348                    continue;
349                }
350                queue.push(next);
351            }
352        }
353        distances
354    }
355}
356
357/// Distance to minimize during the price DB computation.
358///
359/// Now this is using simple derived [Ord] logic,
360/// but we can work on heuristic cost function instead.
361#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
362struct Distance {
363    /// Number of conversions with [`PriceSource::Ledger`] used to compute the rate.
364    /// Minimize this because we assume [`PriceSource::PriceDB`] is more reliable
365    /// than the one in Ledger.
366    num_ledger_conversions: usize,
367    /// Number of conversions used to compute the rate.
368    num_all_conversions: usize,
369    /// Staleness of the conversion rate.
370    staleness: TimeDelta,
371}
372
373impl Distance {
374    fn extend(&self, source: PriceSource, staleness: TimeDelta) -> Self {
375        let num_ledger_conversions = self.num_ledger_conversions
376            + match source {
377                PriceSource::Ledger => 1,
378                PriceSource::PriceDB => 0,
379            };
380        Self {
381            num_ledger_conversions,
382            num_all_conversions: self.num_all_conversions + 1,
383            staleness: std::cmp::max(self.staleness, staleness),
384        }
385    }
386}
387
388#[derive(Debug, Clone)]
389struct WithDistance<T>(Distance, T);
390
391impl<T> PartialEq for WithDistance<T> {
392    fn eq(&self, other: &Self) -> bool {
393        self.0 == other.0
394    }
395}
396
397impl<T> PartialEq<Distance> for WithDistance<T> {
398    fn eq(&self, other: &Distance) -> bool {
399        self.0 == *other
400    }
401}
402
403impl<T> Eq for WithDistance<T> {}
404
405impl<T> PartialOrd for WithDistance<T> {
406    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
407        self.0.partial_cmp(&other.0)
408    }
409}
410
411impl<T: Eq> PartialOrd<Distance> for WithDistance<T> {
412    fn partial_cmp(&self, other: &Distance) -> Option<std::cmp::Ordering> {
413        self.0.partial_cmp(other)
414    }
415}
416
417impl<T: Eq> Ord for WithDistance<T> {
418    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
419        self.0.cmp(&other.0)
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    use bumpalo::Bump;
428    use pretty_assertions::assert_eq;
429    use rust_decimal_macros::dec;
430
431    #[test]
432    fn price_db_computes_direct_price() {
433        let arena = Bump::new();
434        let mut ctx = ReportContext::new(&arena);
435        let chf = ctx.commodities.ensure("CHF");
436        let eur = ctx.commodities.ensure("EUR");
437        let mut builder = PriceRepositoryBuilder::default();
438        builder.insert_price(
439            PriceSource::Ledger,
440            PriceEvent {
441                date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
442                price_x: SingleAmount::from_value(eur, dec!(1)),
443                price_y: SingleAmount::from_value(chf, dec!(0.8)),
444            },
445        );
446
447        let db = builder.build_naive();
448
449        // before the event date, we can't convert the value, thus see Right.
450        let got = db.convert(
451            &ctx,
452            SingleAmount::from_value(eur, dec!(1)),
453            chf,
454            NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
455        );
456        assert_eq!(got, Err(SingleAmount::from_value(eur, dec!(1))));
457
458        let got = db.convert(
459            &ctx,
460            SingleAmount::from_value(chf, dec!(10)),
461            eur,
462            NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
463        );
464        assert_eq!(got, Err(SingleAmount::from_value(chf, dec!(10))));
465
466        let got = db.convert(
467            &ctx,
468            SingleAmount::from_value(eur, dec!(1.0)),
469            chf,
470            NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
471        );
472        assert_eq!(got, Ok(SingleAmount::from_value(chf, dec!(0.8))));
473
474        let got = db.convert(
475            &ctx,
476            SingleAmount::from_value(chf, dec!(10.0)),
477            eur,
478            NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
479        );
480        assert_eq!(got, Ok(SingleAmount::from_value(eur, dec!(12.5))));
481    }
482
483    #[test]
484    fn price_db_computes_indirect_price() {
485        let arena = Bump::new();
486        let mut ctx = ReportContext::new(&arena);
487        let chf = ctx.commodities.ensure("CHF");
488        let eur = ctx.commodities.ensure("EUR");
489        let usd = ctx.commodities.ensure("USD");
490        let jpy = ctx.commodities.ensure("JPY");
491        let mut builder = PriceRepositoryBuilder::default();
492
493        builder.insert_price(
494            PriceSource::Ledger,
495            PriceEvent {
496                date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
497                price_x: SingleAmount::from_value(chf, dec!(0.8)),
498                price_y: SingleAmount::from_value(eur, dec!(1)),
499            },
500        );
501        builder.insert_price(
502            PriceSource::Ledger,
503            PriceEvent {
504                date: NaiveDate::from_ymd_opt(2024, 10, 2).unwrap(),
505                price_x: SingleAmount::from_value(eur, dec!(0.8)),
506                price_y: SingleAmount::from_value(usd, dec!(1)),
507            },
508        );
509        builder.insert_price(
510            PriceSource::Ledger,
511            PriceEvent {
512                date: NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
513                price_x: SingleAmount::from_value(jpy, dec!(100)),
514                price_y: SingleAmount::from_value(usd, dec!(1)),
515            },
516        );
517
518        // 1 EUR = 0.8 CHF
519        // 1 USD = 0.8 EUR
520        // 1 USD = 100 JPY
521        // 1 CHF == 5/4 EUR == (5/4)*(5/4) USD == 156.25 JPY
522
523        let db = builder.build_naive();
524
525        let got = db.convert(
526            &ctx,
527            SingleAmount::from_value(chf, dec!(1)),
528            jpy,
529            NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
530        );
531        assert_eq!(got, Ok(SingleAmount::from_value(jpy, dec!(156.25))));
532    }
533
534    #[test]
535    fn price_db_load_overrides_ledger_price() {
536        let price_db =
537            Path::new(env!("CARGO_MANIFEST_DIR")).join("../testdata/report/price_db.txt");
538        let arena = Bump::new();
539        let mut ctx = ReportContext::new(&arena);
540        let chf = ctx.commodities.ensure("CHF");
541        let eur = ctx.commodities.ensure("EUR");
542        let mut builder = PriceRepositoryBuilder::default();
543
544        builder.insert_price(
545            PriceSource::Ledger,
546            PriceEvent {
547                date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
548                price_x: SingleAmount::from_value(chf, dec!(0.8)),
549                price_y: SingleAmount::from_value(eur, dec!(1)),
550            },
551        );
552
553        builder
554            .load_price_db(&mut ctx, &load::ProdFileSystem, &price_db)
555            .unwrap();
556
557        let is_in_scope = |event: &PriceEvent<'_>| {
558            event.date == NaiveDate::from_ymd_opt(2024, 1, 31).unwrap()
559                && ((event.price_x.commodity == chf && event.price_y.commodity == eur)
560                    || (event.price_x.commodity == eur && event.price_y.commodity == chf))
561        };
562        let got: Vec<_> = builder.iter_events().collect();
563        assert_eq!(got.len(), 17 * 2);
564        assert!(
565            got.iter()
566                .all(|(source, _)| *source == PriceSource::PriceDB)
567        );
568        let mut filtered: Vec<_> = got
569            .into_iter()
570            .map(|(_, event)| event)
571            .filter(is_in_scope)
572            .collect();
573        filtered.sort_by_key(|x| x.sort_key());
574        let want = vec![
575            PriceEvent {
576                date: NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
577                price_x: SingleAmount::from_value(chf, Decimal::ONE),
578                price_y: SingleAmount::from_value(eur, Decimal::ONE / dec!(0.9348)),
579            },
580            PriceEvent {
581                date: NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
582                price_x: SingleAmount::from_value(eur, dec!(1)),
583                price_y: SingleAmount::from_value(chf, dec!(0.9348)),
584            },
585        ];
586        assert_eq!(want, filtered);
587    }
588}