use std::collections::HashMap;
use std::fmt;
use std::sync::RwLock;
use super::aggregate::AggregateResult;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct AggregatedPrice {
pub value: f64,
pub as_of: i64,
pub source_count: u8,
}
#[derive(Debug, PartialEq, Eq)]
pub enum PriceError {
NoCurrency,
TooStale,
}
impl fmt::Display for PriceError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PriceError::NoCurrency => write!(f, "no price available for currency"),
PriceError::TooStale => write!(f, "price is older than the staleness window"),
}
}
}
impl std::error::Error for PriceError {}
#[derive(Debug, Default)]
pub struct PriceStore {
inner: RwLock<HashMap<String, AggregatedPrice>>,
}
impl PriceStore {
pub fn new() -> Self {
Self::default()
}
pub fn update(&self, aggregates: HashMap<String, AggregateResult>, now: i64) {
if aggregates.is_empty() {
return;
}
let mut w = self.inner.write().expect("price store lock poisoned");
for (currency, agg) in aggregates {
w.insert(
currency.to_uppercase(),
AggregatedPrice {
value: agg.value,
as_of: now,
source_count: agg.sources,
},
);
}
}
pub fn get(
&self,
currency: &str,
max_staleness_secs: i64,
now: i64,
) -> Result<f64, PriceError> {
let r = self.inner.read().expect("price store lock poisoned");
let entry = r
.get(¤cy.to_uppercase())
.ok_or(PriceError::NoCurrency)?;
if now.saturating_sub(entry.as_of) <= max_staleness_secs {
Ok(entry.value)
} else {
Err(PriceError::TooStale)
}
}
pub fn snapshot(&self, currency: &str) -> Option<AggregatedPrice> {
let r = self.inner.read().expect("price store lock poisoned");
r.get(¤cy.to_uppercase()).copied()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn results(pairs: &[(&str, f64, u8)]) -> HashMap<String, AggregateResult> {
pairs
.iter()
.map(|(c, v, s)| {
(
c.to_string(),
AggregateResult {
value: *v,
sources: *s,
contributors: Vec::new(),
},
)
})
.collect()
}
#[test]
fn get_fresh_within_and_past_ttl() {
let store = PriceStore::new();
store.update(results(&[("USD", 50_000.0, 2)]), 1_000);
assert_eq!(store.get("USD", 1_800, 1_000).unwrap(), 50_000.0);
assert_eq!(store.get("USD", 1_800, 1_000 + 1_800).unwrap(), 50_000.0);
assert_eq!(
store.get("USD", 1_800, 1_000 + 1_801).unwrap_err(),
PriceError::TooStale
);
}
#[test]
fn get_missing_currency() {
let store = PriceStore::new();
assert_eq!(
store.get("EUR", 1_800, 0).unwrap_err(),
PriceError::NoCurrency
);
}
#[test]
fn get_is_case_insensitive() {
let store = PriceStore::new();
store.update(results(&[("usd", 50_000.0, 1)]), 0);
assert_eq!(store.get("USD", 1_800, 0).unwrap(), 50_000.0);
assert_eq!(store.get("usd", 1_800, 0).unwrap(), 50_000.0);
}
#[test]
fn update_preserves_last_known_good_for_absent_currencies() {
let store = PriceStore::new();
store.update(
results(&[("USD", 50_000.0, 2), ("EUR", 45_000.0, 2)]),
1_000,
);
store.update(results(&[("USD", 51_000.0, 2)]), 2_000);
assert_eq!(store.snapshot("USD").unwrap().as_of, 2_000);
assert_eq!(store.snapshot("USD").unwrap().value, 51_000.0);
let eur = store.snapshot("EUR").unwrap();
assert_eq!(eur.as_of, 1_000, "EUR keeps its older as_of");
assert_eq!(eur.value, 45_000.0);
assert_eq!(
store.get("EUR", 500, 2_000).unwrap_err(),
PriceError::TooStale
);
}
#[test]
fn empty_update_is_noop() {
let store = PriceStore::new();
store.update(results(&[("USD", 50_000.0, 1)]), 1_000);
store.update(HashMap::new(), 9_999);
assert_eq!(store.snapshot("USD").unwrap().as_of, 1_000);
}
}