use std::collections::{BinaryHeap, HashMap};
use std::path::{Path, PathBuf};
use chrono::{NaiveDate, TimeDelta};
use rust_decimal::Decimal;
use crate::load;
use crate::parse;
use crate::report::commodity::{CommodityMap, CommodityTag, OwnedCommodity};
use super::context::ReportContext;
use super::eval::{Amount, SingleAmount};
#[derive(Debug, thiserror::Error)]
pub enum LoadError {
#[error("failed to load price DB file {0}")]
IO(PathBuf, #[source] std::io::Error),
#[error("failed to parse price DB file {0}")]
Parse(PathBuf, #[source] parse::ParseError),
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)]
pub(super) enum PriceSource {
Ledger,
PriceDB,
}
#[derive(Debug)]
struct Entry(PriceSource, Vec<(NaiveDate, Decimal)>);
#[derive(Debug, Default)]
pub(super) struct PriceRepositoryBuilder<'ctx> {
records: HashMap<CommodityTag<'ctx>, HashMap<CommodityTag<'ctx>, Entry>>,
}
#[derive(Debug, PartialEq, Eq)]
pub(super) struct PriceEvent<'ctx> {
pub date: NaiveDate,
pub price_x: SingleAmount<'ctx>,
pub price_y: SingleAmount<'ctx>,
}
#[cfg(test)]
impl<'ctx> PriceEvent<'ctx> {
fn sort_key(&self) -> (NaiveDate, usize, usize) {
let PriceEvent {
date,
price_x:
SingleAmount {
value: _,
commodity: commodity_x,
},
price_y:
SingleAmount {
value: _,
commodity: commodity_y,
},
} = self;
(*date, commodity_x.as_index(), commodity_y.as_index())
}
}
impl<'ctx> PriceRepositoryBuilder<'ctx> {
pub fn insert_price(&mut self, source: PriceSource, event: PriceEvent<'ctx>) {
if event.price_x.commodity == event.price_y.commodity {
log::error!("price log should not contain the self-mention rate");
}
self.insert_impl(source, event.date, event.price_x, event.price_y);
self.insert_impl(source, event.date, event.price_y, event.price_x);
}
fn insert_impl(
&mut self,
source: PriceSource,
date: NaiveDate,
price_of: SingleAmount<'ctx>,
price_with: SingleAmount<'ctx>,
) {
let Entry(stored_source, entries): &mut _ = self
.records
.entry(price_with.commodity)
.or_default()
.entry(price_of.commodity)
.or_insert(Entry(PriceSource::Ledger, Vec::new()));
if *stored_source < source {
*stored_source = source;
entries.clear();
}
entries.push((date, price_with.value / price_of.value));
}
pub fn load_price_db<F: load::FileSystem>(
&mut self,
ctx: &mut ReportContext<'ctx>,
filesystem: &F,
path: &Path,
) -> Result<(), LoadError> {
let content = filesystem
.file_content_utf8(path)
.map_err(|e| LoadError::IO(path.to_owned(), e))?;
for entry in parse::price::parse_price_db(&parse::ParseOptions::default(), &content) {
let (_, entry) = entry.map_err(|e| LoadError::Parse(path.to_owned(), e))?;
let target = ctx.commodities.ensure(entry.target.as_ref());
let rate: SingleAmount<'ctx> = SingleAmount::from_value(
ctx.commodities.ensure(&entry.rate.commodity),
entry.rate.value.value,
);
self.insert_price(
PriceSource::PriceDB,
PriceEvent {
price_x: SingleAmount::from_value(target, Decimal::ONE),
price_y: rate,
date: entry.datetime.date(),
},
);
}
Ok(())
}
#[cfg(test)]
pub fn iter_events(&self) -> impl Iterator<Item = (PriceSource, PriceEvent<'ctx>)> {
self.records.iter().flat_map(|(price_with, v)| {
v.iter().flat_map(|(price_of, Entry(source, v))| {
v.iter().map(|(date, rate)| {
(
*source,
PriceEvent {
price_x: SingleAmount::from_value(*price_of, Decimal::ONE),
price_y: SingleAmount::from_value(*price_with, *rate),
date: *date,
},
)
})
})
})
}
#[cfg(test)]
pub fn to_events(&self) -> Vec<PriceEvent<'ctx>> {
let mut ret: Vec<PriceEvent<'ctx>> =
self.iter_events().map(|(_source, event)| event).collect();
ret.sort_by_key(|x| x.sort_key());
ret
}
pub fn build(self) -> PriceRepository<'ctx> {
PriceRepository::new(self.build_naive())
}
fn build_naive(mut self) -> NaivePriceRepository<'ctx> {
self.records
.values_mut()
.for_each(|x| x.values_mut().for_each(|x| x.1.sort()));
NaivePriceRepository {
records: self.records,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum ConversionError {
#[error("commodity rate {0} into {1} at {2} not found")]
RateNotFound(OwnedCommodity, OwnedCommodity, NaiveDate),
}
pub fn convert_amount<'ctx>(
ctx: &ReportContext<'ctx>,
price_repos: &mut PriceRepository<'ctx>,
amount: &Amount<'ctx>,
commodity_with: CommodityTag<'ctx>,
date: NaiveDate,
) -> Result<Amount<'ctx>, ConversionError> {
let mut result = Amount::zero();
for v in amount.iter() {
result += price_repos.convert_single(ctx, v, commodity_with, date)?;
}
Ok(result)
}
#[derive(Debug)]
pub struct PriceRepository<'ctx> {
inner: NaivePriceRepository<'ctx>,
cache: HashMap<(CommodityTag<'ctx>, NaiveDate), CommodityMap<WithDistance<Decimal>>>,
}
impl<'ctx> PriceRepository<'ctx> {
fn new(inner: NaivePriceRepository<'ctx>) -> Self {
Self {
inner,
cache: HashMap::new(),
}
}
pub fn convert_single(
&mut self,
ctx: &ReportContext<'ctx>,
value: SingleAmount<'ctx>,
commodity_with: CommodityTag<'ctx>,
date: NaiveDate,
) -> Result<SingleAmount<'ctx>, ConversionError> {
if value.commodity == commodity_with {
return Ok(value);
}
let rate = self
.cache
.entry((commodity_with, date))
.or_insert_with(|| self.inner.compute_price_table(ctx, commodity_with, date))
.get(value.commodity);
match rate {
Some(WithDistance(_, rate)) => {
Ok(SingleAmount::from_value(commodity_with, value.value * rate))
}
None => Err(ConversionError::RateNotFound(
value.commodity.to_owned_lossy(&ctx.commodities),
commodity_with.to_owned_lossy(&ctx.commodities),
date,
)),
}
}
}
#[derive(Debug)]
struct NaivePriceRepository<'ctx> {
records: HashMap<CommodityTag<'ctx>, HashMap<CommodityTag<'ctx>, Entry>>,
}
impl<'ctx> NaivePriceRepository<'ctx> {
#[cfg(test)]
fn convert(
&self,
ctx: &ReportContext<'ctx>,
value: SingleAmount<'ctx>,
commodity_with: CommodityTag<'ctx>,
date: NaiveDate,
) -> Result<SingleAmount<'ctx>, SingleAmount<'ctx>> {
if value.commodity == commodity_with {
return Ok(value);
}
let rate = self
.compute_price_table(ctx, commodity_with, date)
.get(value.commodity)
.map(|x| x.1);
match rate {
Some(rate) => Ok(SingleAmount::from_value(commodity_with, value.value * rate)),
None => Err(value),
}
}
fn compute_price_table(
&self,
ctx: &ReportContext<'ctx>,
price_with: CommodityTag<'ctx>,
date: NaiveDate,
) -> CommodityMap<WithDistance<Decimal>> {
let mut queue: BinaryHeap<WithDistance<(CommodityTag<'ctx>, Decimal)>> = BinaryHeap::new();
let mut distances: CommodityMap<WithDistance<Decimal>> =
CommodityMap::with_capacity(ctx.commodities.len());
queue.push(WithDistance(
Distance {
num_ledger_conversions: 0,
num_all_conversions: 0,
staleness: TimeDelta::zero(),
},
(price_with, Decimal::ONE),
));
while let Some(curr) = queue.pop() {
log::debug!("curr: {:?}", curr);
let WithDistance(curr_dist, (prev, prev_rate)) = curr;
if let Some(WithDistance(prev_dist, _)) = distances.get(prev)
&& *prev_dist < curr_dist
{
log::debug!(
"no need to update, prev_dist {:?} is smaller than curr_dist {:?}",
prev_dist,
curr_dist
);
continue;
}
for (j, Entry(source, rates)) in match self.records.get(&prev) {
None => continue,
Some(x) => x,
} {
let bound = rates.partition_point(|(record_date, _)| record_date <= &date);
log::debug!(
"found next commodity #{} with date bound {}",
j.as_index(),
bound
);
if bound == 0 {
continue;
}
let (record_date, rate) = rates[bound - 1];
let next_dist = curr_dist.extend(*source, date - record_date);
let rate = prev_rate * rate;
let next = WithDistance(next_dist.clone(), (*j, rate));
let e: &mut Option<_> = distances.get_mut(*j);
let updated = match e.as_mut() {
Some(e) => {
if *e <= next_dist {
false
} else {
*e = WithDistance(next_dist, rate);
true
}
}
None => {
*e = Some(WithDistance(next_dist, rate));
true
}
};
if !updated {
continue;
}
queue.push(next);
}
}
distances
}
}
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone)]
struct Distance {
num_ledger_conversions: usize,
num_all_conversions: usize,
staleness: TimeDelta,
}
impl Distance {
fn extend(&self, source: PriceSource, staleness: TimeDelta) -> Self {
let num_ledger_conversions = self.num_ledger_conversions
+ match source {
PriceSource::Ledger => 1,
PriceSource::PriceDB => 0,
};
Self {
num_ledger_conversions,
num_all_conversions: self.num_all_conversions + 1,
staleness: std::cmp::max(self.staleness, staleness),
}
}
}
#[derive(Debug, Clone)]
struct WithDistance<T>(Distance, T);
impl<T> PartialEq for WithDistance<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T> PartialEq<Distance> for WithDistance<T> {
fn eq(&self, other: &Distance) -> bool {
self.0 == *other
}
}
impl<T> Eq for WithDistance<T> {}
impl<T> PartialOrd for WithDistance<T> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(&other.0)
}
}
impl<T: Eq> PartialOrd<Distance> for WithDistance<T> {
fn partial_cmp(&self, other: &Distance) -> Option<std::cmp::Ordering> {
self.0.partial_cmp(other)
}
}
impl<T: Eq> Ord for WithDistance<T> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.0.cmp(&other.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bumpalo::Bump;
use pretty_assertions::assert_eq;
use rust_decimal_macros::dec;
#[test]
fn price_db_computes_direct_price() {
let arena = Bump::new();
let mut ctx = ReportContext::new(&arena);
let chf = ctx.commodities.ensure("CHF");
let eur = ctx.commodities.ensure("EUR");
let mut builder = PriceRepositoryBuilder::default();
builder.insert_price(
PriceSource::Ledger,
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
price_x: SingleAmount::from_value(eur, dec!(1)),
price_y: SingleAmount::from_value(chf, dec!(0.8)),
},
);
let db = builder.build_naive();
let got = db.convert(
&ctx,
SingleAmount::from_value(eur, dec!(1)),
chf,
NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
);
assert_eq!(got, Err(SingleAmount::from_value(eur, dec!(1))));
let got = db.convert(
&ctx,
SingleAmount::from_value(chf, dec!(10)),
eur,
NaiveDate::from_ymd_opt(2024, 9, 30).unwrap(),
);
assert_eq!(got, Err(SingleAmount::from_value(chf, dec!(10))));
let got = db.convert(
&ctx,
SingleAmount::from_value(eur, dec!(1.0)),
chf,
NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
);
assert_eq!(got, Ok(SingleAmount::from_value(chf, dec!(0.8))));
let got = db.convert(
&ctx,
SingleAmount::from_value(chf, dec!(10.0)),
eur,
NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
);
assert_eq!(got, Ok(SingleAmount::from_value(eur, dec!(12.5))));
}
#[test]
fn price_db_computes_indirect_price() {
let arena = Bump::new();
let mut ctx = ReportContext::new(&arena);
let chf = ctx.commodities.ensure("CHF");
let eur = ctx.commodities.ensure("EUR");
let usd = ctx.commodities.ensure("USD");
let jpy = ctx.commodities.ensure("JPY");
let mut builder = PriceRepositoryBuilder::default();
builder.insert_price(
PriceSource::Ledger,
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 10, 1).unwrap(),
price_x: SingleAmount::from_value(chf, dec!(0.8)),
price_y: SingleAmount::from_value(eur, dec!(1)),
},
);
builder.insert_price(
PriceSource::Ledger,
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 10, 2).unwrap(),
price_x: SingleAmount::from_value(eur, dec!(0.8)),
price_y: SingleAmount::from_value(usd, dec!(1)),
},
);
builder.insert_price(
PriceSource::Ledger,
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
price_x: SingleAmount::from_value(jpy, dec!(100)),
price_y: SingleAmount::from_value(usd, dec!(1)),
},
);
let db = builder.build_naive();
let got = db.convert(
&ctx,
SingleAmount::from_value(chf, dec!(1)),
jpy,
NaiveDate::from_ymd_opt(2024, 10, 3).unwrap(),
);
assert_eq!(got, Ok(SingleAmount::from_value(jpy, dec!(156.25))));
}
#[test]
fn price_db_load_overrides_ledger_price() {
let price_db =
Path::new(env!("CARGO_MANIFEST_DIR")).join("../testdata/report/price_db.txt");
let arena = Bump::new();
let mut ctx = ReportContext::new(&arena);
let chf = ctx.commodities.ensure("CHF");
let eur = ctx.commodities.ensure("EUR");
let mut builder = PriceRepositoryBuilder::default();
builder.insert_price(
PriceSource::Ledger,
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 1, 1).unwrap(),
price_x: SingleAmount::from_value(chf, dec!(0.8)),
price_y: SingleAmount::from_value(eur, dec!(1)),
},
);
builder
.load_price_db(&mut ctx, &load::ProdFileSystem, &price_db)
.unwrap();
let is_in_scope = |event: &PriceEvent<'_>| {
event.date == NaiveDate::from_ymd_opt(2024, 1, 31).unwrap()
&& ((event.price_x.commodity == chf && event.price_y.commodity == eur)
|| (event.price_x.commodity == eur && event.price_y.commodity == chf))
};
let got: Vec<_> = builder.iter_events().collect();
assert_eq!(got.len(), 17 * 2);
assert!(
got.iter()
.all(|(source, _)| *source == PriceSource::PriceDB)
);
let mut filtered: Vec<_> = got
.into_iter()
.map(|(_, event)| event)
.filter(is_in_scope)
.collect();
filtered.sort_by_key(|x| x.sort_key());
let want = vec![
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
price_x: SingleAmount::from_value(chf, Decimal::ONE),
price_y: SingleAmount::from_value(eur, Decimal::ONE / dec!(0.9348)),
},
PriceEvent {
date: NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(),
price_x: SingleAmount::from_value(eur, dec!(1)),
price_y: SingleAmount::from_value(chf, dec!(0.9348)),
},
];
assert_eq!(want, filtered);
}
}