use std::{
collections::{HashMap, HashSet},
sync::Arc,
};
use tokio::sync::RwLock;
use tycho_simulation::{
tycho_client::feed::SynchronizerState,
tycho_common::{
models::{protocol::ProtocolComponent, token::Token, Address},
simulation::protocol_sim::ProtocolSim,
},
tycho_ethereum::gas::BlockGasPrice,
};
use crate::types::{BlockInfo, ComponentId};
pub type StateLabel = String;
pub type OverlayStates = Arc<HashMap<ComponentId, Box<dyn ProtocolSim>>>;
pub struct OverlayEntry {
pub states: OverlayStates,
pub valid_until: u64,
}
type OverlayRegistry = Arc<RwLock<HashMap<StateLabel, OverlayEntry>>>;
#[derive(Debug, thiserror::Error)]
pub enum ReadLabeledError {
#[error("label not found: {0}")]
NotFound(StateLabel),
}
#[derive(Clone)]
pub struct MarketData {
data: Arc<RwLock<MarketState>>,
overlays: OverlayRegistry,
}
impl MarketData {
pub fn new(data: Arc<RwLock<MarketState>>) -> Self {
Self { data, overlays: Arc::new(RwLock::new(HashMap::new())) }
}
pub fn new_shared() -> Self {
Self::new(Arc::new(RwLock::new(MarketState::new())))
}
pub async fn read(&self) -> MarketDataView<'_> {
MarketDataView { guard: self.data.read().await, overlay: None }
}
pub async fn read_labeled(
&self,
label: &StateLabel,
) -> Result<MarketDataView<'_>, ReadLabeledError> {
let guard = self.data.read().await;
if let Some(e) = self.overlays.read().await.get(label) {
let states = Arc::clone(&e.states);
return Ok(MarketDataView { guard, overlay: Some((label.clone(), states)) });
}
if &guard.label == label {
return Ok(MarketDataView { guard, overlay: None });
}
Err(ReadLabeledError::NotFound(label.clone()))
}
pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, MarketState> {
self.data.write().await
}
pub fn try_read(&self) -> Option<tokio::sync::RwLockReadGuard<'_, MarketState>> {
self.data.try_read().ok()
}
pub fn try_write(&self) -> Option<tokio::sync::RwLockWriteGuard<'_, MarketState>> {
self.data.try_write().ok()
}
pub fn try_read_blocking(&self) -> Option<MarketDataView<'_>> {
self.data
.try_read()
.ok()
.map(|guard| MarketDataView { guard, overlay: None })
}
pub async fn register_labeled_state(
&self,
label: StateLabel,
states: HashMap<ComponentId, Box<dyn ProtocolSim>>,
valid_until: u64,
) {
self.overlays
.write()
.await
.insert(label, OverlayEntry { states: Arc::new(states), valid_until });
}
pub async fn remove_labeled_state(&self, label: &StateLabel) {
self.overlays
.write()
.await
.remove(label);
}
pub async fn clear_labeled_states(&self) {
self.overlays.write().await.clear();
}
pub async fn apply_block_update(
&self,
new_block_number: u64,
update: impl FnOnce(&mut MarketState),
) {
self.overlays
.write()
.await
.retain(|_, entry| entry.valid_until >= new_block_number);
let mut data = self.data.write().await;
data.label = new_block_number.to_string();
update(&mut data);
}
pub async fn labeled_state_ids(&self) -> Vec<StateLabel> {
self.overlays
.read()
.await
.keys()
.cloned()
.collect()
}
}
pub struct MarketDataView<'a> {
guard: tokio::sync::RwLockReadGuard<'a, MarketState>,
overlay: Option<(StateLabel, OverlayStates)>,
}
impl<'a> MarketDataView<'a> {
pub fn state_label(&self) -> Option<&StateLabel> {
self.overlay
.as_ref()
.map(|(label, _)| label)
}
pub fn get_simulation_state(&self, id: &str) -> Option<&dyn ProtocolSim> {
if let Some((_, ref states)) = self.overlay {
if let Some(s) = states.get(id) {
return Some(s.as_ref());
}
}
self.guard.get_simulation_state(id)
}
pub fn extract_subset_with_overlay(&self, component_ids: &HashSet<ComponentId>) -> MarketState {
let mut subset = self.guard.extract_subset(component_ids);
if let Some((ref label, ref states)) = self.overlay {
for (id, state) in states.iter() {
if subset
.simulation_states
.contains_key(id)
{
subset
.simulation_states
.insert(id.clone(), state.clone_box());
}
}
subset.label = label.clone();
}
subset
}
pub fn component_topology(&self) -> HashMap<ComponentId, Vec<Address>> {
self.guard.component_topology()
}
pub fn extract_subset(&self, component_ids: &HashSet<ComponentId>) -> MarketState {
self.guard.extract_subset(component_ids)
}
pub fn token_registry_ref(&self) -> &HashMap<Address, Token> {
self.guard.token_registry_ref()
}
pub fn gas_price(&self) -> Option<&BlockGasPrice> {
self.guard.gas_price()
}
pub fn last_updated(&self) -> Option<&BlockInfo> {
self.guard.last_updated()
}
pub fn get_token(&self, address: &Address) -> Option<&Token> {
self.guard.get_token(address)
}
pub fn get_component(&self, id: &str) -> Option<&ProtocolComponent> {
self.guard.get_component(id)
}
pub fn base_market_state(&self) -> &MarketState {
&self.guard
}
}
#[derive(Debug, Default)]
pub struct MarketState {
label: StateLabel,
components: HashMap<ComponentId, ProtocolComponent>,
simulation_states: HashMap<ComponentId, Box<dyn ProtocolSim>>,
tokens: HashMap<Address, Token>,
gas_price: Option<BlockGasPrice>,
protocol_sync_status: HashMap<String, SynchronizerState>,
last_updated: Option<BlockInfo>,
}
impl MarketState {
pub fn new() -> Self {
Self {
label: String::new(),
components: HashMap::new(),
simulation_states: HashMap::new(),
tokens: HashMap::new(),
gas_price: None,
protocol_sync_status: HashMap::new(),
last_updated: None,
}
}
pub fn label(&self) -> &StateLabel {
&self.label
}
pub fn last_updated(&self) -> Option<&BlockInfo> {
self.last_updated.as_ref()
}
pub fn get_protocol_sync_status(&self, protocol_system: &String) -> Option<&SynchronizerState> {
self.protocol_sync_status
.get(protocol_system)
}
pub fn component_topology(&self) -> HashMap<ComponentId, Vec<Address>> {
self.components
.iter()
.map(|(id, component)| (id.clone(), component.tokens.clone()))
.collect()
}
pub fn get_component(&self, id: &str) -> Option<&ProtocolComponent> {
self.components.get(id)
}
pub fn get_simulation_state(&self, id: &str) -> Option<&dyn ProtocolSim> {
self.simulation_states
.get(id)
.map(|b| b.as_ref())
}
pub fn get_token(&self, address: &Address) -> Option<&Token> {
self.tokens.get(address)
}
pub fn gas_price(&self) -> Option<&BlockGasPrice> {
self.gas_price.as_ref()
}
pub fn token_registry_ref(&self) -> &HashMap<Address, Token> {
&self.tokens
}
pub fn upsert_components(&mut self, components: impl IntoIterator<Item = ProtocolComponent>) {
for component in components {
self.components
.insert(component.id.clone(), component);
}
}
pub fn upsert_tokens(&mut self, tokens: impl IntoIterator<Item = Token>) {
for token in tokens {
self.tokens
.insert(token.address.clone(), token);
}
}
pub fn update_protocol_sync_status(
&mut self,
sync_states: impl IntoIterator<Item = (String, SynchronizerState)>,
) {
for (protocol_system, status) in sync_states {
self.protocol_sync_status
.insert(protocol_system, status);
}
}
pub fn remove_components<'a>(&mut self, ids: impl IntoIterator<Item = &'a ComponentId>) {
for id in ids {
self.components.remove(id);
self.simulation_states.remove(id);
}
}
pub fn update_states(
&mut self,
states: impl IntoIterator<Item = (ComponentId, Box<dyn ProtocolSim>)>,
) {
for (id, state) in states {
self.simulation_states.insert(id, state);
}
}
pub fn update_gas_price(&mut self, gas_price: BlockGasPrice) {
self.gas_price = Some(gas_price);
}
pub fn update_last_updated(&mut self, block_info: BlockInfo) {
self.last_updated = Some(block_info);
}
pub fn extract_subset(&self, component_ids: &HashSet<ComponentId>) -> MarketState {
let components: HashMap<ComponentId, ProtocolComponent> = self
.components
.iter()
.filter(|(id, _)| component_ids.contains(*id))
.map(|(id, component)| (id.clone(), component.clone()))
.collect();
let token_addresses: HashSet<&Address> = components
.values()
.flat_map(|c| &c.tokens)
.collect();
let tokens: HashMap<Address, Token> = self
.tokens
.iter()
.filter(|(addr, _)| token_addresses.contains(addr))
.map(|(addr, token)| (addr.clone(), token.clone()))
.collect();
let simulation_states: HashMap<ComponentId, Box<dyn ProtocolSim>> = self
.simulation_states
.iter()
.filter(|(id, _)| component_ids.contains(*id))
.map(|(id, state)| (id.clone(), state.clone_box()))
.collect();
MarketState {
label: self.label.clone(),
components,
simulation_states,
tokens,
gas_price: self.gas_price.clone(),
protocol_sync_status: HashMap::new(), last_updated: self.last_updated.clone(),
}
}
}
#[cfg(test)]
mod tests {
use num_bigint::BigUint;
use tycho_simulation::tycho_ethereum::gas::GasPrice;
use super::*;
use crate::algorithm::test_utils::{component, token, MockProtocolSim};
#[test]
fn extract_subset_filters_by_component_ids() {
let mut market = MarketState::new();
let token_a = token(0x0A, "A");
let token_b = token(0x0B, "B");
let token_c = token(0x0C, "C");
market.upsert_components([
component("pool_ab", &[token_a.clone(), token_b.clone()]),
component("pool_bc", &[token_b.clone(), token_c.clone()]),
]);
market.upsert_tokens([token_a.clone(), token_b.clone(), token_c.clone()]);
market.update_states([
("pool_ab".to_string(), Box::new(MockProtocolSim::new(2.0)) as Box<dyn ProtocolSim>),
("pool_bc".to_string(), Box::new(MockProtocolSim::new(3.0)) as Box<dyn ProtocolSim>),
]);
market.update_gas_price(BlockGasPrice {
block_number: 1,
block_hash: Default::default(),
block_timestamp: 0,
pricing: GasPrice::Legacy { gas_price: BigUint::from(1u64) },
});
market.update_last_updated(BlockInfo::new(12345, "0xabc".to_string(), 0));
let ids: HashSet<_> = ["pool_ab".to_string()]
.into_iter()
.collect();
let subset = market.extract_subset(&ids);
assert_eq!(subset.components.len(), 1);
assert!(subset
.components
.contains_key("pool_ab"));
assert_eq!(subset.tokens.len(), 2);
assert!(subset
.tokens
.contains_key(&token_a.address));
assert!(subset
.tokens
.contains_key(&token_b.address));
assert!(!subset
.tokens
.contains_key(&token_c.address));
assert_eq!(subset.simulation_states.len(), 1);
assert!(subset
.simulation_states
.contains_key("pool_ab"));
assert_eq!(subset.gas_price, market.gas_price);
assert!(subset.last_updated.is_some());
let empty_subset = market.extract_subset(&HashSet::new());
assert!(empty_subset.components.is_empty());
assert!(empty_subset.tokens.is_empty());
assert!(empty_subset
.simulation_states
.is_empty());
}
#[tokio::test]
async fn register_and_retrieve_overlay_via_labeled_read() {
let market_ref = MarketData::new_shared();
let label = "test_label".to_string();
let mut states: HashMap<ComponentId, Box<dyn ProtocolSim>> = HashMap::new();
states.insert(
"pool_ab".to_string(),
Box::new(MockProtocolSim::new(99.0)) as Box<dyn ProtocolSim>,
);
market_ref
.register_labeled_state(label.clone(), states, u64::MAX)
.await;
let guard = market_ref
.read_labeled(&label)
.await
.expect("label was just registered");
let sim = guard.get_simulation_state("pool_ab");
assert!(sim.is_some());
}
#[tokio::test]
async fn read_without_label_returns_no_overlay() {
let market_ref = MarketData::new_shared();
market_ref
.register_labeled_state(
"my_label".to_string(),
HashMap::from([(
"pool1".to_string(),
Box::new(MockProtocolSim::new(5.0)) as Box<dyn ProtocolSim>,
)]),
u64::MAX,
)
.await;
let guard = market_ref.read().await;
assert!(guard
.get_simulation_state("pool1")
.is_none());
}
#[tokio::test]
async fn remove_labeled_state_clears_overlay() {
let market_ref = MarketData::new_shared();
let label = "lbl".to_string();
market_ref
.register_labeled_state(
label.clone(),
HashMap::from([(
"pool".to_string(),
Box::new(MockProtocolSim::new(1.0)) as Box<dyn ProtocolSim>,
)]),
u64::MAX,
)
.await;
market_ref
.remove_labeled_state(&label)
.await;
let ids = market_ref.labeled_state_ids().await;
assert!(ids.is_empty());
}
#[tokio::test]
async fn clear_labeled_states_removes_all() {
let market_ref = MarketData::new_shared();
for i in 0..3u8 {
market_ref
.register_labeled_state(
format!("label_{i}"),
HashMap::from([(
format!("pool_{i}"),
Box::new(MockProtocolSim::new(f64::from(i))) as Box<dyn ProtocolSim>,
)]),
u64::MAX,
)
.await;
}
market_ref.clear_labeled_states().await;
assert!(market_ref
.labeled_state_ids()
.await
.is_empty());
}
#[tokio::test]
async fn clone_shares_overlay_registry() {
let base = MarketData::new_shared();
let clone_a = base.clone();
let clone_b = base.clone();
base.register_labeled_state(
"shared".to_string(),
HashMap::from([(
"pool_x".to_string(),
Box::new(MockProtocolSim::new(7.0)) as Box<dyn ProtocolSim>,
)]),
u64::MAX,
)
.await;
let label = "shared".to_string();
let guard_a = clone_a
.read_labeled(&label)
.await
.expect("label was just registered");
assert!(guard_a
.get_simulation_state("pool_x")
.is_some());
drop(guard_a);
let guard_b = clone_b
.read_labeled(&label)
.await
.expect("label was just registered");
assert!(guard_b
.get_simulation_state("pool_x")
.is_some());
}
#[tokio::test]
async fn extract_subset_with_overlay_replaces_matching_states() {
use crate::algorithm::test_utils::{component as mk_component, token as mk_token};
let market_ref = MarketData::new_shared();
let tok_a = mk_token(0x01, "A");
let tok_b = mk_token(0x02, "B");
{
let mut data = market_ref.write().await;
data.upsert_components([mk_component("pool_ab", &[tok_a.clone(), tok_b.clone()])]);
data.upsert_tokens([tok_a.clone(), tok_b.clone()]);
data.update_states([(
"pool_ab".to_string(),
Box::new(MockProtocolSim::new(2.0)) as Box<dyn ProtocolSim>,
)]);
}
let label = "overlay".to_string();
market_ref
.register_labeled_state(
label.clone(),
HashMap::from([(
"pool_ab".to_string(),
Box::new(MockProtocolSim::new(99.0)) as Box<dyn ProtocolSim>,
)]),
u64::MAX,
)
.await;
let guard = market_ref
.read_labeled(&label)
.await
.expect("label was just registered");
let ids: HashSet<ComponentId> = ["pool_ab".to_string()]
.into_iter()
.collect();
let subset = guard.extract_subset_with_overlay(&ids);
let sim = subset
.get_simulation_state("pool_ab")
.unwrap();
let mock = sim
.as_any()
.downcast_ref::<MockProtocolSim>()
.unwrap();
assert_eq!(mock.spot_price, 99.0, "overlay state should replace base state");
}
#[tokio::test]
async fn apply_block_update_evicts_stale_overlays() {
let market_ref = MarketData::new_shared();
market_ref
.register_labeled_state(
"stale".to_string(),
HashMap::from([(
"pool_stale".to_string(),
Box::new(MockProtocolSim::new(1.0)) as Box<dyn ProtocolSim>,
)]),
10,
)
.await;
market_ref
.register_labeled_state(
"fresh".to_string(),
HashMap::from([(
"pool_fresh".to_string(),
Box::new(MockProtocolSim::new(2.0)) as Box<dyn ProtocolSim>,
)]),
20,
)
.await;
market_ref
.apply_block_update(11, |_data| {})
.await;
let ids = market_ref.labeled_state_ids().await;
assert!(!ids.contains(&"stale".to_string()), "stale overlay must be evicted");
assert!(ids.contains(&"fresh".to_string()), "fresh overlay must survive");
}
#[tokio::test]
async fn apply_block_update_applies_mutation() {
let market_ref = MarketData::new_shared();
market_ref
.apply_block_update(1, |data| {
data.update_last_updated(BlockInfo::new(1, "0xabc".to_string(), 0));
})
.await;
let guard = market_ref.read().await;
assert_eq!(
guard
.last_updated()
.expect("last_updated must be set")
.number(),
1
);
}
}