use std::collections::{HashMap, HashSet};
use async_trait::async_trait;
pub use petgraph::graph::EdgeIndex;
use petgraph::{graph::NodeIndex, stable_graph};
use tracing::{debug, trace};
use tycho_simulation::tycho_common::models::Address;
use super::GraphManager;
use crate::{
feed::{
events::{EventError, MarketEvent, MarketEventHandler},
market_data::SharedMarketData,
},
graph::GraphError,
types::ComponentId,
};
#[derive(Debug, Clone, Default)]
pub struct EdgeData<D = ()> {
pub component_id: ComponentId,
pub data: Option<D>,
}
impl<M> EdgeData<M> {
pub fn new(component_id: ComponentId) -> Self {
Self { component_id, data: None }
}
pub fn with_data(component_id: ComponentId, data: M) -> Self {
Self { component_id, data: Some(data) }
}
}
pub type StableDiGraph<D> = stable_graph::StableDiGraph<Address, EdgeData<D>>;
pub struct PetgraphStableDiGraphManager<D: Clone> {
graph: StableDiGraph<D>,
edge_map: HashMap<ComponentId, Vec<EdgeIndex>>,
node_map: HashMap<Address, NodeIndex>,
}
impl<D: Clone> PetgraphStableDiGraphManager<D> {
pub fn new() -> Self {
Self { graph: StableDiGraph::default(), edge_map: HashMap::new(), node_map: HashMap::new() }
}
pub(crate) fn find_node(&self, addr: &Address) -> Result<NodeIndex, GraphError> {
self.node_map
.get(addr)
.copied()
.ok_or_else(|| GraphError::TokenNotFound(addr.clone()))
}
fn get_or_create_node(&mut self, addr: &Address) -> NodeIndex {
match self.find_node(addr) {
Ok(node_idx) => node_idx,
Err(_) => {
let node_idx = self.graph.add_node(addr.clone());
self.node_map
.insert(addr.clone(), node_idx);
node_idx
}
}
}
fn add_edge(&mut self, from_idx: NodeIndex, to_idx: NodeIndex, component_id: &ComponentId) {
let edge_idx = self
.graph
.add_edge(from_idx, to_idx, EdgeData::new(component_id.clone()));
self.edge_map
.entry(component_id.clone())
.or_default()
.push(edge_idx);
}
fn add_component_edges(&mut self, component_id: &ComponentId, node_indices: &[NodeIndex]) {
node_indices
.iter()
.enumerate()
.flat_map(|(i, &from_idx)| {
node_indices
.iter()
.skip(i + 1)
.map(move |&to_idx| (from_idx, to_idx))
})
.for_each(|(from_idx, to_idx)| {
self.add_edge(from_idx, to_idx, component_id);
self.add_edge(to_idx, from_idx, component_id);
});
}
fn add_components(
&mut self,
components: &HashMap<ComponentId, Vec<Address>>,
) -> Result<(), GraphError> {
let mut invalid_components = Vec::new();
let mut skipped_duplicates = 0usize;
for (comp_id, tokens) in components {
if self.edge_map.contains_key(comp_id) {
trace!(component_id = %comp_id, "skipping already-tracked component");
skipped_duplicates += 1;
continue;
}
if tokens.len() < 2 {
invalid_components.push(comp_id.clone());
continue;
}
let node_indices: Vec<NodeIndex> = tokens
.iter()
.map(|token| self.get_or_create_node(token))
.collect();
self.add_component_edges(comp_id, &node_indices);
}
if skipped_duplicates > 0 {
debug!(skipped_duplicates, "skipped duplicate components during add");
}
if !invalid_components.is_empty() {
return Err(GraphError::InvalidComponents(invalid_components));
}
Ok(())
}
fn remove_components(&mut self, components: &[ComponentId]) -> Result<(), GraphError> {
let mut missing_components = Vec::new();
for comp_id in components {
if let Some(edge_indices) = self.edge_map.remove(comp_id) {
for edge_idx in edge_indices {
self.graph.remove_edge(edge_idx);
}
} else {
missing_components.push(comp_id.clone());
}
}
if !missing_components.is_empty() {
return Err(GraphError::ComponentsNotFound(missing_components));
}
Ok(())
}
#[cfg(test)]
pub(crate) fn set_edge_weight(
&mut self,
component_id: &ComponentId,
token_in: &Address,
token_out: &Address,
data: D,
bidirectional: bool,
) -> Result<(), GraphError> {
let from_idx = self.find_node(token_in)?;
let to_idx = self.find_node(token_out)?;
let edge_indices = self
.edge_map
.get(component_id)
.ok_or_else(|| GraphError::ComponentsNotFound(vec![component_id.clone()]))?;
let mut updated = false;
for &edge_idx in edge_indices {
let (edge_from, edge_to) = match self.graph.edge_endpoints(edge_idx) {
Some(endpoints) => endpoints,
None => continue,
};
let should_update = if bidirectional {
(edge_from == from_idx && edge_to == to_idx) ||
(edge_from == to_idx && edge_to == from_idx)
} else {
edge_from == from_idx && edge_to == to_idx
};
if should_update {
let edge_data = self
.graph
.edge_weight_mut(edge_idx)
.ok_or_else(|| GraphError::ComponentsNotFound(vec![component_id.clone()]))?;
if edge_data.component_id == *component_id {
edge_data.data = Some(data.clone());
updated = true;
}
}
}
if !updated {
return Err(GraphError::MissingComponentBetweenTokens(
token_in.clone(),
token_out.clone(),
component_id.clone(),
));
}
Ok(())
}
}
impl<D: Clone + super::EdgeWeightFromSimAndDerived> PetgraphStableDiGraphManager<D> {
pub fn update_edge_weights_with_derived(
&mut self,
market: &SharedMarketData,
derived: &crate::derived::DerivedData,
) -> usize {
let tokens = market.token_registry_ref();
let updates: Vec<_> = self
.graph
.edge_indices()
.filter_map(|edge_idx| {
let edge_data = self.graph.edge_weight(edge_idx)?;
let component_id = &edge_data.component_id;
let sim_state = market.get_simulation_state(component_id)?;
let (source_idx, target_idx) = self.graph.edge_endpoints(edge_idx)?;
let source_addr = &self.graph[source_idx];
let target_addr = &self.graph[target_idx];
let token_in = tokens.get(source_addr)?;
let token_out = tokens.get(target_addr)?;
let weight =
D::from_sim_and_derived(sim_state, component_id, token_in, token_out, derived)?;
Some((edge_idx, weight))
})
.collect();
let updated = updates.len();
for (edge_idx, weight) in updates {
if let Some(edge_data) = self.graph.edge_weight_mut(edge_idx) {
edge_data.data = Some(weight);
}
}
updated
}
}
impl<D: Clone + super::EdgeWeightFromSimAndDerived> super::EdgeWeightUpdaterWithDerived
for PetgraphStableDiGraphManager<D>
{
fn update_edge_weights_with_derived(
&mut self,
market: &SharedMarketData,
derived: &crate::derived::DerivedData,
) -> usize {
self.update_edge_weights_with_derived(market, derived)
}
}
impl<D: Clone> Default for PetgraphStableDiGraphManager<D> {
fn default() -> Self {
Self::new()
}
}
impl<D: Clone + Send + Sync> GraphManager<StableDiGraph<D>> for PetgraphStableDiGraphManager<D> {
fn initialize_graph(&mut self, component_topology: &HashMap<ComponentId, Vec<Address>>) {
self.graph = StableDiGraph::default();
self.edge_map.clear();
self.node_map.clear();
let unique_tokens: HashSet<Address> = component_topology
.values()
.flat_map(|v| v.iter())
.cloned()
.collect();
for token in unique_tokens {
let node_idx = self.graph.add_node(token.clone());
self.node_map.insert(token, node_idx);
}
for (comp_id, tokens) in component_topology {
let node_indices: Vec<NodeIndex> = tokens
.iter()
.map(|token| self.node_map[token])
.collect();
self.add_component_edges(comp_id, &node_indices);
}
}
fn graph(&self) -> &StableDiGraph<D> {
&self.graph
}
}
#[async_trait]
impl<D: Clone + Send> MarketEventHandler for PetgraphStableDiGraphManager<D> {
async fn handle_event(&mut self, event: &MarketEvent) -> Result<(), EventError> {
match event {
MarketEvent::MarketUpdated { added_components, removed_components, .. } => {
let mut errors = Vec::new();
if let Err(e) = self.add_components(added_components) {
errors.push(e);
}
if let Err(e) = self.remove_components(removed_components) {
errors.push(e);
}
match errors.len() {
0 => Ok(()),
_ => Err(EventError::GraphErrors(errors)),
}
}
}
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use super::*;
fn addr(s: &str) -> Address {
Address::from_str(s).expect("Invalid address hex string")
}
#[test]
fn test_initialize_graph_empty() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let topology = HashMap::new();
manager.initialize_graph(&topology);
let graph = manager.graph();
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
}
#[test]
fn test_initialize_graph_comprehensive() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let mut topology = HashMap::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); let token_c = addr("0x6B175474E89094C44Da98b954EedeAC495271d0F"); let token_d = addr("0xdAC17F958D2ee523a2206206994597C13D831ec7");
topology
.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone(), token_c.clone()]);
topology.insert("pool2".to_string(), vec![token_c.clone(), token_d.clone()]);
manager.initialize_graph(&topology);
let graph = manager.graph();
assert_eq!(graph.node_count(), 4);
assert_eq!(graph.edge_count(), 8);
let node_a = manager.find_node(&token_a).unwrap();
let node_b = manager.find_node(&token_b).unwrap();
let node_c = manager.find_node(&token_c).unwrap();
let node_d = manager.find_node(&token_d).unwrap();
assert_eq!(
graph
.edge_weight(graph.find_edge(node_a, node_b).unwrap())
.unwrap()
.component_id,
"pool1".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_b, node_a).unwrap())
.unwrap()
.component_id,
"pool1".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_a, node_c).unwrap())
.unwrap()
.component_id,
"pool1".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_c, node_a).unwrap())
.unwrap()
.component_id,
"pool1".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_b, node_c).unwrap())
.unwrap()
.component_id,
"pool1".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_c, node_b).unwrap())
.unwrap()
.component_id,
"pool1".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_c, node_d).unwrap())
.unwrap()
.component_id,
"pool2".to_string()
);
assert_eq!(
graph
.edge_weight(graph.find_edge(node_d, node_c).unwrap())
.unwrap()
.component_id,
"pool2".to_string()
);
}
#[test]
fn test_initialize_graph_multiple_edges_same_pair() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let mut topology = HashMap::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
topology.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
topology.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
topology.insert("pool3".to_string(), vec![token_a.clone(), token_b.clone()]);
manager.initialize_graph(&topology);
let graph = manager.graph();
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 6);
let node_a = manager.find_node(&token_a).unwrap();
let node_b = manager.find_node(&token_b).unwrap();
let edges: Vec<_> = graph
.edges_connecting(node_a, node_b)
.collect();
assert_eq!(edges.len(), 3);
let component_ids: Vec<_> = edges
.iter()
.map(|e| &e.weight().component_id)
.collect();
assert!(component_ids.contains(&&"pool1".to_string()));
assert!(component_ids.contains(&&"pool2".to_string()));
assert!(component_ids.contains(&&"pool3".to_string()));
}
#[test]
fn test_add_components_shared_tokens() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let mut components = HashMap::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
manager
.add_components(&components)
.unwrap();
let initial_node_count = manager.graph().node_count();
assert_eq!(initial_node_count, 2);
components.clear();
components.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
manager
.add_components(&components)
.unwrap();
assert_eq!(manager.graph().node_count(), 2, "Should not create duplicate nodes");
}
#[test]
fn test_add_tokenless_components_error() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let mut components = HashMap::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
components.insert("pool2".to_string(), vec![]);
components.insert("pool3".to_string(), vec![]);
let result = manager.add_components(&components);
assert!(result.is_err());
match result.unwrap_err() {
GraphError::InvalidComponents(ids) => {
assert_eq!(ids.len(), 2);
assert!(ids.contains(&"pool2".to_string()));
assert!(ids.contains(&"pool3".to_string()));
}
_ => panic!("Expected InvalidComponents error"),
}
assert_eq!(manager.graph().node_count(), 2);
assert_eq!(manager.graph().edge_count(), 2); }
#[test]
fn test_remove_components_not_found_error() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let mut components = HashMap::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
components.insert("pool2".to_string(), vec![token_a.clone(), token_b.clone()]);
manager
.add_components(&components)
.unwrap();
let result = manager.remove_components(&[
"pool1".to_string(),
"pool3".to_string(),
"pool4".to_string(),
]);
assert!(result.is_err());
match result.unwrap_err() {
GraphError::ComponentsNotFound(ids) => {
assert_eq!(ids.len(), 2, "Expected 2 missing components");
assert!(ids.contains(&"pool3".to_string()));
assert!(ids.contains(&"pool4".to_string()));
}
_ => panic!("Expected ComponentsNotFound error"),
}
for edge in manager.graph().edge_indices() {
assert_eq!(
manager
.graph()
.edge_weight(edge)
.unwrap()
.component_id,
"pool2".to_string()
);
}
}
#[test]
fn test_set_edge_weight_errors() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let mut topology = HashMap::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2"); let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48"); let token_c = addr("0x6B175474E89094C44Da98b954EedeAC495271d0F");
topology.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
topology.insert("pool2".to_string(), vec![token_b.clone(), token_c.clone()]);
manager.initialize_graph(&topology);
let result = manager.set_edge_weight(&"pool3".to_string(), &token_a, &token_b, (), true);
assert!(result.is_err());
match result.unwrap_err() {
GraphError::ComponentsNotFound(ids) => {
assert_eq!(ids, vec!["pool3".to_string()]);
}
_ => panic!("Expected ComponentsNotFound error"),
}
let non_existent_token = addr("0x0000000000000000000000000000000000000000");
let result = manager.set_edge_weight(
&"pool1".to_string(),
&token_a,
&non_existent_token, (),
true,
);
assert!(result.is_err());
match result.unwrap_err() {
GraphError::TokenNotFound(found_addr) => {
assert_eq!(found_addr, non_existent_token);
}
_ => panic!("Expected TokenNotFound error"),
}
let result = manager.set_edge_weight(
&"pool1".to_string(),
&token_a,
&token_c, (),
true,
);
assert!(result.is_err());
match result.unwrap_err() {
GraphError::MissingComponentBetweenTokens(in_token, out_token, comp_id) => {
assert_eq!(in_token, token_a);
assert_eq!(out_token, token_c);
assert_eq!(comp_id, "pool1".to_string());
}
_ => panic!("Expected MissingComponentBetweenTokens error"),
}
}
#[tokio::test]
async fn test_handle_event_propagates_errors() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
use std::collections::HashMap;
use crate::feed::events::{EventError, MarketEvent};
let event = MarketEvent::MarketUpdated {
added_components: HashMap::from([("pool1".to_string(), vec![])]),
removed_components: vec!["pool2".to_string()],
updated_components: vec![],
};
let result = manager.handle_event(&event).await;
assert!(result.is_err());
match result.unwrap_err() {
EventError::GraphErrors(errors) => {
assert_eq!(errors.len(), 2);
let has_add_error = errors
.iter()
.any(|e| matches!(e, GraphError::InvalidComponents(_)));
let has_remove_error = errors
.iter()
.any(|e| matches!(e, GraphError::ComponentsNotFound(_)));
assert!(has_add_error, "Should have InvalidComponents error");
assert!(has_remove_error, "Should have ComponentsNotFound error");
}
}
}
#[test]
fn test_add_components_skips_duplicates() {
let mut manager = PetgraphStableDiGraphManager::<()>::new();
let token_a = addr("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2");
let token_b = addr("0xA0b86991c6218b36c1d19D4a2e9Eb0cE3606eB48");
let mut components = HashMap::new();
components.insert("pool1".to_string(), vec![token_a.clone(), token_b.clone()]);
manager
.add_components(&components)
.unwrap();
let edge_count_after_first = manager.graph().edge_count();
assert_eq!(edge_count_after_first, 2);
manager
.add_components(&components)
.unwrap();
let edge_count_after_second = manager.graph().edge_count();
assert_eq!(
edge_count_after_first, edge_count_after_second,
"Edge count should not change when re-adding the same component"
);
}
}