use alloy_erc20::LazyToken;
use alloy_primitives::Address;
use alloy_provider::Provider;
use futures::future::join_all;
use std::collections::HashMap;
use tracing::{info, warn};
use crate::errors::PriceCalculationError;
use crate::TokenDecimals;
#[derive(Debug, Default, Clone)]
pub(crate) struct TokenDecimalsCache {
inner: HashMap<Address, TokenDecimals>,
}
impl TokenDecimalsCache {
pub fn new() -> Self {
Self::default()
}
pub fn get(&self, addr: &Address) -> Option<TokenDecimals> {
self.inner.get(addr).copied()
}
pub fn insert(&mut self, addr: Address, decimals: TokenDecimals) {
self.inner.insert(addr, decimals);
}
pub fn contains(&self, addr: &Address) -> bool {
self.inner.contains_key(addr)
}
}
pub(crate) struct TokenMetadataProvider<'a, P> {
provider: &'a P,
}
impl<'a, P: Provider + Clone> TokenMetadataProvider<'a, P> {
pub fn new(provider: &'a P) -> Self {
Self { provider }
}
pub async fn get_or_fetch(
&self,
cache: &mut TokenDecimalsCache,
addr: Address,
) -> Result<TokenDecimals, PriceCalculationError> {
if let Some(d) = cache.get(&addr) {
return Ok(d);
}
let token = LazyToken::new(addr, self.provider.clone());
let raw = token
.decimals()
.await
.map_err(|e| PriceCalculationError::metadata_fetch_failed(addr, e))?;
let decimals = TokenDecimals::new(*raw);
cache.insert(addr, decimals);
Ok(decimals)
}
pub async fn ensure_decimals(&self, cache: &mut TokenDecimalsCache, addresses: &[Address]) {
let uncached: Vec<Address> = addresses
.iter()
.copied()
.filter(|addr| !cache.contains(addr))
.collect();
if uncached.is_empty() {
return;
}
info!(
count = uncached.len(),
"Batch fetching token decimals for uncached tokens"
);
let futures = uncached.iter().map(|&addr| {
let provider = self.provider.clone();
async move {
let token = LazyToken::new(addr, provider);
let result = token.decimals().await.copied();
(addr, result)
}
});
let results = join_all(futures).await;
for (addr, result) in results {
match result {
Ok(raw) => {
cache.insert(addr, TokenDecimals::new(raw));
}
Err(e) => {
warn!(
token = ?addr,
error = ?e,
"Failed to fetch decimals for token, will retry on demand"
);
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloy_primitives::address;
#[test]
fn cache_starts_empty() {
let cache = TokenDecimalsCache::new();
let addr = address!("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
assert!(!cache.contains(&addr));
assert_eq!(cache.get(&addr), None);
}
#[test]
fn insert_then_get_returns_value() {
let mut cache = TokenDecimalsCache::new();
let addr = address!("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
cache.insert(addr, TokenDecimals::new(18));
assert!(cache.contains(&addr));
assert_eq!(cache.get(&addr), Some(TokenDecimals::new(18)));
}
#[test]
fn insert_overwrites_previous_value() {
let mut cache = TokenDecimalsCache::new();
let addr = address!("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
cache.insert(addr, TokenDecimals::new(6));
cache.insert(addr, TokenDecimals::new(18));
assert_eq!(cache.get(&addr), Some(TokenDecimals::new(18)));
}
#[test]
fn entries_are_token_scoped() {
let mut cache = TokenDecimalsCache::new();
let a = address!("aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa");
let b = address!("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb");
cache.insert(a, TokenDecimals::new(18));
cache.insert(b, TokenDecimals::new(6));
assert_eq!(cache.get(&a), Some(TokenDecimals::new(18)));
assert_eq!(cache.get(&b), Some(TokenDecimals::new(6)));
}
}