1use std::{
10 collections::{HashMap, HashSet},
11 sync::Arc,
12 time::Instant,
13};
14
15use async_trait::async_trait;
16use tokio::sync::{broadcast, RwLock};
17use tracing::{info, trace, warn};
18use tycho_simulation::tycho_common::models::Address;
19
20use crate::{feed::market_data::SharedMarketData, types::ComponentId};
21
22#[derive(Debug, Clone, Default)]
27pub struct ChangedComponents {
28 pub added: HashMap<ComponentId, Vec<Address>>,
30 pub removed: Vec<ComponentId>,
32 pub updated: Vec<ComponentId>,
34 pub is_full_recompute: bool,
36}
37
38impl ChangedComponents {
39 pub fn all(market: &SharedMarketData) -> Self {
43 Self {
44 added: market.component_topology().clone(),
45 removed: vec![],
46 updated: vec![],
47 is_full_recompute: true,
48 }
49 }
50
51 pub fn is_topology_change(&self) -> bool {
53 !self.added.is_empty() || !self.removed.is_empty()
54 }
55
56 pub fn all_changed_ids(&self) -> HashSet<ComponentId> {
58 let mut all = HashSet::new();
59 all.extend(self.added.keys().cloned());
60 all.extend(self.removed.iter().cloned());
61 all.extend(self.updated.iter().cloned());
62 all
63 }
64}
65
66use super::{
67 computation::DerivedComputation,
68 computations::{PoolDepthComputation, SpotPriceComputation, TokenGasPriceComputation},
69 error::ComputationError,
70 events::DerivedDataEvent,
71 store::DerivedData,
72};
73use crate::feed::{
74 events::{EventError, MarketEvent, MarketEventHandler},
75 market_data::SharedMarketDataRef,
76};
77
78pub type SharedDerivedDataRef = Arc<RwLock<DerivedData>>;
80
81#[derive(Debug, Clone)]
87pub struct ComputationManagerConfig {
88 gas_token: Address,
90 max_hop: usize,
92 depth_slippage_threshold: f64,
94}
95
96impl ComputationManagerConfig {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn with_depth_slippage_threshold(mut self, threshold: f64) -> Self {
104 self.depth_slippage_threshold = threshold;
105 self
106 }
107
108 pub fn with_max_hop(mut self, hop_count: usize) -> Self {
110 self.max_hop = hop_count;
111 self
112 }
113
114 pub fn with_gas_token(mut self, gas_token: Address) -> Self {
116 self.gas_token = gas_token;
117 self
118 }
119
120 pub fn gas_token(&self) -> &Address {
122 &self.gas_token
123 }
124
125 pub fn max_hop(&self) -> usize {
127 self.max_hop
128 }
129
130 pub fn depth_slippage_threshold(&self) -> f64 {
132 self.depth_slippage_threshold
133 }
134}
135
136impl Default for ComputationManagerConfig {
137 fn default() -> Self {
138 Self { gas_token: Address::zero(20), max_hop: 2, depth_slippage_threshold: 0.01 }
139 }
140}
141
142pub struct ComputationManager {
144 market_data: SharedMarketDataRef,
146 store: SharedDerivedDataRef,
148 token_price_computation: TokenGasPriceComputation,
150 spot_price_computation: SpotPriceComputation,
152 pool_depth_computation: PoolDepthComputation,
154 event_tx: broadcast::Sender<DerivedDataEvent>,
156}
157
158impl ComputationManager {
159 pub fn new(
165 config: ComputationManagerConfig,
166 market_data: SharedMarketDataRef,
167 ) -> Result<(Self, broadcast::Receiver<DerivedDataEvent>), ComputationError> {
168 let pool_depth_computation = PoolDepthComputation::new(config.depth_slippage_threshold)?;
169 let (event_tx, event_rx) = broadcast::channel(64);
170
171 Ok((
172 Self {
173 market_data,
174 store: DerivedData::new_shared(),
175 token_price_computation: TokenGasPriceComputation::default()
176 .with_max_hops(config.max_hop)
177 .with_gas_token(config.gas_token),
178 spot_price_computation: SpotPriceComputation::new(),
179 pool_depth_computation,
180 event_tx,
181 },
182 event_rx,
183 ))
184 }
185
186 pub fn store(&self) -> SharedDerivedDataRef {
188 Arc::clone(&self.store)
189 }
190
191 pub fn event_sender(&self) -> broadcast::Sender<DerivedDataEvent> {
193 self.event_tx.clone()
194 }
195
196 pub async fn run(
200 mut self,
201 mut event_rx: broadcast::Receiver<MarketEvent>,
202 mut shutdown_rx: broadcast::Receiver<()>,
203 ) {
204 info!("computation manager started");
205
206 loop {
207 tokio::select! {
208 biased;
209
210 _ = shutdown_rx.recv() => {
211 info!("computation manager shutting down");
212 break;
213 }
214
215 event_result = event_rx.recv() => {
216 match event_result {
217 Ok(event) => {
218 if let Err(e) = self.handle_event(&event).await {
219 warn!(error = ?e, "failed to handle market event");
220 }
221 }
222 Err(broadcast::error::RecvError::Closed) => {
223 info!("event channel closed, computation manager shutting down");
224 break;
225 }
226 Err(broadcast::error::RecvError::Lagged(skipped)) => {
227 warn!(
228 skipped,
229 "computation manager lagged, skipped {} events. Recomputing from current state.",
230 skipped
231 );
232 let market = self.market_data.read().await;
233 let changed = ChangedComponents::all(&market);
234 drop(market);
235 self.compute_all(&changed).await;
236 }
237 }
238 }
239 }
240 }
241 }
242
243 async fn compute_all(&self, changed: &ChangedComponents) {
253 let total_start = Instant::now();
254
255 let Some(block) = self
257 .market_data
258 .read()
259 .await
260 .last_updated()
261 .map(|b| b.number())
262 else {
263 warn!("market data has no last updated block, skipping computations");
264 return;
265 };
266
267 let _ = self
269 .event_tx
270 .send(DerivedDataEvent::NewBlock { block });
271
272 let spot_start = Instant::now();
274 let spot_prices_result = self
275 .spot_price_computation
276 .compute(&self.market_data, &self.store, changed)
277 .await;
278 let spot_elapsed = spot_start.elapsed();
279
280 match spot_prices_result {
282 Ok(prices) => {
283 let count = prices.len();
284 self.store
285 .write()
286 .await
287 .set_spot_prices(prices, block);
288 info!(count, elapsed_ms = spot_elapsed.as_millis(), "spot prices computed");
289 let _ = self
290 .event_tx
291 .send(DerivedDataEvent::ComputationComplete {
292 computation_id: SpotPriceComputation::ID,
293 block,
294 });
295 }
296 Err(e) => {
297 warn!(error = ?e, elapsed_ms = spot_elapsed.as_millis(), "spot price computation failed");
298 return;
300 }
301 }
302
303 let (token_prices_result, pool_depths_result) = tokio::join!(
305 async {
306 let start = Instant::now();
307 let result = self
308 .token_price_computation
309 .compute(&self.market_data, &self.store, changed)
310 .await;
311 (result, start.elapsed())
312 },
313 async {
314 let start = Instant::now();
315 let result = self
316 .pool_depth_computation
317 .compute(&self.market_data, &self.store, changed)
318 .await;
319 (result, start.elapsed())
320 }
321 );
322 let (token_prices_result, token_elapsed) = token_prices_result;
323 let (pool_depths_result, depth_elapsed) = pool_depths_result;
324
325 let mut store_write = self.store.write().await;
327
328 match token_prices_result {
329 Ok(prices) => {
330 let count = prices.len();
331 store_write.set_token_prices(prices, block);
332 info!(count, elapsed_ms = token_elapsed.as_millis(), "token prices computed");
333 let _ = self
334 .event_tx
335 .send(DerivedDataEvent::ComputationComplete {
336 computation_id: TokenGasPriceComputation::ID,
337 block,
338 });
339 }
340 Err(e) => {
341 warn!(error = ?e, "token price computation failed");
342 }
343 }
344
345 match pool_depths_result {
346 Ok(depths) => {
347 let count = depths.len();
348 store_write.set_pool_depths(depths, block);
349 info!(count, elapsed_ms = depth_elapsed.as_millis(), "pool depths computed");
350 let _ = self
351 .event_tx
352 .send(DerivedDataEvent::ComputationComplete {
353 computation_id: PoolDepthComputation::ID,
354 block,
355 });
356 }
357 Err(e) => {
358 warn!(error = ?e, "pool depth computation failed");
359 }
360 }
361
362 let total_elapsed = total_start.elapsed();
363 info!(block, total_ms = total_elapsed.as_millis(), "all derived computations complete");
364 }
365}
366
367#[async_trait]
368impl MarketEventHandler for ComputationManager {
369 async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
370 match event {
371 MarketEvent::MarketUpdated {
372 added_components,
373 removed_components,
374 updated_components,
375 } if !added_components.is_empty() ||
376 !removed_components.is_empty() ||
377 !updated_components.is_empty() =>
378 {
379 trace!(
380 added = added_components.len(),
381 removed = removed_components.len(),
382 updated = updated_components.len(),
383 "market updated, running incremental computations"
384 );
385
386 let changed = ChangedComponents {
387 added: added_components.clone(),
388 removed: removed_components.clone(),
389 updated: updated_components.clone(),
390 is_full_recompute: false,
391 };
392 self.compute_all(&changed).await;
393 }
394 _ => {
395 trace!("empty market update, skipping computations");
396 }
397 }
398
399 Ok(())
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use std::collections::HashMap;
406
407 use tokio::sync::broadcast;
408
409 use super::*;
410 use crate::algorithm::test_utils::{setup_market, token, MockProtocolSim};
411
412 #[test]
413 fn invalid_slippage_threshold_returns_error() {
414 let (market, _) = setup_market(vec![]);
415 let config = ComputationManagerConfig::new().with_depth_slippage_threshold(1.5);
416
417 let result = ComputationManager::new(config, market);
418 assert!(matches!(result, Err(ComputationError::InvalidConfiguration(_))));
419 }
420
421 #[tokio::test]
422 async fn handle_event_runs_computations_on_market_update() {
423 let eth = token(1, "ETH");
424 let usdc = token(2, "USDC");
425
426 let (market, _) =
427 setup_market(vec![("eth_usdc", ð, &usdc, MockProtocolSim::new(2000.0).with_gas(0))]);
428
429 let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
430 let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
431
432 let event = MarketEvent::MarketUpdated {
433 added_components: HashMap::from([(
434 "eth_usdc".to_string(),
435 vec![eth.address.clone(), usdc.address.clone()],
436 )]),
437 removed_components: vec![],
438 updated_components: vec![],
439 };
440
441 manager
442 .handle_event(&event)
443 .await
444 .unwrap();
445
446 let store = manager.store();
447 let guard = store.read().await;
448 assert!(guard.token_prices().is_some());
449 assert!(guard.spot_prices().is_some());
450 }
451
452 #[tokio::test]
453 async fn handle_event_skips_empty_update() {
454 let (market, _) = setup_market(vec![]);
455 let config = ComputationManagerConfig::new();
456 let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
457
458 let event = MarketEvent::MarketUpdated {
459 added_components: HashMap::new(),
460 removed_components: vec![],
461 updated_components: vec![],
462 };
463
464 manager
465 .handle_event(&event)
466 .await
467 .unwrap();
468
469 let store = manager.store();
470 let guard = store.read().await;
471 assert!(guard.token_prices().is_none());
472 }
473
474 #[tokio::test]
475 async fn run_shuts_down_on_signal() {
476 let (market, _) = setup_market(vec![]);
477 let config = ComputationManagerConfig::new();
478 let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
479
480 let (_event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
481 let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
482
483 let handle = tokio::spawn(async move {
484 manager.run(event_rx, shutdown_rx).await;
485 });
486
487 shutdown_tx.send(()).unwrap();
488
489 tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
490 .await
491 .expect("manager should shutdown")
492 .expect("task should complete successfully");
493 }
494
495 #[tokio::test]
496 async fn run_shuts_down_on_channel_close() {
497 let (market, _) = setup_market(vec![]);
498 let config = ComputationManagerConfig::new();
499 let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
500
501 let (event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
502 let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
503
504 let handle = tokio::spawn(async move {
505 manager.run(event_rx, shutdown_rx).await;
506 });
507
508 drop(event_tx);
509
510 tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
511 .await
512 .expect("manager should shutdown on channel close")
513 .expect("task should complete successfully");
514 }
515}