1use std::{
4 collections::{hash_map, BinaryHeap, HashMap},
5 path::Path,
6};
7
8use chrono::{NaiveDate, TimeDelta};
9use rust_decimal::Decimal;
10
11use crate::parse;
12
13use super::{
14 commodity::Commodity,
15 context::ReportContext,
16 eval::{Amount, SingleAmount},
17};
18
19#[derive(Debug, thiserror::Error)]
20pub enum LoadError {
21 #[error("failed to perform IO")]
22 IO(#[from] std::io::Error),
23 #[error("failed to parse price DB entry: {0}")]
24 Parse(#[from] parse::ParseError),
25}
26
27#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
32pub(super) enum PriceSource {
33 Ledger,
34 PriceDB,
35}
36
37#[derive(Debug)]
38struct Entry(PriceSource, Vec<(NaiveDate, Decimal)>);
39
40#[derive(Debug, Default)]
42pub(super) struct PriceRepositoryBuilder<'ctx> {
43 records: HashMap<Commodity<'ctx>, HashMap<Commodity<'ctx>, Entry>>,
44}
45
46#[derive(Debug, PartialEq, Eq)]
48pub(super) struct PriceEvent<'ctx> {
49 pub date: NaiveDate,
50 pub price_x: SingleAmount<'ctx>,
51 pub price_y: SingleAmount<'ctx>,
52}
53
54#[cfg(test)]
55impl<'ctx> PriceEvent<'ctx> {
56 fn sort_key<'a>(&'a self) -> (NaiveDate, &'ctx str, &'ctx str, &'a Decimal, &'a Decimal) {
57 let PriceEvent {
58 date,
59 price_x:
60 SingleAmount {
61 value: value_x,
62 commodity: commodity_x,
63 },
64 price_y:
65 SingleAmount {
66 value: value_y,
67 commodity: commodity_y,
68 },
69 } = self;
70 (
71 *date,
72 commodity_x.as_str(),
73 commodity_y.as_str(),
74 value_x,
75 value_y,
76 )
77 }
78}
79
80impl<'ctx> PriceRepositoryBuilder<'ctx> {
81 pub fn insert_price(&mut self, source: PriceSource, event: PriceEvent<'ctx>) {
82 if event.price_x.commodity == event.price_y.commodity {
83 log::error!("price log should not contain the self-mention rate");
85 }
86 self.insert_impl(source, event.date, event.price_x, event.price_y);
87 self.insert_impl(source, event.date, event.price_y, event.price_x);
88 }
89
90 fn insert_impl(
91 &mut self,
92 source: PriceSource,
93 date: NaiveDate,
94 price_of: SingleAmount<'ctx>,
95 price_with: SingleAmount<'ctx>,
96 ) {
97 let Entry(stored_source, entries): &mut _ = self
98 .records
99 .entry(price_with.commodity)
100 .or_default()
101 .entry(price_of.commodity)
102 .or_insert(Entry(PriceSource::Ledger, Vec::new()));
103 if *stored_source < source {
104 *stored_source = source;
105 entries.clear();
106 }
107 entries.push((date, price_with.value / price_of.value));
113 }
114
115 pub fn load_price_db(
117 &mut self,
118 ctx: &mut ReportContext<'ctx>,
119 path: &Path,
120 ) -> Result<(), LoadError> {
121 let content = std::fs::read_to_string(path)?;
124 for entry in parse::price::parse_price_db(&parse::ParseOptions::default(), &content) {
125 let (_, entry) = entry?;
126 let target = ctx.commodities.ensure(entry.target.as_ref());
130 let rate: SingleAmount<'ctx> = SingleAmount::from_value(
131 entry.rate.value.value,
132 ctx.commodities.ensure(&entry.rate.commodity),
133 );
134 self.insert_price(
135 PriceSource::PriceDB,
136 PriceEvent {
137 price_x: SingleAmount::from_value(Decimal::ONE, target),
138 price_y: rate,
139 date: entry.datetime.date(),
140 },
141 );
142 }
143 Ok(())
144 }
145
146 #[cfg(test)]
147 pub fn into_events(self) -> Vec<PriceEvent<'ctx>> {
148 let mut ret = Vec::new();
149 for (price_with, v) in self.records {
150 for (price_of, Entry(_, v)) in v {
151 for (date, rate) in v {
152 ret.push(PriceEvent {
153 price_x: SingleAmount::from_value(Decimal::ONE, price_of),
154 price_y: SingleAmount::from_value(rate, price_with),
155 date,
156 });
157 }
158 }
159 }
160 ret.sort_by(|x, y| x.sort_key().cmp(&y.sort_key()));
161 ret
162 }
163
164 pub fn build(self) -> PriceRepository<'ctx> {
165 PriceRepository::new(self.build_naive())
166 }
167
168 fn build_naive(mut self) -> NaivePriceRepository<'ctx> {
169 self.records
170 .values_mut()
171 .for_each(|x| x.values_mut().for_each(|x| x.1.sort()));
172 NaivePriceRepository {
173 records: self.records,
174 }
175 }
176}
177
178#[derive(Debug, thiserror::Error)]
179pub enum ConversionError<'ctx> {
180 #[error("commodity rate {0} into {1} at {2} not found")]
181 RateNotFound(SingleAmount<'ctx>, Commodity<'ctx>, NaiveDate),
182}
183
184pub fn convert_amount<'ctx>(
186 price_repos: &mut PriceRepository<'ctx>,
187 amount: &Amount<'ctx>,
188 commodity_with: Commodity<'ctx>,
189 date: NaiveDate,
190) -> Result<Amount<'ctx>, ConversionError<'ctx>> {
191 let mut result = Amount::zero();
192 for v in amount.iter() {
193 result += price_repos.convert_single(v, commodity_with, date)?;
194 }
195 Ok(result)
196}
197
198#[derive(Debug)]
200pub struct PriceRepository<'ctx> {
201 inner: NaivePriceRepository<'ctx>,
202 cache: HashMap<(Commodity<'ctx>, NaiveDate), HashMap<Commodity<'ctx>, WithDistance<Decimal>>>,
206}
207
208impl<'ctx> PriceRepository<'ctx> {
209 fn new(inner: NaivePriceRepository<'ctx>) -> Self {
210 Self {
211 inner,
212 cache: HashMap::new(),
213 }
214 }
215
216 pub fn convert_single(
220 &mut self,
221 value: SingleAmount<'ctx>,
222 commodity_with: Commodity<'ctx>,
223 date: NaiveDate,
224 ) -> Result<SingleAmount<'ctx>, ConversionError<'ctx>> {
225 if value.commodity == commodity_with {
226 return Ok(value);
227 }
228 let rate = self
229 .cache
230 .entry((commodity_with, date))
231 .or_insert_with(|| self.inner.compute_price_table(commodity_with, date))
232 .get(&value.commodity);
233 match rate {
234 Some(WithDistance(_, rate)) => {
235 Ok(SingleAmount::from_value(value.value * rate, commodity_with))
236 }
237 None => Err(ConversionError::RateNotFound(value, commodity_with, date)),
238 }
239 }
240}
241
242#[derive(Debug)]
243struct NaivePriceRepository<'ctx> {
244 records: HashMap<Commodity<'ctx>, HashMap<Commodity<'ctx>, Entry>>,
248}
249
250impl<'ctx> NaivePriceRepository<'ctx> {
251 #[cfg(test)]
253 fn convert(
254 &self,
255 value: SingleAmount<'ctx>,
256 commodity_with: Commodity<'ctx>,
257 date: NaiveDate,
258 ) -> Result<SingleAmount<'ctx>, SingleAmount<'ctx>> {
259 if value.commodity == commodity_with {
260 return Ok(value);
261 }
262 let rate = self
263 .compute_price_table(commodity_with, date)
264 .get(&value.commodity)
265 .map(|x| x.1);
266 match rate {
267 Some(rate) => Ok(SingleAmount::from_value(value.value * rate, commodity_with)),
268 None => Err(value),
269 }
270 }
271
272 fn compute_price_table(
273 &self,
274 price_with: Commodity<'ctx>,
275 date: NaiveDate,
276 ) -> HashMap<Commodity<'ctx>, WithDistance<Decimal>> {
277 let mut queue: BinaryHeap<WithDistance<(Commodity<'ctx>, Decimal)>> = BinaryHeap::new();
279 let mut distances: HashMap<Commodity<'ctx>, WithDistance<Decimal>> = HashMap::new();
280 queue.push(WithDistance(
281 Distance {
282 num_ledger_conversions: 0,
283 num_all_conversions: 0,
284 staleness: TimeDelta::zero(),
285 },
286 (price_with, Decimal::ONE),
287 ));
288 while let Some(curr) = queue.pop() {
289 log::debug!("curr: {:?}", curr);
290 let WithDistance(curr_dist, (prev, prev_rate)) = curr;
291 if let Some(WithDistance(prev_dist, _)) = distances.get(&prev) {
292 if *prev_dist < curr_dist {
293 log::debug!(
294 "no need to update, prev_dist {:?} is smaller than curr_dist {:?}",
295 prev_dist,
296 curr_dist
297 );
298 continue;
299 }
300 }
301 for (j, Entry(source, rates)) in match self.records.get(&prev) {
302 None => continue,
303 Some(x) => x,
304 } {
305 let bound = rates.partition_point(|(record_date, _)| record_date <= &date);
306 log::debug!(
307 "found next commodity {} with date bound {}",
308 j.as_str(),
309 bound
310 );
311 if bound == 0 {
312 continue;
315 }
316 let (record_date, rate) = rates[bound - 1];
317 let next_dist = curr_dist.extend(*source, date - record_date);
318 let rate = prev_rate * rate;
319 let next = WithDistance(next_dist.clone(), (*j, rate));
320 let updated = match distances.entry(*j) {
321 hash_map::Entry::Occupied(mut e) => {
322 if e.get() <= &next_dist {
323 false
324 } else {
325 e.insert(WithDistance(next_dist, rate));
326 true
327 }
328 }
329 hash_map::Entry::Vacant(e) => {
330 e.insert(WithDistance(next_dist, rate));
331 true
332 }
333 };
334 if !updated {
335 continue;
336 }
337 queue.push(next);
338 }
339 }
340 distances
341 }
342}
343
344#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
349struct Distance {
350 num_ledger_conversions: usize,
354 num_all_conversions: usize,
356 staleness: TimeDelta,
358}
359
360impl Distance {
361 fn extend(&self, source: PriceSource, staleness: TimeDelta) -> Self {
362 let num_ledger_conversions = self.num_ledger_conversions
363 + match source {
364 PriceSource::Ledger => 1,
365 PriceSource::PriceDB => 0,
366 };
367 Self {
368 num_ledger_conversions,
369 num_all_conversions: self.num_all_conversions + 1,
370 staleness: std::cmp::max(self.staleness, staleness),
371 }
372 }
373}
374
375#[derive(Debug)]
376struct WithDistance<T>(Distance, T);
377
378impl<T> PartialEq for WithDistance<T> {
379 fn eq(&self, other: &Self) -> bool {
380 self.0 == other.0
381 }
382}
383
384impl<T> PartialEq<Distance> for WithDistance<T> {
385 fn eq(&self, other: &Distance) -> bool {
386 self.0 == *other
387 }
388}
389
390impl<T> Eq for WithDistance<T> {}
391
392impl<T> PartialOrd for WithDistance<T> {
393 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
394 self.0.partial_cmp(&other.0)
395 }
396}
397
398impl<T: Eq> PartialOrd<Distance> for WithDistance<T> {
399 fn partial_cmp(&self, other: &Distance) -> Option<std::cmp::Ordering> {
400 self.0.partial_cmp(other)
401 }
402}
403
404impl<T: Eq> Ord for WithDistance<T> {
405 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
406 self.0.cmp(&other.0)
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413
414 use bumpalo::Bump;
415 use rust_decimal_macros::dec;
416
417 #[test]
418 fn price_db_computes_direct_price() {
419 let arena = Bump::new();
420 let mut ctx = ReportContext::new(&arena);
421 let chf = ctx.commodities.ensure("CHF");
422 let eur = ctx.commodities.ensure("EUR");
423 let mut builder = PriceRepositoryBuilder::default();
424 builder.insert_price(
425 PriceSource::Ledger,
426 PriceEvent {
427 date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
428 price_x: SingleAmount::from_value(dec!(1), eur),
429 price_y: SingleAmount::from_value(dec!(0.8), chf),
430 },
431 );
432
433 let db = builder.build_naive();
434
435 let got = db.convert(
437 SingleAmount::from_value(dec!(1), eur),
438 chf,
439 NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
440 );
441 assert_eq!(got, Err(SingleAmount::from_value(dec!(1), eur)));
442
443 let got = db.convert(
444 SingleAmount::from_value(dec!(10), chf),
445 eur,
446 NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
447 );
448 assert_eq!(got, Err(SingleAmount::from_value(dec!(10), chf)));
449
450 let got = db.convert(
451 SingleAmount::from_value(dec!(1.0), eur),
452 chf,
453 NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
454 );
455 assert_eq!(got, Ok(SingleAmount::from_value(dec!(0.8), chf)));
456
457 let got = db.convert(
458 SingleAmount::from_value(dec!(10.0), chf),
459 eur,
460 NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
461 );
462 assert_eq!(got, Ok(SingleAmount::from_value(dec!(12.5), eur)));
463 }
464
465 #[test]
466 fn price_db_computes_indirect_price() {
467 let arena = Bump::new();
468 let mut ctx = ReportContext::new(&arena);
469 let chf = ctx.commodities.ensure("CHF");
470 let eur = ctx.commodities.ensure("EUR");
471 let usd = ctx.commodities.ensure("USD");
472 let jpy = ctx.commodities.ensure("JPY");
473 let mut builder = PriceRepositoryBuilder::default();
474
475 builder.insert_price(
476 PriceSource::Ledger,
477 PriceEvent {
478 date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
479 price_x: SingleAmount::from_value(dec!(0.8), chf),
480 price_y: SingleAmount::from_value(dec!(1), eur),
481 },
482 );
483 builder.insert_price(
484 PriceSource::Ledger,
485 PriceEvent {
486 date: NaiveDate::from_ymd_opt(2024, 10, 2).unwrap(),
487 price_x: SingleAmount::from_value(dec!(0.8), eur),
488 price_y: SingleAmount::from_value(dec!(1), usd),
489 },
490 );
491 builder.insert_price(
492 PriceSource::Ledger,
493 PriceEvent {
494 date: NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
495 price_x: SingleAmount::from_value(dec!(100), jpy),
496 price_y: SingleAmount::from_value(dec!(1), usd),
497 },
498 );
499
500 let db = builder.build_naive();
506
507 let got = db.convert(
508 SingleAmount::from_value(dec!(1), chf),
509 jpy,
510 NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
511 );
512 assert_eq!(got, Ok(SingleAmount::from_value(dec!(156.25), jpy)));
513 }
514}