1use std::{
4 collections::{BinaryHeap, HashMap},
5 path::Path,
6};
7
8use chrono::{NaiveDate, TimeDelta};
9use rust_decimal::Decimal;
10
11use crate::{
12 parse,
13 report::commodity::{CommodityMap, CommodityTag, OwnedCommodity},
14};
15
16use super::{
17 context::ReportContext,
18 eval::{Amount, SingleAmount},
19};
20
21#[derive(Debug, thiserror::Error)]
22pub enum LoadError {
23 #[error("failed to perform IO")]
24 IO(#[from] std::io::Error),
25 #[error("failed to parse price DB entry: {0}")]
26 Parse(#[from] parse::ParseError),
27}
28
29#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
34pub(super) enum PriceSource {
35 Ledger,
36 PriceDB,
37}
38
39#[derive(Debug)]
40struct Entry(PriceSource, Vec<(NaiveDate, Decimal)>);
41
42#[derive(Debug, Default)]
44pub(super) struct PriceRepositoryBuilder<'ctx> {
45 records: HashMap<CommodityTag<'ctx>, HashMap<CommodityTag<'ctx>, Entry>>,
46}
47
48#[derive(Debug, PartialEq, Eq)]
50pub(super) struct PriceEvent<'ctx> {
51 pub date: NaiveDate,
52 pub price_x: SingleAmount<'ctx>,
53 pub price_y: SingleAmount<'ctx>,
54}
55
56#[cfg(test)]
57impl<'ctx> PriceEvent<'ctx> {
58 fn sort_key(&self) -> (NaiveDate, usize, usize, &Decimal, &Decimal) {
59 let PriceEvent {
60 date,
61 price_x:
62 SingleAmount {
63 value: value_x,
64 commodity: commodity_x,
65 },
66 price_y:
67 SingleAmount {
68 value: value_y,
69 commodity: commodity_y,
70 },
71 } = self;
72 (
73 *date,
74 commodity_x.as_index(),
75 commodity_y.as_index(),
76 value_x,
77 value_y,
78 )
79 }
80}
81
82impl<'ctx> PriceRepositoryBuilder<'ctx> {
83 pub fn insert_price(&mut self, source: PriceSource, event: PriceEvent<'ctx>) {
84 if event.price_x.commodity == event.price_y.commodity {
85 log::error!("price log should not contain the self-mention rate");
87 }
88 self.insert_impl(source, event.date, event.price_x, event.price_y);
89 self.insert_impl(source, event.date, event.price_y, event.price_x);
90 }
91
92 fn insert_impl(
93 &mut self,
94 source: PriceSource,
95 date: NaiveDate,
96 price_of: SingleAmount<'ctx>,
97 price_with: SingleAmount<'ctx>,
98 ) {
99 let Entry(stored_source, entries): &mut _ = self
100 .records
101 .entry(price_with.commodity)
102 .or_default()
103 .entry(price_of.commodity)
104 .or_insert(Entry(PriceSource::Ledger, Vec::new()));
105 if *stored_source < source {
106 *stored_source = source;
107 entries.clear();
108 }
109 entries.push((date, price_with.value / price_of.value));
115 }
116
117 pub fn load_price_db(
119 &mut self,
120 ctx: &mut ReportContext<'ctx>,
121 path: &Path,
122 ) -> Result<(), LoadError> {
123 let content = std::fs::read_to_string(path)?;
126 for entry in parse::price::parse_price_db(&parse::ParseOptions::default(), &content) {
127 let (_, entry) = entry?;
128 let target = ctx.commodities.ensure(entry.target.as_ref());
133 let rate: SingleAmount<'ctx> = SingleAmount::from_value(
134 entry.rate.value.value,
135 ctx.commodities.ensure(&entry.rate.commodity),
136 );
137 self.insert_price(
138 PriceSource::PriceDB,
139 PriceEvent {
140 price_x: SingleAmount::from_value(Decimal::ONE, target),
141 price_y: rate,
142 date: entry.datetime.date(),
143 },
144 );
145 }
146 Ok(())
147 }
148
149 #[cfg(test)]
150 pub fn into_events(self) -> Vec<PriceEvent<'ctx>> {
151 let mut ret = Vec::new();
152 for (price_with, v) in self.records {
153 for (price_of, Entry(_, v)) in v {
154 for (date, rate) in v {
155 ret.push(PriceEvent {
156 price_x: SingleAmount::from_value(Decimal::ONE, price_of),
157 price_y: SingleAmount::from_value(rate, price_with),
158 date,
159 });
160 }
161 }
162 }
163 ret.sort_by(|x, y| x.sort_key().cmp(&y.sort_key()));
164 ret
165 }
166
167 pub fn build(self) -> PriceRepository<'ctx> {
168 PriceRepository::new(self.build_naive())
169 }
170
171 fn build_naive(mut self) -> NaivePriceRepository<'ctx> {
172 self.records
173 .values_mut()
174 .for_each(|x| x.values_mut().for_each(|x| x.1.sort()));
175 NaivePriceRepository {
176 records: self.records,
177 }
178 }
179}
180
181#[derive(Debug, thiserror::Error)]
182pub enum ConversionError {
183 #[error("commodity rate {0} into {1} at {2} not found")]
184 RateNotFound(OwnedCommodity, OwnedCommodity, NaiveDate),
185}
186
187pub fn convert_amount<'ctx>(
189 ctx: &ReportContext<'ctx>,
190 price_repos: &mut PriceRepository<'ctx>,
191 amount: &Amount<'ctx>,
192 commodity_with: CommodityTag<'ctx>,
193 date: NaiveDate,
194) -> Result<Amount<'ctx>, ConversionError> {
195 let mut result = Amount::zero();
196 for v in amount.iter() {
197 result += price_repos.convert_single(ctx, v, commodity_with, date)?;
198 }
199 Ok(result)
200}
201
202#[derive(Debug)]
204pub struct PriceRepository<'ctx> {
205 inner: NaivePriceRepository<'ctx>,
206 cache: HashMap<(CommodityTag<'ctx>, NaiveDate), CommodityMap<WithDistance<Decimal>>>,
210}
211
212impl<'ctx> PriceRepository<'ctx> {
213 fn new(inner: NaivePriceRepository<'ctx>) -> Self {
214 Self {
215 inner,
216 cache: HashMap::new(),
217 }
218 }
219
220 pub fn convert_single(
224 &mut self,
225 ctx: &ReportContext<'ctx>,
226 value: SingleAmount<'ctx>,
227 commodity_with: CommodityTag<'ctx>,
228 date: NaiveDate,
229 ) -> Result<SingleAmount<'ctx>, ConversionError> {
230 if value.commodity == commodity_with {
231 return Ok(value);
232 }
233 let rate = self
234 .cache
235 .entry((commodity_with, date))
236 .or_insert_with(|| self.inner.compute_price_table(ctx, commodity_with, date))
237 .get(value.commodity);
238 match rate {
239 Some(WithDistance(_, rate)) => {
240 Ok(SingleAmount::from_value(value.value * rate, commodity_with))
241 }
242 None => Err(ConversionError::RateNotFound(
243 value.commodity.to_owned_lossy(&ctx.commodities),
244 commodity_with.to_owned_lossy(&ctx.commodities),
245 date,
246 )),
247 }
248 }
249}
250
251#[derive(Debug)]
252struct NaivePriceRepository<'ctx> {
253 records: HashMap<CommodityTag<'ctx>, HashMap<CommodityTag<'ctx>, Entry>>,
257}
258
259impl<'ctx> NaivePriceRepository<'ctx> {
260 #[cfg(test)]
262 fn convert(
263 &self,
264 ctx: &ReportContext<'ctx>,
265 value: SingleAmount<'ctx>,
266 commodity_with: CommodityTag<'ctx>,
267 date: NaiveDate,
268 ) -> Result<SingleAmount<'ctx>, SingleAmount<'ctx>> {
269 if value.commodity == commodity_with {
270 return Ok(value);
271 }
272 let rate = self
273 .compute_price_table(ctx, commodity_with, date)
274 .get(value.commodity)
275 .map(|x| x.1);
276 match rate {
277 Some(rate) => Ok(SingleAmount::from_value(value.value * rate, commodity_with)),
278 None => Err(value),
279 }
280 }
281
282 fn compute_price_table(
283 &self,
284 ctx: &ReportContext<'ctx>,
285 price_with: CommodityTag<'ctx>,
286 date: NaiveDate,
287 ) -> CommodityMap<WithDistance<Decimal>> {
288 let mut queue: BinaryHeap<WithDistance<(CommodityTag<'ctx>, Decimal)>> = BinaryHeap::new();
290 let mut distances: CommodityMap<WithDistance<Decimal>> =
291 CommodityMap::with_capacity(ctx.commodities.len());
292 queue.push(WithDistance(
293 Distance {
294 num_ledger_conversions: 0,
295 num_all_conversions: 0,
296 staleness: TimeDelta::zero(),
297 },
298 (price_with, Decimal::ONE),
299 ));
300 while let Some(curr) = queue.pop() {
301 log::debug!("curr: {:?}", curr);
302 let WithDistance(curr_dist, (prev, prev_rate)) = curr;
303 if let Some(WithDistance(prev_dist, _)) = distances.get(prev) {
304 if *prev_dist < curr_dist {
305 log::debug!(
306 "no need to update, prev_dist {:?} is smaller than curr_dist {:?}",
307 prev_dist,
308 curr_dist
309 );
310 continue;
311 }
312 }
313 for (j, Entry(source, rates)) in match self.records.get(&prev) {
314 None => continue,
315 Some(x) => x,
316 } {
317 let bound = rates.partition_point(|(record_date, _)| record_date <= &date);
318 log::debug!(
319 "found next commodity #{} with date bound {}",
320 j.as_index(),
321 bound
322 );
323 if bound == 0 {
324 continue;
327 }
328 let (record_date, rate) = rates[bound - 1];
329 let next_dist = curr_dist.extend(*source, date - record_date);
330 let rate = prev_rate * rate;
331 let next = WithDistance(next_dist.clone(), (*j, rate));
332 let e: &mut Option<_> = distances.get_mut(*j);
333 let updated = match e.as_mut() {
334 Some(e) => {
335 if *e <= next_dist {
336 false
337 } else {
338 *e = WithDistance(next_dist, rate);
339 true
340 }
341 }
342 None => {
343 *e = Some(WithDistance(next_dist, rate));
344 true
345 }
346 };
347 if !updated {
348 continue;
349 }
350 queue.push(next);
351 }
352 }
353 distances
354 }
355}
356
357#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
362struct Distance {
363 num_ledger_conversions: usize,
367 num_all_conversions: usize,
369 staleness: TimeDelta,
371}
372
373impl Distance {
374 fn extend(&self, source: PriceSource, staleness: TimeDelta) -> Self {
375 let num_ledger_conversions = self.num_ledger_conversions
376 + match source {
377 PriceSource::Ledger => 1,
378 PriceSource::PriceDB => 0,
379 };
380 Self {
381 num_ledger_conversions,
382 num_all_conversions: self.num_all_conversions + 1,
383 staleness: std::cmp::max(self.staleness, staleness),
384 }
385 }
386}
387
388#[derive(Debug, Clone)]
389struct WithDistance<T>(Distance, T);
390
391impl<T> PartialEq for WithDistance<T> {
392 fn eq(&self, other: &Self) -> bool {
393 self.0 == other.0
394 }
395}
396
397impl<T> PartialEq<Distance> for WithDistance<T> {
398 fn eq(&self, other: &Distance) -> bool {
399 self.0 == *other
400 }
401}
402
403impl<T> Eq for WithDistance<T> {}
404
405impl<T> PartialOrd for WithDistance<T> {
406 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
407 self.0.partial_cmp(&other.0)
408 }
409}
410
411impl<T: Eq> PartialOrd<Distance> for WithDistance<T> {
412 fn partial_cmp(&self, other: &Distance) -> Option<std::cmp::Ordering> {
413 self.0.partial_cmp(other)
414 }
415}
416
417impl<T: Eq> Ord for WithDistance<T> {
418 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
419 self.0.cmp(&other.0)
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 use bumpalo::Bump;
428 use rust_decimal_macros::dec;
429
430 #[test]
431 fn price_db_computes_direct_price() {
432 let arena = Bump::new();
433 let mut ctx = ReportContext::new(&arena);
434 let chf = ctx.commodities.ensure("CHF");
435 let eur = ctx.commodities.ensure("EUR");
436 let mut builder = PriceRepositoryBuilder::default();
437 builder.insert_price(
438 PriceSource::Ledger,
439 PriceEvent {
440 date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
441 price_x: SingleAmount::from_value(dec!(1), eur),
442 price_y: SingleAmount::from_value(dec!(0.8), chf),
443 },
444 );
445
446 let db = builder.build_naive();
447
448 let got = db.convert(
450 &ctx,
451 SingleAmount::from_value(dec!(1), eur),
452 chf,
453 NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
454 );
455 assert_eq!(got, Err(SingleAmount::from_value(dec!(1), eur)));
456
457 let got = db.convert(
458 &ctx,
459 SingleAmount::from_value(dec!(10), chf),
460 eur,
461 NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
462 );
463 assert_eq!(got, Err(SingleAmount::from_value(dec!(10), chf)));
464
465 let got = db.convert(
466 &ctx,
467 SingleAmount::from_value(dec!(1.0), eur),
468 chf,
469 NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
470 );
471 assert_eq!(got, Ok(SingleAmount::from_value(dec!(0.8), chf)));
472
473 let got = db.convert(
474 &ctx,
475 SingleAmount::from_value(dec!(10.0), chf),
476 eur,
477 NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
478 );
479 assert_eq!(got, Ok(SingleAmount::from_value(dec!(12.5), eur)));
480 }
481
482 #[test]
483 fn price_db_computes_indirect_price() {
484 let arena = Bump::new();
485 let mut ctx = ReportContext::new(&arena);
486 let chf = ctx.commodities.ensure("CHF");
487 let eur = ctx.commodities.ensure("EUR");
488 let usd = ctx.commodities.ensure("USD");
489 let jpy = ctx.commodities.ensure("JPY");
490 let mut builder = PriceRepositoryBuilder::default();
491
492 builder.insert_price(
493 PriceSource::Ledger,
494 PriceEvent {
495 date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
496 price_x: SingleAmount::from_value(dec!(0.8), chf),
497 price_y: SingleAmount::from_value(dec!(1), eur),
498 },
499 );
500 builder.insert_price(
501 PriceSource::Ledger,
502 PriceEvent {
503 date: NaiveDate::from_ymd_opt(2024, 10, 2).unwrap(),
504 price_x: SingleAmount::from_value(dec!(0.8), eur),
505 price_y: SingleAmount::from_value(dec!(1), usd),
506 },
507 );
508 builder.insert_price(
509 PriceSource::Ledger,
510 PriceEvent {
511 date: NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
512 price_x: SingleAmount::from_value(dec!(100), jpy),
513 price_y: SingleAmount::from_value(dec!(1), usd),
514 },
515 );
516
517 let db = builder.build_naive();
523
524 let got = db.convert(
525 &ctx,
526 SingleAmount::from_value(dec!(1), chf),
527 jpy,
528 NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
529 );
530 assert_eq!(got, Ok(SingleAmount::from_value(dec!(156.25), jpy)));
531 }
532}