1use std::collections::{HashMap, HashSet};
12
13use async_trait::async_trait;
14pub use petgraph::graph::EdgeIndex;
15use petgraph::{graph::NodeIndex, stable_graph};
16use tracing::{debug, trace};
17use tycho_simulation::tycho_common::models::Address;
18
19use super::GraphManager;
20use crate::{
21 feed::{
22 events::{EventError, MarketEvent, MarketEventHandler},
23 market_data::SharedMarketData,
24 },
25 graph::GraphError,
26 types::ComponentId,
27};
28
29#[derive(Debug, Clone, Default)]
48pub struct EdgeData<D = ()> {
49 pub component_id: ComponentId,
51 pub data: Option<D>,
53}
54
55impl<M> EdgeData<M> {
56 pub fn new(component_id: ComponentId) -> Self {
58 Self { component_id, data: None }
59 }
60
61 pub fn with_data(component_id: ComponentId, data: M) -> Self {
63 Self { component_id, data: Some(data) }
64 }
65}
66
67pub type StableDiGraph<D> = stable_graph::StableDiGraph<Address, EdgeData<D>>;
69
70pub struct PetgraphStableDiGraphManager<D: Clone> {
77 graph: StableDiGraph<D>,
81 edge_map: HashMap<ComponentId, Vec<EdgeIndex>>,
83 node_map: HashMap<Address, NodeIndex>,
85}
86
87impl<D: Clone> PetgraphStableDiGraphManager<D> {
88 pub fn new() -> Self {
90 Self { graph: StableDiGraph::default(), edge_map: HashMap::new(), node_map: HashMap::new() }
91 }
92
93 pub(crate) fn find_node(&self, addr: &Address) -> Result<NodeIndex, GraphError> {
95 self.node_map
96 .get(addr)
97 .copied()
98 .ok_or_else(|| GraphError::TokenNotFound(addr.clone()))
99 }
100
101 fn get_or_create_node(&mut self, addr: &Address) -> NodeIndex {
104 match self.find_node(addr) {
106 Ok(node_idx) => node_idx,
107 Err(_) => {
108 let node_idx = self.graph.add_node(addr.clone());
109 self.node_map
110 .insert(addr.clone(), node_idx);
111 node_idx
112 }
113 }
114 }
115
116 fn add_edge(&mut self, from_idx: NodeIndex, to_idx: NodeIndex, component_id: &ComponentId) {
124 let edge_idx = self
125 .graph
126 .add_edge(from_idx, to_idx, EdgeData::new(component_id.clone()));
127 self.edge_map
128 .entry(component_id.clone())
129 .or_default()
130 .push(edge_idx);
131 }
132
133 fn add_component_edges(&mut self, component_id: &ComponentId, node_indices: &[NodeIndex]) {
136 node_indices
138 .iter()
139 .enumerate()
140 .flat_map(|(i, &from_idx)| {
141 node_indices
142 .iter()
143 .skip(i + 1)
144 .map(move |&to_idx| (from_idx, to_idx))
145 })
146 .for_each(|(from_idx, to_idx)| {
147 self.add_edge(from_idx, to_idx, component_id);
149 self.add_edge(to_idx, from_idx, component_id);
150 });
151 }
152
153 fn add_components(
163 &mut self,
164 components: &HashMap<ComponentId, Vec<Address>>,
165 ) -> Result<(), GraphError> {
166 let mut invalid_components = Vec::new();
167 let mut skipped_duplicates = 0usize;
168
169 for (comp_id, tokens) in components {
170 if self.edge_map.contains_key(comp_id) {
171 trace!(component_id = %comp_id, "skipping already-tracked component");
172 skipped_duplicates += 1;
173 continue;
174 }
175
176 if tokens.len() < 2 {
177 invalid_components.push(comp_id.clone());
178 continue;
179 }
180 let node_indices: Vec<NodeIndex> = tokens
182 .iter()
183 .map(|token| self.get_or_create_node(token))
184 .collect();
185 self.add_component_edges(comp_id, &node_indices);
187 }
188
189 if skipped_duplicates > 0 {
190 debug!(skipped_duplicates, "skipped duplicate components during add");
191 }
192
193 if !invalid_components.is_empty() {
195 return Err(GraphError::InvalidComponents(invalid_components));
196 }
197
198 Ok(())
199 }
200
201 fn remove_components(&mut self, components: &[ComponentId]) -> Result<(), GraphError> {
211 let mut missing_components = Vec::new();
212
213 for comp_id in components {
214 if let Some(edge_indices) = self.edge_map.remove(comp_id) {
216 for edge_idx in edge_indices {
217 self.graph.remove_edge(edge_idx);
218 }
219 } else {
220 missing_components.push(comp_id.clone());
222 }
223 }
224
225 if !missing_components.is_empty() {
227 return Err(GraphError::ComponentsNotFound(missing_components));
228 }
229
230 Ok(())
231 }
232
233 #[cfg(test)]
248 pub(crate) fn set_edge_weight(
249 &mut self,
250 component_id: &ComponentId,
251 token_in: &Address,
252 token_out: &Address,
253 data: D,
254 bidirectional: bool,
255 ) -> Result<(), GraphError> {
256 let from_idx = self.find_node(token_in)?;
257 let to_idx = self.find_node(token_out)?;
258
259 let edge_indices = self
261 .edge_map
262 .get(component_id)
263 .ok_or_else(|| GraphError::ComponentsNotFound(vec![component_id.clone()]))?;
264
265 let mut updated = false;
266 for &edge_idx in edge_indices {
267 let (edge_from, edge_to) = match self.graph.edge_endpoints(edge_idx) {
269 Some(endpoints) => endpoints,
270 None => continue,
271 };
272
273 let should_update = if bidirectional {
275 (edge_from == from_idx && edge_to == to_idx) ||
277 (edge_from == to_idx && edge_to == from_idx)
278 } else {
279 edge_from == from_idx && edge_to == to_idx
281 };
282
283 if should_update {
284 let edge_data = self
286 .graph
287 .edge_weight_mut(edge_idx)
288 .ok_or_else(|| GraphError::ComponentsNotFound(vec![component_id.clone()]))?;
289 if edge_data.component_id == *component_id {
291 edge_data.data = Some(data.clone());
292 updated = true;
293 }
294 }
295 }
296
297 if !updated {
298 return Err(GraphError::MissingComponentBetweenTokens(
299 token_in.clone(),
300 token_out.clone(),
301 component_id.clone(),
302 ));
303 }
304
305 Ok(())
306 }
307}
308
309impl<D: Clone + super::EdgeWeightFromSimAndDerived> PetgraphStableDiGraphManager<D> {
310 pub fn update_edge_weights_with_derived(
325 &mut self,
326 market: &SharedMarketData,
327 derived: &crate::derived::DerivedData,
328 ) -> usize {
329 let tokens = market.token_registry_ref();
330
331 let updates: Vec<_> = self
333 .graph
334 .edge_indices()
335 .filter_map(|edge_idx| {
336 let edge_data = self.graph.edge_weight(edge_idx)?;
337 let component_id = &edge_data.component_id;
338
339 let sim_state = market.get_simulation_state(component_id)?;
340
341 let (source_idx, target_idx) = self.graph.edge_endpoints(edge_idx)?;
342 let source_addr = &self.graph[source_idx];
343 let target_addr = &self.graph[target_idx];
344
345 let token_in = tokens.get(source_addr)?;
346 let token_out = tokens.get(target_addr)?;
347
348 let weight =
349 D::from_sim_and_derived(sim_state, component_id, token_in, token_out, derived)?;
350 Some((edge_idx, weight))
351 })
352 .collect();
353
354 let updated = updates.len();
356 for (edge_idx, weight) in updates {
357 if let Some(edge_data) = self.graph.edge_weight_mut(edge_idx) {
358 edge_data.data = Some(weight);
359 }
360 }
361
362 updated
363 }
364}
365
366impl<D: Clone + super::EdgeWeightFromSimAndDerived> super::EdgeWeightUpdaterWithDerived
367 for PetgraphStableDiGraphManager<D>
368{
369 fn update_edge_weights_with_derived(
370 &mut self,
371 market: &SharedMarketData,
372 derived: &crate::derived::DerivedData,
373 ) -> usize {
374 self.update_edge_weights_with_derived(market, derived)
375 }
376}
377
378impl<D: Clone> Default for PetgraphStableDiGraphManager<D> {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384impl<D: Clone + Send + Sync> GraphManager<StableDiGraph<D>> for PetgraphStableDiGraphManager<D> {
385 fn initialize_graph(&mut self, component_topology: &HashMap<ComponentId, Vec<Address>>) {
386 self.graph = StableDiGraph::default();
388 self.edge_map.clear();
389 self.node_map.clear();
390
391 let unique_tokens: HashSet<Address> = component_topology
392 .values()
393 .flat_map(|v| v.iter())
394 .cloned()
395 .collect();
396
397 for token in unique_tokens {
399 let node_idx = self.graph.add_node(token.clone());
400 self.node_map.insert(token, node_idx);
401 }
402
403 for (comp_id, tokens) in component_topology {
405 let node_indices: Vec<NodeIndex> = tokens
406 .iter()
407 .map(|token| self.node_map[token])
408 .collect();
409 self.add_component_edges(comp_id, &node_indices);
410 }
411 }
412
413 fn graph(&self) -> &StableDiGraph<D> {
414 &self.graph
415 }
416}
417
418#[async_trait]
419impl<D: Clone + Send> MarketEventHandler for PetgraphStableDiGraphManager<D> {
420 async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
421 match event {
422 MarketEvent::MarketUpdated { added_components, removed_components, .. } => {
423 let mut errors = Vec::new();
425
426 if let Err(e) = self.add_components(added_components) {
428 errors.push(e);
429 }
430
431 if let Err(e) = self.remove_components(removed_components) {
433 errors.push(e);
434 }
435
436 match errors.len() {
438 0 => Ok(()),
439 _ => Err(EventError::GraphErrors(errors)),
440 }
441 }
442 }
443 }
444}
445
446#[cfg(test)]
447mod tests {
448 use std::str::FromStr;
449
450 use super::*;
451
452 fn addr(s: &str) -> Address {
454 Address::from_str(s).expect("Invalid address hex string")
455 }
456
457 #[test]
458 fn test_initialize_graph_empty() {
459 let mut manager = PetgraphStableDiGraphManager::<()>::new();
460 let topology = HashMap::new();
461
462 manager.initialize_graph(&topology);
463
464 let graph = manager.graph();
465 assert_eq!(graph.node_count(), 0);
466 assert_eq!(graph.edge_count(), 0);
467 }
468
469 #[test]
470 fn test_initialize_graph_comprehensive() {
471 let mut manager = PetgraphStableDiGraphManager::<()>::new();
472 let mut topology = HashMap::new();
473 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); let token_c = addr("0x6B175474E89094C44Da98b954EedeAC495271d0F"); let token_d = addr("0xdAC17F958D2ee523a2206206994597C13D831ec7"); topology
480 .insert("pool1".to_string(), vec![token_a.clone(), token_b.clone(), token_c.clone()]);
481 topology.insert("pool2".to_string(), vec![token_c.clone(), token_d.clone()]);
483
484 manager.initialize_graph(&topology);
485
486 let graph = manager.graph();
487 assert_eq!(graph.node_count(), 4);
489 assert_eq!(graph.edge_count(), 8);
493
494 let node_a = manager.find_node(&token_a).unwrap();
496 let node_b = manager.find_node(&token_b).unwrap();
497 let node_c = manager.find_node(&token_c).unwrap();
498 let node_d = manager.find_node(&token_d).unwrap();
499
500 assert_eq!(
502 graph
503 .edge_weight(graph.find_edge(node_a, node_b).unwrap())
504 .unwrap()
505 .component_id,
506 "pool1".to_string()
507 );
508 assert_eq!(
509 graph
510 .edge_weight(graph.find_edge(node_b, node_a).unwrap())
511 .unwrap()
512 .component_id,
513 "pool1".to_string()
514 );
515 assert_eq!(
516 graph
517 .edge_weight(graph.find_edge(node_a, node_c).unwrap())
518 .unwrap()
519 .component_id,
520 "pool1".to_string()
521 );
522 assert_eq!(
523 graph
524 .edge_weight(graph.find_edge(node_c, node_a).unwrap())
525 .unwrap()
526 .component_id,
527 "pool1".to_string()
528 );
529 assert_eq!(
530 graph
531 .edge_weight(graph.find_edge(node_b, node_c).unwrap())
532 .unwrap()
533 .component_id,
534 "pool1".to_string()
535 );
536 assert_eq!(
537 graph
538 .edge_weight(graph.find_edge(node_c, node_b).unwrap())
539 .unwrap()
540 .component_id,
541 "pool1".to_string()
542 );
543
544 assert_eq!(
546 graph
547 .edge_weight(graph.find_edge(node_c, node_d).unwrap())
548 .unwrap()
549 .component_id,
550 "pool2".to_string()
551 );
552 assert_eq!(
553 graph
554 .edge_weight(graph.find_edge(node_d, node_c).unwrap())
555 .unwrap()
556 .component_id,
557 "pool2".to_string()
558 );
559 }
560
561 #[test]
562 fn test_initialize_graph_multiple_edges_same_pair() {
563 let mut manager = PetgraphStableDiGraphManager::<()>::new();
564 let mut topology = HashMap::new();
565 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); topology.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
570 topology.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
571 topology.insert("pool3".to_string(), vec![token_a.clone(), token_b.clone()]);
572
573 manager.initialize_graph(&topology);
574
575 let graph = manager.graph();
576 assert_eq!(graph.node_count(), 2);
578 assert_eq!(graph.edge_count(), 6);
580
581 let node_a = manager.find_node(&token_a).unwrap();
582 let node_b = manager.find_node(&token_b).unwrap();
583
584 let edges: Vec<_> = graph
586 .edges_connecting(node_a, node_b)
587 .collect();
588 assert_eq!(edges.len(), 3);
589
590 let component_ids: Vec<_> = edges
591 .iter()
592 .map(|e| &e.weight().component_id)
593 .collect();
594
595 assert!(component_ids.contains(&&"pool1".to_string()));
597 assert!(component_ids.contains(&&"pool2".to_string()));
598 assert!(component_ids.contains(&&"pool3".to_string()));
599 }
600
601 #[test]
602 fn test_add_components_shared_tokens() {
603 let mut manager = PetgraphStableDiGraphManager::<()>::new();
604 let mut components = HashMap::new();
605 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
610 manager
611 .add_components(&components)
612 .unwrap();
613
614 let initial_node_count = manager.graph().node_count();
615 assert_eq!(initial_node_count, 2);
616
617 components.clear();
619 components.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
620 manager
621 .add_components(&components)
622 .unwrap();
623
624 assert_eq!(manager.graph().node_count(), 2, "Should not create duplicate nodes");
626 }
627
628 #[test]
629 fn test_add_tokenless_components_error() {
630 let mut manager = PetgraphStableDiGraphManager::<()>::new();
631 let mut components = HashMap::new();
632 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
637 components.insert("pool2".to_string(), vec![]);
638 components.insert("pool3".to_string(), vec![]);
639 let result = manager.add_components(&components);
640
641 assert!(result.is_err());
642 match result.unwrap_err() {
643 GraphError::InvalidComponents(ids) => {
644 assert_eq!(ids.len(), 2);
645 assert!(ids.contains(&"pool2".to_string()));
646 assert!(ids.contains(&"pool3".to_string()));
647 }
648 _ => panic!("Expected InvalidComponents error"),
649 }
650
651 assert_eq!(manager.graph().node_count(), 2);
653 assert_eq!(manager.graph().edge_count(), 2); }
655
656 #[test]
657 fn test_remove_components_not_found_error() {
658 let mut manager = PetgraphStableDiGraphManager::<()>::new();
659 let mut components = HashMap::new();
660 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
665 components.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
666 manager
667 .add_components(&components)
668 .unwrap();
669
670 let result = manager.remove_components(&[
672 "pool1".to_string(),
673 "pool3".to_string(),
674 "pool4".to_string(),
675 ]);
676
677 assert!(result.is_err());
678 match result.unwrap_err() {
679 GraphError::ComponentsNotFound(ids) => {
680 assert_eq!(ids.len(), 2, "Expected 2 missing components");
681 assert!(ids.contains(&"pool3".to_string()));
682 assert!(ids.contains(&"pool4".to_string()));
683 }
684 _ => panic!("Expected ComponentsNotFound error"),
685 }
686
687 for edge in manager.graph().edge_indices() {
689 assert_eq!(
690 manager
691 .graph()
692 .edge_weight(edge)
693 .unwrap()
694 .component_id,
695 "pool2".to_string()
696 );
697 }
698 }
699
700 #[test]
701 fn test_set_edge_weight_errors() {
702 let mut manager = PetgraphStableDiGraphManager::<()>::new();
703 let mut topology = HashMap::new();
704 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); let token_c = addr("0x6B175474E89094C44Da98b954EedeAC495271d0F"); topology.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
710 topology.insert("pool2".to_string(), vec![token_b.clone(), token_c.clone()]);
711 manager.initialize_graph(&topology);
712
713 let result = manager.set_edge_weight(&"pool3".to_string(), &token_a, &token_b, (), true);
715 assert!(result.is_err());
716 match result.unwrap_err() {
717 GraphError::ComponentsNotFound(ids) => {
718 assert_eq!(ids, vec!["pool3".to_string()]);
719 }
720 _ => panic!("Expected ComponentsNotFound error"),
721 }
722
723 let non_existent_token = addr("0x0000000000000000000000000000000000000000");
725 let result = manager.set_edge_weight(
726 &"pool1".to_string(),
727 &token_a,
728 &non_existent_token, (),
730 true,
731 );
732 assert!(result.is_err());
733 match result.unwrap_err() {
734 GraphError::TokenNotFound(found_addr) => {
735 assert_eq!(found_addr, non_existent_token);
736 }
737 _ => panic!("Expected TokenNotFound error"),
738 }
739
740 let result = manager.set_edge_weight(
742 &"pool1".to_string(),
743 &token_a,
744 &token_c, (),
746 true,
747 );
748 assert!(result.is_err());
749 match result.unwrap_err() {
750 GraphError::MissingComponentBetweenTokens(in_token, out_token, comp_id) => {
751 assert_eq!(in_token, token_a);
752 assert_eq!(out_token, token_c);
753 assert_eq!(comp_id, "pool1".to_string());
754 }
755 _ => panic!("Expected MissingComponentBetweenTokens error"),
756 }
757 }
758
759 #[tokio::test]
760 async fn test_handle_event_propagates_errors() {
761 let mut manager = PetgraphStableDiGraphManager::<()>::new();
762 use std::collections::HashMap;
763
764 use crate::feed::events::{EventError, MarketEvent};
765
766 let event = MarketEvent::MarketUpdated {
768 added_components: HashMap::from([("pool1".to_string(), vec![])]),
769 removed_components: vec!["pool2".to_string()],
770 updated_components: vec![],
771 };
772
773 let result = manager.handle_event(&event).await;
774
775 assert!(result.is_err());
777 match result.unwrap_err() {
778 EventError::GraphErrors(errors) => {
779 assert_eq!(errors.len(), 2);
780 let has_add_error = errors
782 .iter()
783 .any(|e| matches!(e, GraphError::InvalidComponents(_)));
784 let has_remove_error = errors
785 .iter()
786 .any(|e| matches!(e, GraphError::ComponentsNotFound(_)));
787 assert!(has_add_error, "Should have InvalidComponents error");
788 assert!(has_remove_error, "Should have ComponentsNotFound error");
789 }
790 }
791 }
792
793 #[test]
794 fn test_add_components_skips_duplicates() {
795 let mut manager = PetgraphStableDiGraphManager::<()>::new();
796 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
797 let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
798
799 let mut components = HashMap::new();
800 components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
801
802 manager
803 .add_components(&components)
804 .unwrap();
805 let edge_count_after_first = manager.graph().edge_count();
806 assert_eq!(edge_count_after_first, 2); manager
810 .add_components(&components)
811 .unwrap();
812 let edge_count_after_second = manager.graph().edge_count();
813 assert_eq!(
814 edge_count_after_first, edge_count_after_second,
815 "Edge count should not change when re-adding the same component"
816 );
817 }
818}