1use std::{
10 collections::{HashMap, HashSet},
11 sync::Arc,
12 time::Instant,
13};
14
15use async_trait::async_trait;
16use tokio::sync::{broadcast, RwLock};
17use tracing::{debug, info, trace, warn};
18use tycho_simulation::tycho_common::models::Address;
19
20use crate::{feed::market_data::SharedMarketData, types::ComponentId};
21
22#[derive(Debug, Clone, Default)]
27pub struct ChangedComponents {
28 pub added: HashMap<ComponentId, Vec<Address>>,
30 pub removed: Vec<ComponentId>,
32 pub updated: Vec<ComponentId>,
34 pub is_full_recompute: bool,
36}
37
38impl ChangedComponents {
39 pub fn all(market: &SharedMarketData) -> Self {
43 Self {
44 added: market.component_topology().clone(),
45 removed: vec![],
46 updated: vec![],
47 is_full_recompute: true,
48 }
49 }
50
51 pub fn is_topology_change(&self) -> bool {
53 !self.added.is_empty() || !self.removed.is_empty()
54 }
55
56 pub fn all_changed_ids(&self) -> HashSet<ComponentId> {
58 let mut all = HashSet::new();
59 all.extend(self.added.keys().cloned());
60 all.extend(self.removed.iter().cloned());
61 all.extend(self.updated.iter().cloned());
62 all
63 }
64}
65
66use super::{
67 computation::DerivedComputation,
68 computations::{PoolDepthComputation, SpotPriceComputation, TokenGasPriceComputation},
69 error::ComputationError,
70 events::DerivedDataEvent,
71 store::DerivedData,
72};
73use crate::feed::{
74 events::{EventError, MarketEvent, MarketEventHandler},
75 market_data::SharedMarketDataRef,
76};
77
78pub type SharedDerivedDataRef = Arc<RwLock<DerivedData>>;
80
81#[derive(Debug, Clone)]
87pub struct ComputationManagerConfig {
88 gas_token: Address,
90 max_hop: usize,
92 depth_slippage_threshold: f64,
94}
95
96impl ComputationManagerConfig {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn with_depth_slippage_threshold(mut self, threshold: f64) -> Self {
104 self.depth_slippage_threshold = threshold;
105 self
106 }
107
108 pub fn with_max_hop(mut self, hop_count: usize) -> Self {
110 self.max_hop = hop_count;
111 self
112 }
113
114 pub fn with_gas_token(mut self, gas_token: Address) -> Self {
116 self.gas_token = gas_token;
117 self
118 }
119
120 pub fn gas_token(&self) -> &Address {
122 &self.gas_token
123 }
124
125 pub fn max_hop(&self) -> usize {
127 self.max_hop
128 }
129
130 pub fn depth_slippage_threshold(&self) -> f64 {
132 self.depth_slippage_threshold
133 }
134}
135
136impl Default for ComputationManagerConfig {
137 fn default() -> Self {
138 Self { gas_token: Address::zero(20), max_hop: 2, depth_slippage_threshold: 0.01 }
139 }
140}
141
142pub struct ComputationManager {
144 market_data: SharedMarketDataRef,
146 store: SharedDerivedDataRef,
148 token_price_computation: TokenGasPriceComputation,
150 spot_price_computation: SpotPriceComputation,
152 pool_depth_computation: PoolDepthComputation,
154 event_tx: broadcast::Sender<DerivedDataEvent>,
156}
157
158impl ComputationManager {
159 pub fn new(
165 config: ComputationManagerConfig,
166 market_data: SharedMarketDataRef,
167 ) -> Result<(Self, broadcast::Receiver<DerivedDataEvent>), ComputationError> {
168 let pool_depth_computation = PoolDepthComputation::new(config.depth_slippage_threshold)?;
169 let (event_tx, event_rx) = broadcast::channel(64);
170
171 Ok((
172 Self {
173 market_data,
174 store: DerivedData::new_shared(),
175 token_price_computation: TokenGasPriceComputation::default()
176 .with_max_hops(config.max_hop)
177 .with_gas_token(config.gas_token),
178 spot_price_computation: SpotPriceComputation::new(),
179 pool_depth_computation,
180 event_tx,
181 },
182 event_rx,
183 ))
184 }
185
186 pub fn store(&self) -> SharedDerivedDataRef {
188 Arc::clone(&self.store)
189 }
190
191 pub fn event_sender(&self) -> broadcast::Sender<DerivedDataEvent> {
193 self.event_tx.clone()
194 }
195
196 pub async fn run(
200 mut self,
201 mut event_rx: broadcast::Receiver<MarketEvent>,
202 mut shutdown_rx: broadcast::Receiver<()>,
203 ) {
204 info!("computation manager started");
205
206 loop {
207 tokio::select! {
208 biased;
209
210 _ = shutdown_rx.recv() => {
211 info!("computation manager shutting down");
212 break;
213 }
214
215 event_result = event_rx.recv() => {
216 match event_result {
217 Ok(event) => {
218 if let Err(e) = self.handle_event(&event).await {
219 warn!(error = ?e, "failed to handle market event");
220 }
221 }
222 Err(broadcast::error::RecvError::Closed) => {
223 info!("event channel closed, computation manager shutting down");
224 break;
225 }
226 Err(broadcast::error::RecvError::Lagged(skipped)) => {
227 warn!(
228 skipped,
229 "computation manager lagged, skipped {} events. Recomputing from current state.",
230 skipped
231 );
232 let market = self.market_data.read().await;
233 let changed = ChangedComponents::all(&market);
234 drop(market);
235 self.compute_all(&changed).await;
236 }
237 }
238 }
239 }
240 }
241 }
242
243 async fn compute_all(&self, changed: &ChangedComponents) {
253 let total_start = Instant::now();
254
255 let Some(block) = self
257 .market_data
258 .read()
259 .await
260 .last_updated()
261 .map(|b| b.number())
262 else {
263 warn!("market data has no last updated block, skipping computations");
264 return;
265 };
266
267 let _ = self
269 .event_tx
270 .send(DerivedDataEvent::NewBlock { block });
271
272 let spot_start = Instant::now();
274 let spot_prices_result = self
275 .spot_price_computation
276 .compute(&self.market_data, &self.store, changed)
277 .await;
278 let spot_elapsed = spot_start.elapsed();
279
280 match spot_prices_result {
282 Ok(output) => {
283 let count = output.data.len();
284 if output.has_failures() {
285 warn!(
286 count,
287 failed = output.failed_items.len(),
288 "spot prices partial failures"
289 );
290 for item in &output.failed_items {
291 debug!(key = %item.key, error = %item.error, "spot price failed item");
292 }
293 } else {
294 info!(count, elapsed_ms = spot_elapsed.as_millis(), "spot prices computed");
295 }
296 self.store
297 .write()
298 .await
299 .set_spot_prices(
300 output.data,
301 output.failed_items.clone(),
302 block,
303 changed.is_full_recompute,
304 );
305 let _ = self
306 .event_tx
307 .send(DerivedDataEvent::ComputationComplete {
308 computation_id: SpotPriceComputation::ID,
309 block,
310 failed_items: output.failed_items,
311 });
312 }
313 Err(e) => {
314 warn!(error = ?e, elapsed_ms = spot_elapsed.as_millis(), "spot price computation failed");
315 let _ = self
316 .event_tx
317 .send(DerivedDataEvent::ComputationFailed {
318 computation_id: SpotPriceComputation::ID,
319 block,
320 });
321 let _ = self
322 .event_tx
323 .send(DerivedDataEvent::ComputationFailed {
324 computation_id: TokenGasPriceComputation::ID,
325 block,
326 });
327 let _ = self
328 .event_tx
329 .send(DerivedDataEvent::ComputationFailed {
330 computation_id: PoolDepthComputation::ID,
331 block,
332 });
333 return;
335 }
336 }
337
338 let (token_prices_result, pool_depths_result) = tokio::join!(
340 async {
341 let start = Instant::now();
342 let result = self
343 .token_price_computation
344 .compute(&self.market_data, &self.store, changed)
345 .await;
346 (result, start.elapsed())
347 },
348 async {
349 let start = Instant::now();
350 let result = self
351 .pool_depth_computation
352 .compute(&self.market_data, &self.store, changed)
353 .await;
354 (result, start.elapsed())
355 }
356 );
357 let (token_prices_result, token_elapsed) = token_prices_result;
358 let (pool_depths_result, depth_elapsed) = pool_depths_result;
359
360 let mut store_write = self.store.write().await;
362
363 match token_prices_result {
364 Ok(output) => {
365 let count = output.data.len();
366 if output.has_failures() {
367 warn!(
368 count,
369 failed = output.failed_items.len(),
370 "token prices partial failures"
371 );
372 for item in &output.failed_items {
373 debug!(key = %item.key, error = %item.error, "token price failed item");
374 }
375 } else {
376 info!(count, elapsed_ms = token_elapsed.as_millis(), "token prices computed");
377 }
378 store_write.set_token_prices(
379 output.data,
380 output.failed_items.clone(),
381 block,
382 changed.is_full_recompute,
383 );
384 let _ = self
385 .event_tx
386 .send(DerivedDataEvent::ComputationComplete {
387 computation_id: TokenGasPriceComputation::ID,
388 block,
389 failed_items: output.failed_items,
390 });
391 }
392 Err(e) => {
393 warn!(error = ?e, "token price computation failed");
394 let _ = self
395 .event_tx
396 .send(DerivedDataEvent::ComputationFailed {
397 computation_id: TokenGasPriceComputation::ID,
398 block,
399 });
400 }
401 }
402
403 match pool_depths_result {
404 Ok(output) => {
405 let count = output.data.len();
406 if output.has_failures() {
407 warn!(
408 count,
409 failed = output.failed_items.len(),
410 "pool depths partial failures"
411 );
412 for item in &output.failed_items {
413 debug!(key = %item.key, error = %item.error, "pool depth failed item");
414 }
415 } else {
416 info!(count, elapsed_ms = depth_elapsed.as_millis(), "pool depths computed");
417 }
418 store_write.set_pool_depths(
419 output.data,
420 output.failed_items.clone(),
421 block,
422 changed.is_full_recompute,
423 );
424 let _ = self
425 .event_tx
426 .send(DerivedDataEvent::ComputationComplete {
427 computation_id: PoolDepthComputation::ID,
428 block,
429 failed_items: output.failed_items,
430 });
431 }
432 Err(e) => {
433 warn!(error = ?e, "pool depth computation failed");
434 let _ = self
435 .event_tx
436 .send(DerivedDataEvent::ComputationFailed {
437 computation_id: PoolDepthComputation::ID,
438 block,
439 });
440 }
441 }
442
443 let total_elapsed = total_start.elapsed();
444 info!(block, total_ms = total_elapsed.as_millis(), "all derived computations complete");
445 }
446}
447
448#[async_trait]
449impl MarketEventHandler for ComputationManager {
450 async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
451 match event {
452 MarketEvent::MarketUpdated {
453 added_components,
454 removed_components,
455 updated_components,
456 } if !added_components.is_empty() ||
457 !removed_components.is_empty() ||
458 !updated_components.is_empty() =>
459 {
460 trace!(
461 added = added_components.len(),
462 removed = removed_components.len(),
463 updated = updated_components.len(),
464 "market updated, running incremental computations"
465 );
466
467 let changed = ChangedComponents {
468 added: added_components.clone(),
469 removed: removed_components.clone(),
470 updated: updated_components.clone(),
471 is_full_recompute: false,
472 };
473 self.compute_all(&changed).await;
474 }
475 _ => {
476 trace!("empty market update, skipping computations");
477 }
478 }
479
480 Ok(())
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use std::{collections::HashMap, sync::Arc};
487
488 use tokio::sync::{broadcast, RwLock};
489
490 use super::*;
491 use crate::{
492 algorithm::test_utils::{component, setup_market, token, MockProtocolSim},
493 feed::market_data::SharedMarketData,
494 types::BlockInfo,
495 };
496
497 fn drain_events(rx: &mut broadcast::Receiver<DerivedDataEvent>) -> Vec<DerivedDataEvent> {
499 let mut events = vec![];
500 loop {
501 match rx.try_recv() {
502 Ok(e) => events.push(e),
503 Err(broadcast::error::TryRecvError::Empty) => break,
504 Err(broadcast::error::TryRecvError::Lagged(_)) => continue,
505 Err(broadcast::error::TryRecvError::Closed) => break,
506 }
507 }
508 events
509 }
510
511 #[test]
512 fn invalid_slippage_threshold_returns_error() {
513 let (market, _) = setup_market(vec![]);
514 let config = ComputationManagerConfig::new().with_depth_slippage_threshold(1.5);
515
516 let result = ComputationManager::new(config, market);
517 assert!(matches!(result, Err(ComputationError::InvalidConfiguration(_))));
518 }
519
520 #[tokio::test]
521 async fn handle_event_runs_computations_on_market_update() {
522 let eth = token(1, "ETH");
523 let usdc = token(2, "USDC");
524
525 let (market, _) =
526 setup_market(vec![("eth_usdc", ð, &usdc, MockProtocolSim::new(2000.0).with_gas(0))]);
527
528 let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
529 let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
530
531 let event = MarketEvent::MarketUpdated {
532 added_components: HashMap::from([(
533 "eth_usdc".to_string(),
534 vec![eth.address.clone(), usdc.address.clone()],
535 )]),
536 removed_components: vec![],
537 updated_components: vec![],
538 };
539
540 manager
541 .handle_event(&event)
542 .await
543 .unwrap();
544
545 let store = manager.store();
546 let guard = store.read().await;
547 assert!(guard.token_prices().is_some());
548 assert!(guard.spot_prices().is_some());
549 }
550
551 #[tokio::test]
552 async fn handle_event_skips_empty_update() {
553 let (market, _) = setup_market(vec![]);
554 let config = ComputationManagerConfig::new();
555 let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
556
557 let event = MarketEvent::MarketUpdated {
558 added_components: HashMap::new(),
559 removed_components: vec![],
560 updated_components: vec![],
561 };
562
563 manager
564 .handle_event(&event)
565 .await
566 .unwrap();
567
568 let store = manager.store();
569 let guard = store.read().await;
570 assert!(guard.token_prices().is_none());
571 }
572
573 #[tokio::test]
574 async fn run_shuts_down_on_signal() {
575 let (market, _) = setup_market(vec![]);
576 let config = ComputationManagerConfig::new();
577 let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
578
579 let (_event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
580 let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
581
582 let handle = tokio::spawn(async move {
583 manager.run(event_rx, shutdown_rx).await;
584 });
585
586 shutdown_tx.send(()).unwrap();
587
588 tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
589 .await
590 .expect("manager should shutdown")
591 .expect("task should complete successfully");
592 }
593
594 fn market_with_component_no_sim_state() -> Arc<RwLock<SharedMarketData>> {
599 let eth = token(1, "ETH");
600 let usdc = token(2, "USDC");
601 let pool = component("pool", &[eth.clone(), usdc.clone()]);
602
603 let mut market = SharedMarketData::new();
604 market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
605 market.upsert_components(std::iter::once(pool));
606 market.upsert_tokens([eth, usdc]);
608 Arc::new(RwLock::new(market))
609 }
610
611 fn market_with_mixed_sim_states() -> Arc<RwLock<SharedMarketData>> {
614 let eth = token(1, "ETH");
615 let usdc = token(2, "USDC");
616 let dai = token(3, "DAI");
617
618 let pool1 = component("eth_usdc", &[eth.clone(), usdc.clone()]);
619 let pool2 = component("eth_dai", &[eth.clone(), dai.clone()]);
620
621 let mut market = SharedMarketData::new();
622 market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
623 market.upsert_components([pool1, pool2]);
624 market
626 .update_states([("eth_usdc".to_string(), Box::new(MockProtocolSim::new(2000.0)) as _)]);
627 market.upsert_tokens([eth, usdc, dai]);
628 Arc::new(RwLock::new(market))
629 }
630
631 fn market_with_sim_state_no_gas_price() -> Arc<RwLock<SharedMarketData>> {
636 let eth = token(1, "ETH");
637 let usdc = token(2, "USDC");
638 let pool = component("pool", &[eth.clone(), usdc.clone()]);
639
640 let mut market = SharedMarketData::new();
641 market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
643 market.upsert_components(std::iter::once(pool));
644 market.update_states([("pool".to_string(), Box::new(MockProtocolSim::new(2000.0)) as _)]);
645 market.upsert_tokens([eth, usdc]);
646 Arc::new(RwLock::new(market))
647 }
648
649 #[tokio::test]
650 async fn test_spot_price_failure_broadcasts_computation_failed() {
651 let market = market_with_component_no_sim_state();
652 let config = ComputationManagerConfig::new();
653 let (manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
654
655 let changed = ChangedComponents { is_full_recompute: true, ..Default::default() };
657 manager.compute_all(&changed).await;
658
659 let events = drain_events(&mut event_rx);
660
661 assert!(
662 events.iter().any(|e| matches!(
663 e,
664 DerivedDataEvent::ComputationFailed { computation_id: "spot_prices", .. }
665 )),
666 "expected ComputationFailed(spot_prices) in events: {events:?}"
667 );
668 }
669
670 #[tokio::test]
671 async fn test_token_price_failure_broadcasts_computation_failed() {
672 let eth = token(1, "ETH");
673 let usdc = token(2, "USDC");
674 let market = market_with_sim_state_no_gas_price();
675 let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
676 let (mut manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
677
678 let event = MarketEvent::MarketUpdated {
680 added_components: HashMap::from([(
681 "pool".to_string(),
682 vec![eth.address.clone(), usdc.address.clone()],
683 )]),
684 removed_components: vec![],
685 updated_components: vec![],
686 };
687 manager
688 .handle_event(&event)
689 .await
690 .unwrap();
691
692 let events = drain_events(&mut event_rx);
693 assert!(
694 events.iter().any(|e| matches!(
695 e,
696 DerivedDataEvent::ComputationFailed { computation_id: "token_prices", .. }
697 )),
698 "expected ComputationFailed(token_prices) in events: {events:?}"
699 );
700 }
701
702 #[tokio::test]
703 async fn run_shuts_down_on_channel_close() {
704 let (market, _) = setup_market(vec![]);
705 let config = ComputationManagerConfig::new();
706 let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
707
708 let (event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
709 let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
710
711 let handle = tokio::spawn(async move {
712 manager.run(event_rx, shutdown_rx).await;
713 });
714
715 drop(event_tx);
716
717 tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
718 .await
719 .expect("manager should shutdown on channel close")
720 .expect("task should complete successfully");
721 }
722
723 #[tokio::test]
724 async fn partial_spot_price_failure_broadcasts_computation_complete() {
725 let market = market_with_mixed_sim_states();
728 let config = ComputationManagerConfig::new();
729 let (manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
730
731 let changed = ChangedComponents { is_full_recompute: true, ..Default::default() };
732 manager.compute_all(&changed).await;
733
734 let events = drain_events(&mut event_rx);
735
736 assert!(
738 events.iter().any(|e| matches!(
739 e,
740 DerivedDataEvent::ComputationComplete { computation_id: "spot_prices", .. }
741 )),
742 "expected ComputationComplete(spot_prices), got: {events:?}"
743 );
744 assert!(
745 !events.iter().any(|e| matches!(
746 e,
747 DerivedDataEvent::ComputationFailed { computation_id: "spot_prices", .. }
748 )),
749 "should not broadcast ComputationFailed for partial failure"
750 );
751
752 let complete = events.iter().find(|e| {
754 matches!(e, DerivedDataEvent::ComputationComplete { computation_id: "spot_prices", .. })
755 });
756 if let Some(DerivedDataEvent::ComputationComplete { failed_items, .. }) = complete {
757 assert!(
758 !failed_items.is_empty(),
759 "ComputationComplete should carry failed_items for pool2"
760 );
761 }
762
763 let eth = token(1, "ETH");
766 let dai = token(3, "DAI");
767 let store = manager.store();
768 let guard = store.read().await;
769 let key_eth_dai = ("eth_dai".to_string(), eth.address.clone(), dai.address.clone());
770 let key_dai_eth = ("eth_dai".to_string(), dai.address.clone(), eth.address.clone());
771 assert!(
772 guard
773 .spot_price_failure(&key_eth_dai)
774 .is_some() ||
775 guard
776 .spot_price_failure(&key_dai_eth)
777 .is_some(),
778 "store should persist failure reason for eth_dai (missing sim state)"
779 );
780 }
781}