use std::{
collections::{HashMap, HashSet},
sync::Arc,
time::Instant,
};
use async_trait::async_trait;
use tokio::sync::{broadcast, RwLock};
use tracing::{debug, info, trace, warn};
use tycho_simulation::tycho_common::models::Address;
use crate::{feed::market_data::SharedMarketData, types::ComponentId};
#[derive(Debug, Clone, Default)]
pub struct ChangedComponents {
pub added: HashMap<ComponentId, Vec<Address>>,
pub removed: Vec<ComponentId>,
pub updated: Vec<ComponentId>,
pub is_full_recompute: bool,
}
impl ChangedComponents {
pub fn all(market: &SharedMarketData) -> Self {
Self {
added: market.component_topology().clone(),
removed: vec![],
updated: vec![],
is_full_recompute: true,
}
}
pub fn is_topology_change(&self) -> bool {
!self.added.is_empty() || !self.removed.is_empty()
}
pub fn all_changed_ids(&self) -> HashSet<ComponentId> {
let mut all = HashSet::new();
all.extend(self.added.keys().cloned());
all.extend(self.removed.iter().cloned());
all.extend(self.updated.iter().cloned());
all
}
}
use super::{
computation::DerivedComputation,
computations::{PoolDepthComputation, SpotPriceComputation, TokenGasPriceComputation},
error::ComputationError,
events::DerivedDataEvent,
store::DerivedData,
};
use crate::feed::{
events::{EventError, MarketEvent, MarketEventHandler},
market_data::SharedMarketDataRef,
};
pub type SharedDerivedDataRef = Arc<RwLock<DerivedData>>;
#[derive(Debug, Clone)]
pub struct ComputationManagerConfig {
gas_token: Address,
max_hop: usize,
depth_slippage_threshold: f64,
}
impl ComputationManagerConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_depth_slippage_threshold(mut self, threshold: f64) -> Self {
self.depth_slippage_threshold = threshold;
self
}
pub fn with_max_hop(mut self, hop_count: usize) -> Self {
self.max_hop = hop_count;
self
}
pub fn with_gas_token(mut self, gas_token: Address) -> Self {
self.gas_token = gas_token;
self
}
pub fn gas_token(&self) -> &Address {
&self.gas_token
}
pub fn max_hop(&self) -> usize {
self.max_hop
}
pub fn depth_slippage_threshold(&self) -> f64 {
self.depth_slippage_threshold
}
}
impl Default for ComputationManagerConfig {
fn default() -> Self {
Self { gas_token: Address::zero(20), max_hop: 2, depth_slippage_threshold: 0.01 }
}
}
pub struct ComputationManager {
market_data: SharedMarketDataRef,
store: SharedDerivedDataRef,
token_price_computation: TokenGasPriceComputation,
spot_price_computation: SpotPriceComputation,
pool_depth_computation: PoolDepthComputation,
event_tx: broadcast::Sender<DerivedDataEvent>,
}
impl ComputationManager {
pub fn new(
config: ComputationManagerConfig,
market_data: SharedMarketDataRef,
) -> Result<(Self, broadcast::Receiver<DerivedDataEvent>), ComputationError> {
let pool_depth_computation = PoolDepthComputation::new(config.depth_slippage_threshold)?;
let (event_tx, event_rx) = broadcast::channel(64);
Ok((
Self {
market_data,
store: DerivedData::new_shared(),
token_price_computation: TokenGasPriceComputation::default()
.with_max_hops(config.max_hop)
.with_gas_token(config.gas_token),
spot_price_computation: SpotPriceComputation::new(),
pool_depth_computation,
event_tx,
},
event_rx,
))
}
pub fn store(&self) -> SharedDerivedDataRef {
Arc::clone(&self.store)
}
pub fn event_sender(&self) -> broadcast::Sender<DerivedDataEvent> {
self.event_tx.clone()
}
pub async fn run(
mut self,
mut event_rx: broadcast::Receiver<MarketEvent>,
mut shutdown_rx: broadcast::Receiver<()>,
) {
info!("computation manager started");
loop {
tokio::select! {
biased;
_ = shutdown_rx.recv() => {
info!("computation manager shutting down");
break;
}
event_result = event_rx.recv() => {
match event_result {
Ok(event) => {
if let Err(e) = self.handle_event(&event).await {
warn!(error = ?e, "failed to handle market event");
}
}
Err(broadcast::error::RecvError::Closed) => {
info!("event channel closed, computation manager shutting down");
break;
}
Err(broadcast::error::RecvError::Lagged(skipped)) => {
warn!(
skipped,
"computation manager lagged, skipped {} events. Recomputing from current state.",
skipped
);
let market = self.market_data.read().await;
let changed = ChangedComponents::all(&market);
drop(market);
self.compute_all(&changed).await;
}
}
}
}
}
}
async fn compute_all(&self, changed: &ChangedComponents) {
let total_start = Instant::now();
let Some(block) = self
.market_data
.read()
.await
.last_updated()
.map(|b| b.number())
else {
warn!("market data has no last updated block, skipping computations");
return;
};
let _ = self
.event_tx
.send(DerivedDataEvent::NewBlock { block });
let spot_start = Instant::now();
let spot_prices_result = self
.spot_price_computation
.compute(&self.market_data, &self.store, changed)
.await;
let spot_elapsed = spot_start.elapsed();
match spot_prices_result {
Ok(output) => {
let count = output.data.len();
if output.has_failures() {
warn!(
count,
failed = output.failed_items.len(),
"spot prices partial failures"
);
for item in &output.failed_items {
debug!(key = %item.key, error = %item.error, "spot price failed item");
}
} else {
info!(count, elapsed_ms = spot_elapsed.as_millis(), "spot prices computed");
}
self.store
.write()
.await
.set_spot_prices(
output.data,
output.failed_items.clone(),
block,
changed.is_full_recompute,
);
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationComplete {
computation_id: SpotPriceComputation::ID,
block,
failed_items: output.failed_items,
});
}
Err(e) => {
warn!(error = ?e, elapsed_ms = spot_elapsed.as_millis(), "spot price computation failed");
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationFailed {
computation_id: SpotPriceComputation::ID,
block,
});
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationFailed {
computation_id: TokenGasPriceComputation::ID,
block,
});
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationFailed {
computation_id: PoolDepthComputation::ID,
block,
});
return;
}
}
let (token_prices_result, pool_depths_result) = tokio::join!(
async {
let start = Instant::now();
let result = self
.token_price_computation
.compute(&self.market_data, &self.store, changed)
.await;
(result, start.elapsed())
},
async {
let start = Instant::now();
let result = self
.pool_depth_computation
.compute(&self.market_data, &self.store, changed)
.await;
(result, start.elapsed())
}
);
let (token_prices_result, token_elapsed) = token_prices_result;
let (pool_depths_result, depth_elapsed) = pool_depths_result;
let mut store_write = self.store.write().await;
match token_prices_result {
Ok(output) => {
let count = output.data.len();
if output.has_failures() {
warn!(
count,
failed = output.failed_items.len(),
"token prices partial failures"
);
for item in &output.failed_items {
debug!(key = %item.key, error = %item.error, "token price failed item");
}
} else {
info!(count, elapsed_ms = token_elapsed.as_millis(), "token prices computed");
}
store_write.set_token_prices(
output.data,
output.failed_items.clone(),
block,
changed.is_full_recompute,
);
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationComplete {
computation_id: TokenGasPriceComputation::ID,
block,
failed_items: output.failed_items,
});
}
Err(e) => {
warn!(error = ?e, "token price computation failed");
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationFailed {
computation_id: TokenGasPriceComputation::ID,
block,
});
}
}
match pool_depths_result {
Ok(output) => {
let count = output.data.len();
if output.has_failures() {
warn!(
count,
failed = output.failed_items.len(),
"pool depths partial failures"
);
for item in &output.failed_items {
debug!(key = %item.key, error = %item.error, "pool depth failed item");
}
} else {
info!(count, elapsed_ms = depth_elapsed.as_millis(), "pool depths computed");
}
store_write.set_pool_depths(
output.data,
output.failed_items.clone(),
block,
changed.is_full_recompute,
);
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationComplete {
computation_id: PoolDepthComputation::ID,
block,
failed_items: output.failed_items,
});
}
Err(e) => {
warn!(error = ?e, "pool depth computation failed");
let _ = self
.event_tx
.send(DerivedDataEvent::ComputationFailed {
computation_id: PoolDepthComputation::ID,
block,
});
}
}
let total_elapsed = total_start.elapsed();
info!(block, total_ms = total_elapsed.as_millis(), "all derived computations complete");
}
}
#[async_trait]
impl MarketEventHandler for ComputationManager {
async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
match event {
MarketEvent::MarketUpdated {
added_components,
removed_components,
updated_components,
} if !added_components.is_empty() ||
!removed_components.is_empty() ||
!updated_components.is_empty() =>
{
trace!(
added = added_components.len(),
removed = removed_components.len(),
updated = updated_components.len(),
"market updated, running incremental computations"
);
let changed = ChangedComponents {
added: added_components.clone(),
removed: removed_components.clone(),
updated: updated_components.clone(),
is_full_recompute: false,
};
self.compute_all(&changed).await;
}
_ => {
trace!("empty market update, skipping computations");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{broadcast, RwLock};
use super::*;
use crate::{
algorithm::test_utils::{component, setup_market, token, MockProtocolSim},
feed::market_data::SharedMarketData,
types::BlockInfo,
};
fn drain_events(rx: &mut broadcast::Receiver<DerivedDataEvent>) -> Vec<DerivedDataEvent> {
let mut events = vec![];
loop {
match rx.try_recv() {
Ok(e) => events.push(e),
Err(broadcast::error::TryRecvError::Empty) => break,
Err(broadcast::error::TryRecvError::Lagged(_)) => continue,
Err(broadcast::error::TryRecvError::Closed) => break,
}
}
events
}
#[test]
fn invalid_slippage_threshold_returns_error() {
let (market, _) = setup_market(vec![]);
let config = ComputationManagerConfig::new().with_depth_slippage_threshold(1.5);
let result = ComputationManager::new(config, market);
assert!(matches!(result, Err(ComputationError::InvalidConfiguration(_))));
}
#[tokio::test]
async fn handle_event_runs_computations_on_market_update() {
let eth = token(1, "ETH");
let usdc = token(2, "USDC");
let (market, _) =
setup_market(vec![("eth_usdc", ð, &usdc, MockProtocolSim::new(2000.0).with_gas(0))]);
let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
let event = MarketEvent::MarketUpdated {
added_components: HashMap::from([(
"eth_usdc".to_string(),
vec![eth.address.clone(), usdc.address.clone()],
)]),
removed_components: vec![],
updated_components: vec![],
};
manager
.handle_event(&event)
.await
.unwrap();
let store = manager.store();
let guard = store.read().await;
assert!(guard.token_prices().is_some());
assert!(guard.spot_prices().is_some());
}
#[tokio::test]
async fn handle_event_skips_empty_update() {
let (market, _) = setup_market(vec![]);
let config = ComputationManagerConfig::new();
let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
let event = MarketEvent::MarketUpdated {
added_components: HashMap::new(),
removed_components: vec![],
updated_components: vec![],
};
manager
.handle_event(&event)
.await
.unwrap();
let store = manager.store();
let guard = store.read().await;
assert!(guard.token_prices().is_none());
}
#[tokio::test]
async fn run_shuts_down_on_signal() {
let (market, _) = setup_market(vec![]);
let config = ComputationManagerConfig::new();
let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
let (_event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
let handle = tokio::spawn(async move {
manager.run(event_rx, shutdown_rx).await;
});
shutdown_tx.send(()).unwrap();
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.expect("manager should shutdown")
.expect("task should complete successfully");
}
fn market_with_component_no_sim_state() -> Arc<RwLock<SharedMarketData>> {
let eth = token(1, "ETH");
let usdc = token(2, "USDC");
let pool = component("pool", &[eth.clone(), usdc.clone()]);
let mut market = SharedMarketData::new();
market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
market.upsert_components(std::iter::once(pool));
market.upsert_tokens([eth, usdc]);
Arc::new(RwLock::new(market))
}
fn market_with_mixed_sim_states() -> Arc<RwLock<SharedMarketData>> {
let eth = token(1, "ETH");
let usdc = token(2, "USDC");
let dai = token(3, "DAI");
let pool1 = component("eth_usdc", &[eth.clone(), usdc.clone()]);
let pool2 = component("eth_dai", &[eth.clone(), dai.clone()]);
let mut market = SharedMarketData::new();
market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
market.upsert_components([pool1, pool2]);
market
.update_states([("eth_usdc".to_string(), Box::new(MockProtocolSim::new(2000.0)) as _)]);
market.upsert_tokens([eth, usdc, dai]);
Arc::new(RwLock::new(market))
}
fn market_with_sim_state_no_gas_price() -> Arc<RwLock<SharedMarketData>> {
let eth = token(1, "ETH");
let usdc = token(2, "USDC");
let pool = component("pool", &[eth.clone(), usdc.clone()]);
let mut market = SharedMarketData::new();
market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
market.upsert_components(std::iter::once(pool));
market.update_states([("pool".to_string(), Box::new(MockProtocolSim::new(2000.0)) as _)]);
market.upsert_tokens([eth, usdc]);
Arc::new(RwLock::new(market))
}
#[tokio::test]
async fn test_spot_price_failure_broadcasts_computation_failed() {
let market = market_with_component_no_sim_state();
let config = ComputationManagerConfig::new();
let (manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
let changed = ChangedComponents { is_full_recompute: true, ..Default::default() };
manager.compute_all(&changed).await;
let events = drain_events(&mut event_rx);
assert!(
events.iter().any(|e| matches!(
e,
DerivedDataEvent::ComputationFailed { computation_id: "spot_prices", .. }
)),
"expected ComputationFailed(spot_prices) in events: {events:?}"
);
}
#[tokio::test]
async fn test_token_price_failure_broadcasts_computation_failed() {
let eth = token(1, "ETH");
let usdc = token(2, "USDC");
let market = market_with_sim_state_no_gas_price();
let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
let (mut manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
let event = MarketEvent::MarketUpdated {
added_components: HashMap::from([(
"pool".to_string(),
vec![eth.address.clone(), usdc.address.clone()],
)]),
removed_components: vec![],
updated_components: vec![],
};
manager
.handle_event(&event)
.await
.unwrap();
let events = drain_events(&mut event_rx);
assert!(
events.iter().any(|e| matches!(
e,
DerivedDataEvent::ComputationFailed { computation_id: "token_prices", .. }
)),
"expected ComputationFailed(token_prices) in events: {events:?}"
);
}
#[tokio::test]
async fn run_shuts_down_on_channel_close() {
let (market, _) = setup_market(vec![]);
let config = ComputationManagerConfig::new();
let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
let (event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
let handle = tokio::spawn(async move {
manager.run(event_rx, shutdown_rx).await;
});
drop(event_tx);
tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
.await
.expect("manager should shutdown on channel close")
.expect("task should complete successfully");
}
#[tokio::test]
async fn partial_spot_price_failure_broadcasts_computation_complete() {
let market = market_with_mixed_sim_states();
let config = ComputationManagerConfig::new();
let (manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
let changed = ChangedComponents { is_full_recompute: true, ..Default::default() };
manager.compute_all(&changed).await;
let events = drain_events(&mut event_rx);
assert!(
events.iter().any(|e| matches!(
e,
DerivedDataEvent::ComputationComplete { computation_id: "spot_prices", .. }
)),
"expected ComputationComplete(spot_prices), got: {events:?}"
);
assert!(
!events.iter().any(|e| matches!(
e,
DerivedDataEvent::ComputationFailed { computation_id: "spot_prices", .. }
)),
"should not broadcast ComputationFailed for partial failure"
);
let complete = events.iter().find(|e| {
matches!(e, DerivedDataEvent::ComputationComplete { computation_id: "spot_prices", .. })
});
if let Some(DerivedDataEvent::ComputationComplete { failed_items, .. }) = complete {
assert!(
!failed_items.is_empty(),
"ComputationComplete should carry failed_items for pool2"
);
}
let eth = token(1, "ETH");
let dai = token(3, "DAI");
let store = manager.store();
let guard = store.read().await;
let key_eth_dai = ("eth_dai".to_string(), eth.address.clone(), dai.address.clone());
let key_dai_eth = ("eth_dai".to_string(), dai.address.clone(), eth.address.clone());
assert!(
guard
.spot_price_failure(&key_eth_dai)
.is_some() ||
guard
.spot_price_failure(&key_dai_eth)
.is_some(),
"store should persist failure reason for eth_dai (missing sim state)"
);
}
}