1use std::{
10 collections::{HashMap, HashSet},
11 sync::Arc,
12 time::Instant,
13};
14
15use async_trait::async_trait;
16use tokio::sync::{broadcast, RwLock};
17use tracing::{debug, info, trace, warn};
18use tycho_simulation::tycho_common::models::Address;
19
20use crate::types::ComponentId;
21
22#[derive(Debug, Clone, Default)]
27pub struct ChangedComponents {
28 pub added: HashMap<ComponentId, Vec<Address>>,
30 pub removed: Vec<ComponentId>,
32 pub updated: Vec<ComponentId>,
34 pub is_full_recompute: bool,
36}
37
38impl ChangedComponents {
39 pub fn all(market: MarketDataView) -> Self {
43 Self {
44 added: market.component_topology().clone(),
45 removed: vec![],
46 updated: vec![],
47 is_full_recompute: true,
48 }
49 }
50
51 pub fn is_topology_change(&self) -> bool {
53 !self.added.is_empty() || !self.removed.is_empty()
54 }
55
56 pub fn all_changed_ids(&self) -> HashSet<ComponentId> {
58 let mut all = HashSet::new();
59 all.extend(self.added.keys().cloned());
60 all.extend(self.removed.iter().cloned());
61 all.extend(self.updated.iter().cloned());
62 all
63 }
64}
65
66use super::{
67 computation::DerivedComputation,
68 computations::{PoolDepthComputation, SpotPriceComputation, TokenGasPriceComputation},
69 error::ComputationError,
70 events::DerivedDataEvent,
71 store::DerivedData,
72};
73use crate::feed::{
74 events::{EventError, MarketEvent, MarketEventHandler},
75 market_data::{MarketData, MarketDataView},
76};
77
78pub type SharedDerivedDataRef = Arc<RwLock<DerivedData>>;
80
81#[derive(Debug, Clone)]
87pub struct ComputationManagerConfig {
88 gas_token: Address,
90 max_hop: usize,
92 depth_slippage_threshold: f64,
94}
95
96impl ComputationManagerConfig {
97 pub fn new() -> Self {
99 Self::default()
100 }
101
102 pub fn with_depth_slippage_threshold(mut self, threshold: f64) -> Self {
104 self.depth_slippage_threshold = threshold;
105 self
106 }
107
108 pub fn with_max_hop(mut self, hop_count: usize) -> Self {
110 self.max_hop = hop_count;
111 self
112 }
113
114 pub fn with_gas_token(mut self, gas_token: Address) -> Self {
116 self.gas_token = gas_token;
117 self
118 }
119
120 pub fn gas_token(&self) -> &Address {
122 &self.gas_token
123 }
124
125 pub fn max_hop(&self) -> usize {
127 self.max_hop
128 }
129
130 pub fn depth_slippage_threshold(&self) -> f64 {
132 self.depth_slippage_threshold
133 }
134}
135
136impl Default for ComputationManagerConfig {
137 fn default() -> Self {
138 Self { gas_token: Address::zero(20), max_hop: 2, depth_slippage_threshold: 0.01 }
139 }
140}
141
142pub struct ComputationManager {
144 market_data: MarketData,
146 store: SharedDerivedDataRef,
148 token_price_computation: TokenGasPriceComputation,
150 spot_price_computation: SpotPriceComputation,
152 pool_depth_computation: PoolDepthComputation,
154 event_tx: broadcast::Sender<DerivedDataEvent>,
156}
157
158impl ComputationManager {
159 pub fn new(
165 config: ComputationManagerConfig,
166 market_data: MarketData,
167 ) -> Result<(Self, broadcast::Receiver<DerivedDataEvent>), ComputationError> {
168 let pool_depth_computation = PoolDepthComputation::new(config.depth_slippage_threshold)?;
169 let (event_tx, event_rx) = broadcast::channel(64);
170
171 Ok((
172 Self {
173 market_data,
174 store: DerivedData::new_shared(),
175 token_price_computation: TokenGasPriceComputation::default()
176 .with_max_hops(config.max_hop)
177 .with_gas_token(config.gas_token),
178 spot_price_computation: SpotPriceComputation::new(),
179 pool_depth_computation,
180 event_tx,
181 },
182 event_rx,
183 ))
184 }
185
186 pub fn store(&self) -> SharedDerivedDataRef {
188 Arc::clone(&self.store)
189 }
190
191 pub fn event_sender(&self) -> broadcast::Sender<DerivedDataEvent> {
193 self.event_tx.clone()
194 }
195
196 pub async fn run(
200 mut self,
201 mut event_rx: broadcast::Receiver<MarketEvent>,
202 mut shutdown_rx: broadcast::Receiver<()>,
203 ) {
204 info!("computation manager started");
205
206 loop {
207 tokio::select! {
208 biased;
209
210 _ = shutdown_rx.recv() => {
211 info!("computation manager shutting down");
212 break;
213 }
214
215 event_result = event_rx.recv() => {
216 match event_result {
217 Ok(event) => {
218 if let Err(e) = self.handle_event(&event).await {
219 warn!(error = ?e, "failed to handle market event");
220 }
221 }
222 Err(broadcast::error::RecvError::Closed) => {
223 info!("event channel closed, computation manager shutting down");
224 break;
225 }
226 Err(broadcast::error::RecvError::Lagged(skipped)) => {
227 warn!(
228 skipped,
229 "computation manager lagged, skipped {} events. Recomputing from current state.",
230 skipped
231 );
232 let market = self.market_data.read().await;
233 let changed = ChangedComponents::all(market);
234 self.compute_all(&changed).await;
235 }
236 }
237 }
238 }
239 }
240 }
241
242 async fn compute_all(&self, changed: &ChangedComponents) {
252 let total_start = Instant::now();
253
254 let Some(block) = self
256 .market_data
257 .read()
258 .await
259 .last_updated()
260 .map(|b| b.number())
261 else {
262 warn!("market data has no last updated block, skipping computations");
263 return;
264 };
265
266 let _ = self
268 .event_tx
269 .send(DerivedDataEvent::NewBlock { block });
270
271 let spot_start = Instant::now();
273 let spot_prices_result = self
274 .spot_price_computation
275 .compute(&self.market_data, &self.store, changed)
276 .await;
277 let spot_elapsed = spot_start.elapsed();
278
279 match spot_prices_result {
281 Ok(output) => {
282 let count = output.data.len();
283 if output.has_failures() {
284 warn!(
285 count,
286 failed = output.failed_items.len(),
287 "spot prices partial failures"
288 );
289 for item in &output.failed_items {
290 debug!(key = %item.key, error = %item.error, "spot price failed item");
291 }
292 } else {
293 info!(count, elapsed_ms = spot_elapsed.as_millis(), "spot prices computed");
294 }
295 self.store
296 .write()
297 .await
298 .set_spot_prices(
299 output.data,
300 output.failed_items.clone(),
301 block,
302 changed.is_full_recompute,
303 );
304 let _ = self
305 .event_tx
306 .send(DerivedDataEvent::ComputationComplete {
307 computation_id: SpotPriceComputation::ID,
308 block,
309 failed_items: output.failed_items,
310 });
311 }
312 Err(e) => {
313 warn!(error = ?e, elapsed_ms = spot_elapsed.as_millis(), "spot price computation failed");
314 let _ = self
315 .event_tx
316 .send(DerivedDataEvent::ComputationFailed {
317 computation_id: SpotPriceComputation::ID,
318 block,
319 });
320 let _ = self
321 .event_tx
322 .send(DerivedDataEvent::ComputationFailed {
323 computation_id: TokenGasPriceComputation::ID,
324 block,
325 });
326 let _ = self
327 .event_tx
328 .send(DerivedDataEvent::ComputationFailed {
329 computation_id: PoolDepthComputation::ID,
330 block,
331 });
332 return;
334 }
335 }
336
337 let (token_prices_result, pool_depths_result) = tokio::join!(
339 async {
340 let start = Instant::now();
341 let result = self
342 .token_price_computation
343 .compute(&self.market_data, &self.store, changed)
344 .await;
345 (result, start.elapsed())
346 },
347 async {
348 let start = Instant::now();
349 let result = self
350 .pool_depth_computation
351 .compute(&self.market_data, &self.store, changed)
352 .await;
353 (result, start.elapsed())
354 }
355 );
356 let (token_prices_result, token_elapsed) = token_prices_result;
357 let (pool_depths_result, depth_elapsed) = pool_depths_result;
358
359 let mut store_write = self.store.write().await;
361
362 match token_prices_result {
363 Ok(output) => {
364 let count = output.data.len();
365 if output.has_failures() {
366 warn!(
367 count,
368 failed = output.failed_items.len(),
369 "token prices partial failures"
370 );
371 for item in &output.failed_items {
372 debug!(key = %item.key, error = %item.error, "token price failed item");
373 }
374 } else {
375 info!(count, elapsed_ms = token_elapsed.as_millis(), "token prices computed");
376 }
377 store_write.set_token_prices(
378 output.data,
379 output.failed_items.clone(),
380 block,
381 changed.is_full_recompute,
382 );
383 let _ = self
384 .event_tx
385 .send(DerivedDataEvent::ComputationComplete {
386 computation_id: TokenGasPriceComputation::ID,
387 block,
388 failed_items: output.failed_items,
389 });
390 }
391 Err(e) => {
392 warn!(error = ?e, "token price computation failed");
393 let _ = self
394 .event_tx
395 .send(DerivedDataEvent::ComputationFailed {
396 computation_id: TokenGasPriceComputation::ID,
397 block,
398 });
399 }
400 }
401
402 match pool_depths_result {
403 Ok(output) => {
404 let count = output.data.len();
405 if output.has_failures() {
406 warn!(
407 count,
408 failed = output.failed_items.len(),
409 "pool depths partial failures"
410 );
411 for item in &output.failed_items {
412 debug!(key = %item.key, error = %item.error, "pool depth failed item");
413 }
414 } else {
415 info!(count, elapsed_ms = depth_elapsed.as_millis(), "pool depths computed");
416 }
417 store_write.set_pool_depths(
418 output.data,
419 output.failed_items.clone(),
420 block,
421 changed.is_full_recompute,
422 );
423 let _ = self
424 .event_tx
425 .send(DerivedDataEvent::ComputationComplete {
426 computation_id: PoolDepthComputation::ID,
427 block,
428 failed_items: output.failed_items,
429 });
430 }
431 Err(e) => {
432 warn!(error = ?e, "pool depth computation failed");
433 let _ = self
434 .event_tx
435 .send(DerivedDataEvent::ComputationFailed {
436 computation_id: PoolDepthComputation::ID,
437 block,
438 });
439 }
440 }
441
442 let total_elapsed = total_start.elapsed();
443 info!(block, total_ms = total_elapsed.as_millis(), "all derived computations complete");
444 }
445}
446
447#[async_trait]
448impl MarketEventHandler for ComputationManager {
449 async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
450 match event {
451 MarketEvent::MarketUpdated {
452 added_components,
453 removed_components,
454 updated_components,
455 } if !added_components.is_empty() ||
456 !removed_components.is_empty() ||
457 !updated_components.is_empty() =>
458 {
459 trace!(
460 added = added_components.len(),
461 removed = removed_components.len(),
462 updated = updated_components.len(),
463 "market updated, running incremental computations"
464 );
465
466 let changed = ChangedComponents {
467 added: added_components.clone(),
468 removed: removed_components.clone(),
469 updated: updated_components.clone(),
470 is_full_recompute: false,
471 };
472 self.compute_all(&changed).await;
473 }
474 _ => {
475 trace!("empty market update, skipping computations");
476 }
477 }
478
479 Ok(())
480 }
481}
482
483#[cfg(test)]
484mod tests {
485 use std::collections::HashMap;
486
487 use tokio::sync::broadcast;
488
489 use super::*;
490 use crate::{
491 algorithm::test_utils::{component, setup_market, token, MockProtocolSim},
492 feed::market_data::{MarketData, MarketState},
493 types::BlockInfo,
494 };
495
496 fn drain_events(rx: &mut broadcast::Receiver<DerivedDataEvent>) -> Vec<DerivedDataEvent> {
498 let mut events = vec![];
499 loop {
500 match rx.try_recv() {
501 Ok(e) => events.push(e),
502 Err(broadcast::error::TryRecvError::Empty) => break,
503 Err(broadcast::error::TryRecvError::Lagged(_)) => continue,
504 Err(broadcast::error::TryRecvError::Closed) => break,
505 }
506 }
507 events
508 }
509
510 #[test]
511 fn invalid_slippage_threshold_returns_error() {
512 let (market, _) = setup_market(vec![]);
513 let config = ComputationManagerConfig::new().with_depth_slippage_threshold(1.5);
514
515 let result = ComputationManager::new(config, market);
516 assert!(matches!(result, Err(ComputationError::InvalidConfiguration(_))));
517 }
518
519 #[tokio::test]
520 async fn handle_event_runs_computations_on_market_update() {
521 let eth = token(1, "ETH");
522 let usdc = token(2, "USDC");
523
524 let (market, _) =
525 setup_market(vec![("eth_usdc", ð, &usdc, MockProtocolSim::new(2000.0).with_gas(0))]);
526
527 let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
528 let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
529
530 let event = MarketEvent::MarketUpdated {
531 added_components: HashMap::from([(
532 "eth_usdc".to_string(),
533 vec![eth.address.clone(), usdc.address.clone()],
534 )]),
535 removed_components: vec![],
536 updated_components: vec![],
537 };
538
539 manager
540 .handle_event(&event)
541 .await
542 .unwrap();
543
544 let store = manager.store();
545 let guard = store.read().await;
546 assert!(guard.token_prices().is_some());
547 assert!(guard.spot_prices().is_some());
548 }
549
550 #[tokio::test]
551 async fn handle_event_skips_empty_update() {
552 let (market, _) = setup_market(vec![]);
553 let config = ComputationManagerConfig::new();
554 let (mut manager, _event_rx) = ComputationManager::new(config, market).unwrap();
555
556 let event = MarketEvent::MarketUpdated {
557 added_components: HashMap::new(),
558 removed_components: vec![],
559 updated_components: vec![],
560 };
561
562 manager
563 .handle_event(&event)
564 .await
565 .unwrap();
566
567 let store = manager.store();
568 let guard = store.read().await;
569 assert!(guard.token_prices().is_none());
570 }
571
572 #[tokio::test]
573 async fn run_shuts_down_on_signal() {
574 let (market, _) = setup_market(vec![]);
575 let config = ComputationManagerConfig::new();
576 let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
577
578 let (_event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
579 let (shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
580
581 let handle = tokio::spawn(async move {
582 manager.run(event_rx, shutdown_rx).await;
583 });
584
585 shutdown_tx.send(()).unwrap();
586
587 tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
588 .await
589 .expect("manager should shutdown")
590 .expect("task should complete successfully");
591 }
592
593 fn market_with_component_no_sim_state() -> MarketData {
598 let eth = token(1, "ETH");
599 let usdc = token(2, "USDC");
600 let pool = component("pool", &[eth.clone(), usdc.clone()]);
601
602 let mut market = MarketState::new();
603 market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
604 market.upsert_components(std::iter::once(pool));
605 market.upsert_tokens([eth, usdc]);
607 MarketData::new(std::sync::Arc::new(tokio::sync::RwLock::new(market)))
608 }
609
610 fn market_with_mixed_sim_states() -> MarketData {
613 let eth = token(1, "ETH");
614 let usdc = token(2, "USDC");
615 let dai = token(3, "DAI");
616
617 let pool1 = component("eth_usdc", &[eth.clone(), usdc.clone()]);
618 let pool2 = component("eth_dai", &[eth.clone(), dai.clone()]);
619
620 let mut market = MarketState::new();
621 market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
622 market.upsert_components([pool1, pool2]);
623 market
625 .update_states([("eth_usdc".to_string(), Box::new(MockProtocolSim::new(2000.0)) as _)]);
626 market.upsert_tokens([eth, usdc, dai]);
627 MarketData::new(std::sync::Arc::new(tokio::sync::RwLock::new(market)))
628 }
629
630 fn market_with_sim_state_no_gas_price() -> MarketData {
635 let eth = token(1, "ETH");
636 let usdc = token(2, "USDC");
637 let pool = component("pool", &[eth.clone(), usdc.clone()]);
638
639 let mut market = MarketState::new();
640 market.update_last_updated(BlockInfo::new(10, "0xhash".into(), 0));
642 market.upsert_components(std::iter::once(pool));
643 market.update_states([("pool".to_string(), Box::new(MockProtocolSim::new(2000.0)) as _)]);
644 market.upsert_tokens([eth, usdc]);
645 MarketData::new(std::sync::Arc::new(tokio::sync::RwLock::new(market)))
646 }
647
648 #[tokio::test]
649 async fn test_spot_price_failure_broadcasts_computation_failed() {
650 let market = market_with_component_no_sim_state();
651 let config = ComputationManagerConfig::new();
652 let (manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
653
654 let changed = ChangedComponents { is_full_recompute: true, ..Default::default() };
656 manager.compute_all(&changed).await;
657
658 let events = drain_events(&mut event_rx);
659
660 assert!(
661 events.iter().any(|e| matches!(
662 e,
663 DerivedDataEvent::ComputationFailed { computation_id: "spot_prices", .. }
664 )),
665 "expected ComputationFailed(spot_prices) in events: {events:?}"
666 );
667 }
668
669 #[tokio::test]
670 async fn test_token_price_failure_broadcasts_computation_failed() {
671 let eth = token(1, "ETH");
672 let usdc = token(2, "USDC");
673 let market = market_with_sim_state_no_gas_price();
674 let config = ComputationManagerConfig::new().with_gas_token(eth.address.clone());
675 let (mut manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
676
677 let event = MarketEvent::MarketUpdated {
679 added_components: HashMap::from([(
680 "pool".to_string(),
681 vec![eth.address.clone(), usdc.address.clone()],
682 )]),
683 removed_components: vec![],
684 updated_components: vec![],
685 };
686 manager
687 .handle_event(&event)
688 .await
689 .unwrap();
690
691 let events = drain_events(&mut event_rx);
692 assert!(
693 events.iter().any(|e| matches!(
694 e,
695 DerivedDataEvent::ComputationFailed { computation_id: "token_prices", .. }
696 )),
697 "expected ComputationFailed(token_prices) in events: {events:?}"
698 );
699 }
700
701 #[tokio::test]
702 async fn run_shuts_down_on_channel_close() {
703 let (market, _) = setup_market(vec![]);
704 let config = ComputationManagerConfig::new();
705 let (manager, _event_rx) = ComputationManager::new(config, market).unwrap();
706
707 let (event_tx, event_rx) = broadcast::channel::<MarketEvent>(16);
708 let (_shutdown_tx, shutdown_rx) = broadcast::channel::<()>(1);
709
710 let handle = tokio::spawn(async move {
711 manager.run(event_rx, shutdown_rx).await;
712 });
713
714 drop(event_tx);
715
716 tokio::time::timeout(tokio::time::Duration::from_secs(1), handle)
717 .await
718 .expect("manager should shutdown on channel close")
719 .expect("task should complete successfully");
720 }
721
722 #[tokio::test]
723 async fn partial_spot_price_failure_broadcasts_computation_complete() {
724 let market = market_with_mixed_sim_states();
727 let config = ComputationManagerConfig::new();
728 let (manager, mut event_rx) = ComputationManager::new(config, market).unwrap();
729
730 let changed = ChangedComponents { is_full_recompute: true, ..Default::default() };
731 manager.compute_all(&changed).await;
732
733 let events = drain_events(&mut event_rx);
734
735 assert!(
737 events.iter().any(|e| matches!(
738 e,
739 DerivedDataEvent::ComputationComplete { computation_id: "spot_prices", .. }
740 )),
741 "expected ComputationComplete(spot_prices), got: {events:?}"
742 );
743 assert!(
744 !events.iter().any(|e| matches!(
745 e,
746 DerivedDataEvent::ComputationFailed { computation_id: "spot_prices", .. }
747 )),
748 "should not broadcast ComputationFailed for partial failure"
749 );
750
751 let complete = events.iter().find(|e| {
753 matches!(e, DerivedDataEvent::ComputationComplete { computation_id: "spot_prices", .. })
754 });
755 if let Some(DerivedDataEvent::ComputationComplete { failed_items, .. }) = complete {
756 assert!(
757 !failed_items.is_empty(),
758 "ComputationComplete should carry failed_items for pool2"
759 );
760 }
761
762 let eth = token(1, "ETH");
765 let dai = token(3, "DAI");
766 let store = manager.store();
767 let guard = store.read().await;
768 let key_eth_dai = ("eth_dai".to_string(), eth.address.clone(), dai.address.clone());
769 let key_dai_eth = ("eth_dai".to_string(), dai.address.clone(), eth.address.clone());
770 assert!(
771 guard
772 .spot_price_failure(&key_eth_dai)
773 .is_some() ||
774 guard
775 .spot_price_failure(&key_dai_eth)
776 .is_some(),
777 "store should persist failure reason for eth_dai (missing sim state)"
778 );
779 }
780}