1use rust_decimal::Decimal;
7use rustledger_core::{Amount, Directive, InternedStr, NaiveDate, Price as PriceDirective};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone)]
12pub struct PriceEntry {
13 pub date: NaiveDate,
15 pub price: Decimal,
17 pub currency: InternedStr,
19}
20
21#[derive(Debug, Default)]
26pub struct PriceDatabase {
27 prices: HashMap<InternedStr, Vec<PriceEntry>>,
30}
31
32impl PriceDatabase {
33 pub fn new() -> Self {
35 Self {
36 prices: HashMap::new(),
37 }
38 }
39
40 pub fn from_directives(directives: &[Directive]) -> Self {
42 let mut db = Self::new();
43
44 for directive in directives {
45 if let Directive::Price(price) = directive {
46 db.add_price(price);
47 }
48 }
49
50 for entries in db.prices.values_mut() {
52 entries.sort_by_key(|e| e.date);
53 }
54
55 db
56 }
57
58 pub fn add_price(&mut self, price: &PriceDirective) {
60 let entry = PriceEntry {
61 date: price.date,
62 price: price.amount.number,
63 currency: price.amount.currency.clone(),
64 };
65
66 self.prices
67 .entry(price.currency.clone())
68 .or_default()
69 .push(entry);
70 }
71
72 pub fn get_price(&self, base: &str, quote: &str, date: NaiveDate) -> Option<Decimal> {
77 if base == quote {
79 return Some(Decimal::ONE);
80 }
81
82 if let Some(price) = self.get_direct_price(base, quote, date) {
84 return Some(price);
85 }
86
87 if let Some(price) = self.get_direct_price(quote, base, date) {
89 if price != Decimal::ZERO {
90 return Some(Decimal::ONE / price);
91 }
92 }
93
94 self.get_chained_price(base, quote, date)
96 }
97
98 fn get_direct_price(&self, base: &str, quote: &str, date: NaiveDate) -> Option<Decimal> {
100 if let Some(entries) = self.prices.get(base) {
101 for entry in entries.iter().rev() {
102 if entry.date <= date && entry.currency == quote {
103 return Some(entry.price);
104 }
105 }
106 }
107 None
108 }
109
110 fn get_chained_price(&self, base: &str, quote: &str, date: NaiveDate) -> Option<Decimal> {
113 let intermediates: Vec<InternedStr> = if let Some(entries) = self.prices.get(base) {
115 entries
116 .iter()
117 .filter(|e| e.date <= date)
118 .map(|e| e.currency.clone())
119 .collect()
120 } else {
121 Vec::new()
122 };
123
124 for intermediate in intermediates {
126 if intermediate == quote {
127 continue; }
129
130 if let Some(price1) = self.get_direct_price(base, &intermediate, date) {
132 if let Some(price2) = self.get_direct_price(&intermediate, quote, date) {
134 return Some(price1 * price2);
135 }
136 if let Some(price2) = self.get_direct_price(quote, &intermediate, date) {
138 if price2 != Decimal::ZERO {
139 return Some(price1 / price2);
140 }
141 }
142 }
143 }
144
145 for (currency, entries) in &self.prices {
147 for entry in entries.iter().rev() {
148 if entry.date <= date && entry.currency == base && entry.price != Decimal::ZERO {
149 let price1 = Decimal::ONE / entry.price;
151
152 if let Some(price2) = self.get_direct_price(currency, quote, date) {
154 return Some(price1 * price2);
155 }
156 if let Some(price2) = self.get_direct_price(quote, currency, date) {
157 if price2 != Decimal::ZERO {
158 return Some(price1 / price2);
159 }
160 }
161 }
162 }
163 }
164
165 None
166 }
167
168 pub fn get_latest_price(&self, base: &str, quote: &str) -> Option<Decimal> {
170 if let Some(entries) = self.prices.get(base) {
171 for entry in entries.iter().rev() {
173 if entry.currency == quote {
174 return Some(entry.price);
175 }
176 }
177 }
178
179 if let Some(entries) = self.prices.get(quote) {
181 for entry in entries.iter().rev() {
182 if entry.currency == base && entry.price != Decimal::ZERO {
183 return Some(Decimal::ONE / entry.price);
184 }
185 }
186 }
187
188 None
189 }
190
191 pub fn convert(&self, amount: &Amount, to_currency: &str, date: NaiveDate) -> Option<Amount> {
195 if amount.currency == to_currency {
196 return Some(amount.clone());
197 }
198
199 self.get_price(&amount.currency, to_currency, date)
200 .map(|price| Amount::new(amount.number * price, to_currency))
201 }
202
203 pub fn convert_latest(&self, amount: &Amount, to_currency: &str) -> Option<Amount> {
205 if amount.currency == to_currency {
206 return Some(amount.clone());
207 }
208
209 self.get_latest_price(&amount.currency, to_currency)
210 .map(|price| Amount::new(amount.number * price, to_currency))
211 }
212
213 pub fn currencies(&self) -> impl Iterator<Item = &str> {
215 self.prices.keys().map(InternedStr::as_str)
216 }
217
218 pub fn has_prices(&self, currency: &str) -> bool {
220 self.prices.contains_key(currency)
221 }
222
223 pub fn len(&self) -> usize {
225 self.prices.values().map(Vec::len).sum()
226 }
227
228 pub fn is_empty(&self) -> bool {
230 self.prices.is_empty()
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237 use rust_decimal_macros::dec;
238
239 fn date(y: i32, m: u32, d: u32) -> NaiveDate {
240 NaiveDate::from_ymd_opt(y, m, d).unwrap()
241 }
242
243 #[test]
244 fn test_price_lookup() {
245 let mut db = PriceDatabase::new();
246
247 db.add_price(&PriceDirective {
249 date: date(2024, 1, 1),
250 currency: "AAPL".into(),
251 amount: Amount::new(dec!(150.00), "USD"),
252 meta: Default::default(),
253 });
254
255 db.add_price(&PriceDirective {
256 date: date(2024, 6, 1),
257 currency: "AAPL".into(),
258 amount: Amount::new(dec!(180.00), "USD"),
259 meta: Default::default(),
260 });
261
262 for entries in db.prices.values_mut() {
264 entries.sort_by_key(|e| e.date);
265 }
266
267 assert_eq!(
269 db.get_price("AAPL", "USD", date(2024, 1, 1)),
270 Some(dec!(150.00))
271 );
272
273 assert_eq!(
275 db.get_price("AAPL", "USD", date(2024, 6, 15)),
276 Some(dec!(180.00))
277 );
278
279 assert_eq!(
281 db.get_price("AAPL", "USD", date(2024, 3, 15)),
282 Some(dec!(150.00))
283 );
284
285 assert_eq!(db.get_price("AAPL", "USD", date(2023, 12, 31)), None);
287 }
288
289 #[test]
290 fn test_inverse_price() {
291 let mut db = PriceDatabase::new();
292
293 db.add_price(&PriceDirective {
295 date: date(2024, 1, 1),
296 currency: "USD".into(),
297 amount: Amount::new(dec!(0.92), "EUR"),
298 meta: Default::default(),
299 });
300
301 for entries in db.prices.values_mut() {
303 entries.sort_by_key(|e| e.date);
304 }
305
306 assert_eq!(
308 db.get_price("USD", "EUR", date(2024, 1, 1)),
309 Some(dec!(0.92))
310 );
311
312 let inverse = db.get_price("EUR", "USD", date(2024, 1, 1)).unwrap();
314 assert!(inverse > dec!(1.08) && inverse < dec!(1.09));
316 }
317
318 #[test]
319 fn test_convert() {
320 let mut db = PriceDatabase::new();
321
322 db.add_price(&PriceDirective {
323 date: date(2024, 1, 1),
324 currency: "AAPL".into(),
325 amount: Amount::new(dec!(150.00), "USD"),
326 meta: Default::default(),
327 });
328
329 for entries in db.prices.values_mut() {
330 entries.sort_by_key(|e| e.date);
331 }
332
333 let shares = Amount::new(dec!(10), "AAPL");
334 let usd = db.convert(&shares, "USD", date(2024, 1, 1)).unwrap();
335
336 assert_eq!(usd.number, dec!(1500.00));
337 assert_eq!(usd.currency, "USD");
338 }
339
340 #[test]
341 fn test_same_currency_convert() {
342 let db = PriceDatabase::new();
343 let amount = Amount::new(dec!(100), "USD");
344
345 let result = db.convert(&amount, "USD", date(2024, 1, 1)).unwrap();
346 assert_eq!(result.number, dec!(100));
347 assert_eq!(result.currency, "USD");
348 }
349
350 #[test]
351 fn test_from_directives() {
352 let directives = vec![
353 Directive::Price(PriceDirective {
354 date: date(2024, 1, 1),
355 currency: "AAPL".into(),
356 amount: Amount::new(dec!(150.00), "USD"),
357 meta: Default::default(),
358 }),
359 Directive::Price(PriceDirective {
360 date: date(2024, 1, 1),
361 currency: "EUR".into(),
362 amount: Amount::new(dec!(1.10), "USD"),
363 meta: Default::default(),
364 }),
365 ];
366
367 let db = PriceDatabase::from_directives(&directives);
368
369 assert_eq!(db.len(), 2);
370 assert!(db.has_prices("AAPL"));
371 assert!(db.has_prices("EUR"));
372 }
373
374 #[test]
375 fn test_chained_price_lookup() {
376 let mut db = PriceDatabase::new();
377
378 db.add_price(&PriceDirective {
380 date: date(2024, 1, 1),
381 currency: "AAPL".into(),
382 amount: Amount::new(dec!(150.00), "USD"),
383 meta: Default::default(),
384 });
385
386 db.add_price(&PriceDirective {
388 date: date(2024, 1, 1),
389 currency: "USD".into(),
390 amount: Amount::new(dec!(0.92), "EUR"),
391 meta: Default::default(),
392 });
393
394 for entries in db.prices.values_mut() {
396 entries.sort_by_key(|e| e.date);
397 }
398
399 assert_eq!(
401 db.get_price("AAPL", "USD", date(2024, 1, 1)),
402 Some(dec!(150.00))
403 );
404
405 assert_eq!(
407 db.get_price("USD", "EUR", date(2024, 1, 1)),
408 Some(dec!(0.92))
409 );
410
411 let chained = db.get_price("AAPL", "EUR", date(2024, 1, 1)).unwrap();
414 assert_eq!(chained, dec!(138.00));
415 }
416
417 #[test]
418 fn test_chained_price_with_inverse() {
419 let mut db = PriceDatabase::new();
420
421 db.add_price(&PriceDirective {
423 date: date(2024, 1, 1),
424 currency: "BTC".into(),
425 amount: Amount::new(dec!(40000.00), "USD"),
426 meta: Default::default(),
427 });
428
429 db.add_price(&PriceDirective {
431 date: date(2024, 1, 1),
432 currency: "EUR".into(),
433 amount: Amount::new(dec!(1.10), "USD"),
434 meta: Default::default(),
435 });
436
437 for entries in db.prices.values_mut() {
439 entries.sort_by_key(|e| e.date);
440 }
441
442 let chained = db.get_price("BTC", "EUR", date(2024, 1, 1)).unwrap();
447 assert!(chained > dec!(36363) && chained < dec!(36364));
449 }
450
451 #[test]
452 fn test_chained_price_no_path() {
453 let mut db = PriceDatabase::new();
454
455 db.add_price(&PriceDirective {
457 date: date(2024, 1, 1),
458 currency: "AAPL".into(),
459 amount: Amount::new(dec!(150.00), "USD"),
460 meta: Default::default(),
461 });
462
463 db.add_price(&PriceDirective {
465 date: date(2024, 1, 1),
466 currency: "GBP".into(),
467 amount: Amount::new(dec!(1.17), "EUR"),
468 meta: Default::default(),
469 });
470
471 for entries in db.prices.values_mut() {
473 entries.sort_by_key(|e| e.date);
474 }
475
476 assert_eq!(db.get_price("AAPL", "GBP", date(2024, 1, 1)), None);
478 }
479}