1use std::collections::{HashMap, HashSet};
12
13use async_trait::async_trait;
14pub use petgraph::graph::EdgeIndex;
15use petgraph::{graph::NodeIndex, stable_graph};
16use tracing::{debug, trace};
17use tycho_simulation::tycho_common::models::Address;
18
19use super::GraphManager;
20use crate::{
21 feed::{
22 events::{EventError, MarketEvent, MarketEventHandler},
23 market_data::MarketDataView,
24 },
25 graph::GraphError,
26 types::ComponentId,
27};
28
29#[derive(Debug, Clone, Default)]
48pub struct EdgeData<D = ()> {
49 pub component_id: ComponentId,
51 pub data: Option<D>,
53}
54
55impl<M> EdgeData<M> {
56 pub fn new(component_id: ComponentId) -> Self {
58 Self { component_id, data: None }
59 }
60
61 pub fn with_data(component_id: ComponentId, data: M) -> Self {
63 Self { component_id, data: Some(data) }
64 }
65}
66
67pub type StableDiGraph<D> = stable_graph::StableDiGraph<Address, EdgeData<D>>;
69
70pub struct PetgraphStableDiGraphManager<D: Clone> {
77 graph: StableDiGraph<D>,
81 edge_map: HashMap<ComponentId, Vec<EdgeIndex>>,
83 node_map: HashMap<Address, NodeIndex>,
85}
86
87impl<D: Clone> PetgraphStableDiGraphManager<D> {
88 pub fn new() -> Self {
90 Self { graph: StableDiGraph::default(), edge_map: HashMap::new(), node_map: HashMap::new() }
91 }
92
93 pub(crate) fn find_node(&self, addr: &Address) -> Result<NodeIndex, GraphError> {
95 self.node_map
96 .get(addr)
97 .copied()
98 .ok_or_else(|| GraphError::TokenNotFound(addr.clone()))
99 }
100
101 fn get_or_create_node(&mut self, addr: &Address) -> NodeIndex {
104 match self.find_node(addr) {
106 Ok(node_idx) => node_idx,
107 Err(_) => {
108 let node_idx = self.graph.add_node(addr.clone());
109 self.node_map
110 .insert(addr.clone(), node_idx);
111 node_idx
112 }
113 }
114 }
115
116 fn add_edge(&mut self, from_idx: NodeIndex, to_idx: NodeIndex, component_id: &ComponentId) {
124 let edge_idx = self
125 .graph
126 .add_edge(from_idx, to_idx, EdgeData::new(component_id.clone()));
127 self.edge_map
128 .entry(component_id.clone())
129 .or_default()
130 .push(edge_idx);
131 }
132
133 fn add_component_edges(&mut self, component_id: &ComponentId, node_indices: &[NodeIndex]) {
136 node_indices
138 .iter()
139 .enumerate()
140 .flat_map(|(i, &from_idx)| {
141 node_indices
142 .iter()
143 .skip(i + 1)
144 .map(move |&to_idx| (from_idx, to_idx))
145 })
146 .for_each(|(from_idx, to_idx)| {
147 self.add_edge(from_idx, to_idx, component_id);
149 self.add_edge(to_idx, from_idx, component_id);
150 });
151 }
152
153 fn add_components(
163 &mut self,
164 components: &HashMap<ComponentId, Vec<Address>>,
165 ) -> Result<(), GraphError> {
166 let mut invalid_components = Vec::new();
167 let mut skipped_duplicates = 0usize;
168
169 let mut sorted_components: Vec<_> = components.iter().collect();
171 sorted_components.sort_by_key(|(id, _)| *id);
172
173 for (comp_id, tokens) in sorted_components {
174 if self.edge_map.contains_key(comp_id) {
175 trace!(component_id = %comp_id, "skipping already-tracked component");
176 skipped_duplicates += 1;
177 continue;
178 }
179
180 if tokens.len() < 2 {
181 invalid_components.push(comp_id.clone());
182 continue;
183 }
184 let mut sorted_tokens: Vec<&Address> = tokens.iter().collect();
185 sorted_tokens.sort();
186 let node_indices: Vec<NodeIndex> = sorted_tokens
187 .iter()
188 .map(|token| self.get_or_create_node(token))
189 .collect();
190 self.add_component_edges(comp_id, &node_indices);
191 }
192
193 if skipped_duplicates > 0 {
194 debug!(skipped_duplicates, "skipped duplicate components during add");
195 }
196
197 if !invalid_components.is_empty() {
199 return Err(GraphError::InvalidComponents(invalid_components));
200 }
201
202 Ok(())
203 }
204
205 fn remove_components(&mut self, components: &[ComponentId]) -> Result<(), GraphError> {
215 let mut missing_components = Vec::new();
216
217 for comp_id in components {
218 if let Some(edge_indices) = self.edge_map.remove(comp_id) {
220 for edge_idx in edge_indices {
221 self.graph.remove_edge(edge_idx);
222 }
223 } else {
224 missing_components.push(comp_id.clone());
226 }
227 }
228
229 if !missing_components.is_empty() {
231 return Err(GraphError::ComponentsNotFound(missing_components));
232 }
233
234 Ok(())
235 }
236
237 #[cfg(test)]
252 pub(crate) fn set_edge_weight(
253 &mut self,
254 component_id: &ComponentId,
255 token_in: &Address,
256 token_out: &Address,
257 data: D,
258 bidirectional: bool,
259 ) -> Result<(), GraphError> {
260 let from_idx = self.find_node(token_in)?;
261 let to_idx = self.find_node(token_out)?;
262
263 let edge_indices = self
265 .edge_map
266 .get(component_id)
267 .ok_or_else(|| GraphError::ComponentsNotFound(vec![component_id.clone()]))?;
268
269 let mut updated = false;
270 for &edge_idx in edge_indices {
271 let (edge_from, edge_to) = match self.graph.edge_endpoints(edge_idx) {
273 Some(endpoints) => endpoints,
274 None => continue,
275 };
276
277 let should_update = if bidirectional {
279 (edge_from == from_idx && edge_to == to_idx) ||
281 (edge_from == to_idx && edge_to == from_idx)
282 } else {
283 edge_from == from_idx && edge_to == to_idx
285 };
286
287 if should_update {
288 let edge_data = self
290 .graph
291 .edge_weight_mut(edge_idx)
292 .ok_or_else(|| GraphError::ComponentsNotFound(vec![component_id.clone()]))?;
293 if edge_data.component_id == *component_id {
295 edge_data.data = Some(data.clone());
296 updated = true;
297 }
298 }
299 }
300
301 if !updated {
302 return Err(GraphError::MissingComponentBetweenTokens(
303 token_in.clone(),
304 token_out.clone(),
305 component_id.clone(),
306 ));
307 }
308
309 Ok(())
310 }
311}
312
313impl<D: Clone + super::EdgeWeightFromSimAndDerived> PetgraphStableDiGraphManager<D> {
314 pub fn update_edge_weights_with_derived(
329 &mut self,
330 market: MarketDataView<'_>,
331 derived: &crate::derived::DerivedData,
332 ) -> usize {
333 let tokens = market.token_registry_ref();
334
335 let updates: Vec<_> = self
337 .graph
338 .edge_indices()
339 .filter_map(|edge_idx| {
340 let edge_data = self.graph.edge_weight(edge_idx)?;
341 let component_id = &edge_data.component_id;
342
343 let sim_state = market.get_simulation_state(component_id)?;
344
345 let (source_idx, target_idx) = self.graph.edge_endpoints(edge_idx)?;
346 let source_addr = &self.graph[source_idx];
347 let target_addr = &self.graph[target_idx];
348
349 let token_in = tokens.get(source_addr)?;
350 let token_out = tokens.get(target_addr)?;
351
352 let weight =
353 D::from_sim_and_derived(sim_state, component_id, token_in, token_out, derived)?;
354 Some((edge_idx, weight))
355 })
356 .collect();
357
358 let updated = updates.len();
360 for (edge_idx, weight) in updates {
361 if let Some(edge_data) = self.graph.edge_weight_mut(edge_idx) {
362 edge_data.data = Some(weight);
363 }
364 }
365
366 updated
367 }
368}
369
370impl<D: Clone + super::EdgeWeightFromSimAndDerived> super::EdgeWeightUpdaterWithDerived
371 for PetgraphStableDiGraphManager<D>
372{
373 fn update_edge_weights_with_derived(
374 &mut self,
375 market: MarketDataView<'_>,
376 derived: &crate::derived::DerivedData,
377 ) -> usize {
378 self.update_edge_weights_with_derived(market, derived)
379 }
380}
381
382impl<D: Clone> Default for PetgraphStableDiGraphManager<D> {
383 fn default() -> Self {
384 Self::new()
385 }
386}
387
388impl<D: Clone + Send + Sync> GraphManager<StableDiGraph<D>> for PetgraphStableDiGraphManager<D> {
389 fn initialize_graph(&mut self, component_topology: &HashMap<ComponentId, Vec<Address>>) {
390 self.graph = StableDiGraph::default();
392 self.edge_map.clear();
393 self.node_map.clear();
394
395 let mut unique_tokens: Vec<Address> = component_topology
400 .values()
401 .flat_map(|v| v.iter())
402 .cloned()
403 .collect::<HashSet<_>>()
404 .into_iter()
405 .collect();
406 unique_tokens.sort();
407
408 for token in unique_tokens {
409 let node_idx = self.graph.add_node(token.clone());
410 self.node_map.insert(token, node_idx);
411 }
412
413 let mut sorted_components: Vec<_> = component_topology.iter().collect();
415 sorted_components.sort_by_key(|(id, _)| *id);
416
417 for (comp_id, tokens) in sorted_components {
418 let mut sorted_tokens: Vec<&Address> = tokens.iter().collect();
419 sorted_tokens.sort();
420 let node_indices: Vec<NodeIndex> = sorted_tokens
421 .iter()
422 .map(|token| self.node_map[*token])
423 .collect();
424 self.add_component_edges(comp_id, &node_indices);
425 }
426 }
427
428 fn graph(&self) -> &StableDiGraph<D> {
429 &self.graph
430 }
431}
432
433#[async_trait]
434impl<D: Clone + Send> MarketEventHandler for PetgraphStableDiGraphManager<D> {
435 async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
436 match event {
437 MarketEvent::MarketUpdated { added_components, removed_components, .. } => {
438 let mut errors = Vec::new();
440
441 if let Err(e) = self.add_components(added_components) {
443 errors.push(e);
444 }
445
446 if let Err(e) = self.remove_components(removed_components) {
448 errors.push(e);
449 }
450
451 match errors.len() {
453 0 => Ok(()),
454 _ => Err(EventError::GraphErrors(errors)),
455 }
456 }
457 }
458 }
459}
460
461#[cfg(test)]
462mod tests {
463 use std::str::FromStr;
464
465 use super::*;
466
467 fn addr(s: &str) -> Address {
469 Address::from_str(s).expect("Invalid address hex string")
470 }
471
472 #[test]
473 fn test_initialize_graph_empty() {
474 let mut manager = PetgraphStableDiGraphManager::<()>::new();
475 let topology = HashMap::new();
476
477 manager.initialize_graph(&topology);
478
479 let graph = manager.graph();
480 assert_eq!(graph.node_count(), 0);
481 assert_eq!(graph.edge_count(), 0);
482 }
483
484 #[test]
485 fn test_initialize_graph_comprehensive() {
486 let mut manager = PetgraphStableDiGraphManager::<()>::new();
487 let mut topology = HashMap::new();
488 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); let token_c = addr("0x6B175474E89094C44Da98b954EedeAC495271d0F"); let token_d = addr("0xdAC17F958D2ee523a2206206994597C13D831ec7"); topology
495 .insert("pool1".to_string(), vec![token_a.clone(), token_b.clone(), token_c.clone()]);
496 topology.insert("pool2".to_string(), vec![token_c.clone(), token_d.clone()]);
498
499 manager.initialize_graph(&topology);
500
501 let graph = manager.graph();
502 assert_eq!(graph.node_count(), 4);
504 assert_eq!(graph.edge_count(), 8);
508
509 let node_a = manager.find_node(&token_a).unwrap();
511 let node_b = manager.find_node(&token_b).unwrap();
512 let node_c = manager.find_node(&token_c).unwrap();
513 let node_d = manager.find_node(&token_d).unwrap();
514
515 assert_eq!(
517 graph
518 .edge_weight(graph.find_edge(node_a, node_b).unwrap())
519 .unwrap()
520 .component_id,
521 "pool1".to_string()
522 );
523 assert_eq!(
524 graph
525 .edge_weight(graph.find_edge(node_b, node_a).unwrap())
526 .unwrap()
527 .component_id,
528 "pool1".to_string()
529 );
530 assert_eq!(
531 graph
532 .edge_weight(graph.find_edge(node_a, node_c).unwrap())
533 .unwrap()
534 .component_id,
535 "pool1".to_string()
536 );
537 assert_eq!(
538 graph
539 .edge_weight(graph.find_edge(node_c, node_a).unwrap())
540 .unwrap()
541 .component_id,
542 "pool1".to_string()
543 );
544 assert_eq!(
545 graph
546 .edge_weight(graph.find_edge(node_b, node_c).unwrap())
547 .unwrap()
548 .component_id,
549 "pool1".to_string()
550 );
551 assert_eq!(
552 graph
553 .edge_weight(graph.find_edge(node_c, node_b).unwrap())
554 .unwrap()
555 .component_id,
556 "pool1".to_string()
557 );
558
559 assert_eq!(
561 graph
562 .edge_weight(graph.find_edge(node_c, node_d).unwrap())
563 .unwrap()
564 .component_id,
565 "pool2".to_string()
566 );
567 assert_eq!(
568 graph
569 .edge_weight(graph.find_edge(node_d, node_c).unwrap())
570 .unwrap()
571 .component_id,
572 "pool2".to_string()
573 );
574 }
575
576 #[test]
577 fn test_initialize_graph_multiple_edges_same_pair() {
578 let mut manager = PetgraphStableDiGraphManager::<()>::new();
579 let mut topology = HashMap::new();
580 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); topology.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
585 topology.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
586 topology.insert("pool3".to_string(), vec![token_a.clone(), token_b.clone()]);
587
588 manager.initialize_graph(&topology);
589
590 let graph = manager.graph();
591 assert_eq!(graph.node_count(), 2);
593 assert_eq!(graph.edge_count(), 6);
595
596 let node_a = manager.find_node(&token_a).unwrap();
597 let node_b = manager.find_node(&token_b).unwrap();
598
599 let edges: Vec<_> = graph
601 .edges_connecting(node_a, node_b)
602 .collect();
603 assert_eq!(edges.len(), 3);
604
605 let component_ids: Vec<_> = edges
606 .iter()
607 .map(|e| &e.weight().component_id)
608 .collect();
609
610 assert!(component_ids.contains(&&"pool1".to_string()));
612 assert!(component_ids.contains(&&"pool2".to_string()));
613 assert!(component_ids.contains(&&"pool3".to_string()));
614 }
615
616 #[test]
617 fn test_add_components_shared_tokens() {
618 let mut manager = PetgraphStableDiGraphManager::<()>::new();
619 let mut components = HashMap::new();
620 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
625 manager
626 .add_components(&components)
627 .unwrap();
628
629 let initial_node_count = manager.graph().node_count();
630 assert_eq!(initial_node_count, 2);
631
632 components.clear();
634 components.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
635 manager
636 .add_components(&components)
637 .unwrap();
638
639 assert_eq!(manager.graph().node_count(), 2, "Should not create duplicate nodes");
641 }
642
643 #[test]
644 fn test_add_tokenless_components_error() {
645 let mut manager = PetgraphStableDiGraphManager::<()>::new();
646 let mut components = HashMap::new();
647 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
652 components.insert("pool2".to_string(), vec![]);
653 components.insert("pool3".to_string(), vec![]);
654 let result = manager.add_components(&components);
655
656 assert!(result.is_err());
657 match result.unwrap_err() {
658 GraphError::InvalidComponents(ids) => {
659 assert_eq!(ids.len(), 2);
660 assert!(ids.contains(&"pool2".to_string()));
661 assert!(ids.contains(&"pool3".to_string()));
662 }
663 _ => panic!("Expected InvalidComponents error"),
664 }
665
666 assert_eq!(manager.graph().node_count(), 2);
668 assert_eq!(manager.graph().edge_count(), 2); }
670
671 #[test]
672 fn test_remove_components_not_found_error() {
673 let mut manager = PetgraphStableDiGraphManager::<()>::new();
674 let mut components = HashMap::new();
675 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
680 components.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
681 manager
682 .add_components(&components)
683 .unwrap();
684
685 let result = manager.remove_components(&[
687 "pool1".to_string(),
688 "pool3".to_string(),
689 "pool4".to_string(),
690 ]);
691
692 assert!(result.is_err());
693 match result.unwrap_err() {
694 GraphError::ComponentsNotFound(ids) => {
695 assert_eq!(ids.len(), 2, "Expected 2 missing components");
696 assert!(ids.contains(&"pool3".to_string()));
697 assert!(ids.contains(&"pool4".to_string()));
698 }
699 _ => panic!("Expected ComponentsNotFound error"),
700 }
701
702 for edge in manager.graph().edge_indices() {
704 assert_eq!(
705 manager
706 .graph()
707 .edge_weight(edge)
708 .unwrap()
709 .component_id,
710 "pool2".to_string()
711 );
712 }
713 }
714
715 #[test]
716 fn test_set_edge_weight_errors() {
717 let mut manager = PetgraphStableDiGraphManager::<()>::new();
718 let mut topology = HashMap::new();
719 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); let token_c = addr("0x6B175474E89094C44Da98b954EedeAC495271d0F"); topology.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
725 topology.insert("pool2".to_string(), vec![token_b.clone(), token_c.clone()]);
726 manager.initialize_graph(&topology);
727
728 let result = manager.set_edge_weight(&"pool3".to_string(), &token_a, &token_b, (), true);
730 assert!(result.is_err());
731 match result.unwrap_err() {
732 GraphError::ComponentsNotFound(ids) => {
733 assert_eq!(ids, vec!["pool3".to_string()]);
734 }
735 _ => panic!("Expected ComponentsNotFound error"),
736 }
737
738 let non_existent_token = addr("0x0000000000000000000000000000000000000000");
740 let result = manager.set_edge_weight(
741 &"pool1".to_string(),
742 &token_a,
743 &non_existent_token, (),
745 true,
746 );
747 assert!(result.is_err());
748 match result.unwrap_err() {
749 GraphError::TokenNotFound(found_addr) => {
750 assert_eq!(found_addr, non_existent_token);
751 }
752 _ => panic!("Expected TokenNotFound error"),
753 }
754
755 let result = manager.set_edge_weight(
757 &"pool1".to_string(),
758 &token_a,
759 &token_c, (),
761 true,
762 );
763 assert!(result.is_err());
764 match result.unwrap_err() {
765 GraphError::MissingComponentBetweenTokens(in_token, out_token, comp_id) => {
766 assert_eq!(in_token, token_a);
767 assert_eq!(out_token, token_c);
768 assert_eq!(comp_id, "pool1".to_string());
769 }
770 _ => panic!("Expected MissingComponentBetweenTokens error"),
771 }
772 }
773
774 #[tokio::test]
775 async fn test_handle_event_propagates_errors() {
776 let mut manager = PetgraphStableDiGraphManager::<()>::new();
777 use std::collections::HashMap;
778
779 use crate::feed::events::{EventError, MarketEvent};
780
781 let event = MarketEvent::MarketUpdated {
783 added_components: HashMap::from([("pool1".to_string(), vec![])]),
784 removed_components: vec!["pool2".to_string()],
785 updated_components: vec![],
786 };
787
788 let result = manager.handle_event(&event).await;
789
790 assert!(result.is_err());
792 match result.unwrap_err() {
793 EventError::GraphErrors(errors) => {
794 assert_eq!(errors.len(), 2);
795 let has_add_error = errors
797 .iter()
798 .any(|e| matches!(e, GraphError::InvalidComponents(_)));
799 let has_remove_error = errors
800 .iter()
801 .any(|e| matches!(e, GraphError::ComponentsNotFound(_)));
802 assert!(has_add_error, "Should have InvalidComponents error");
803 assert!(has_remove_error, "Should have ComponentsNotFound error");
804 }
805 }
806 }
807
808 #[test]
809 fn test_add_components_skips_duplicates() {
810 let mut manager = PetgraphStableDiGraphManager::<()>::new();
811 let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
812 let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
813
814 let mut components = HashMap::new();
815 components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
816
817 manager
818 .add_components(&components)
819 .unwrap();
820 let edge_count_after_first = manager.graph().edge_count();
821 assert_eq!(edge_count_after_first, 2); manager
825 .add_components(&components)
826 .unwrap();
827 let edge_count_after_second = manager.graph().edge_count();
828 assert_eq!(
829 edge_count_after_first, edge_count_after_second,
830 "Edge count should not change when re-adding the same component"
831 );
832 }
833}