use std::collections::{HashMap, HashSet};
use tracing::{debug, instrument, warn};
use tycho_common::{
dto::{BlockChanges, Chain, DCIUpdate, ProtocolComponent, ProtocolComponentsRequestBody},
models::{Address, ComponentId, ProtocolSystem},
};
use crate::{
rpc::{RPCClient, RPC_CLIENT_CONCURRENCY},
RPCError,
};
#[derive(Clone, Debug)]
pub(crate) enum ComponentFilterVariant {
Ids(Vec<ComponentId>),
MinimumTVLRange {
range: (f64, f64),
blocklisted_ids: HashSet<ComponentId>,
},
}
#[derive(Clone, Debug)]
pub struct ComponentFilter {
variant: ComponentFilterVariant,
}
impl ComponentFilter {
#[allow(non_snake_case)] #[deprecated(since = "0.9.2", note = "Please use with_tvl_range instead")]
pub fn MinimumTVL(min_tvl: f64) -> ComponentFilter {
ComponentFilter {
variant: ComponentFilterVariant::MinimumTVLRange {
range: (min_tvl, min_tvl),
blocklisted_ids: HashSet::new(),
},
}
}
pub fn with_tvl_range(remove_tvl_threshold: f64, add_tvl_threshold: f64) -> ComponentFilter {
ComponentFilter {
variant: ComponentFilterVariant::MinimumTVLRange {
range: (remove_tvl_threshold, add_tvl_threshold),
blocklisted_ids: HashSet::new(),
},
}
}
#[allow(non_snake_case)] pub fn Ids(ids: Vec<ComponentId>) -> ComponentFilter {
ComponentFilter {
variant: ComponentFilterVariant::Ids(
ids.into_iter()
.map(|id| id.to_lowercase())
.collect(),
),
}
}
pub fn blocklist(mut self, ids: impl IntoIterator<Item = ComponentId>) -> Self {
match &mut self.variant {
ComponentFilterVariant::Ids(_) => {
warn!(
"blocklist() has no effect on ComponentFilter::Ids; \
remove the component from the ID list instead"
);
}
ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
blocklisted_ids.extend(
ids.into_iter()
.map(|id| id.to_lowercase()),
);
}
}
self
}
pub fn is_blocklisted(&self, id: &str) -> bool {
match &self.variant {
ComponentFilterVariant::Ids(_) => false,
ComponentFilterVariant::MinimumTVLRange { blocklisted_ids, .. } => {
blocklisted_ids.contains(&id.to_lowercase())
}
}
}
}
#[derive(Default)]
struct EntrypointRelations {
components: HashSet<ComponentId>,
contracts: HashSet<Address>,
}
pub struct ComponentTracker<R: RPCClient> {
chain: Chain,
protocol_system: ProtocolSystem,
filter: ComponentFilter,
pub components: HashMap<ComponentId, ProtocolComponent>,
entrypoints: HashMap<String, EntrypointRelations>,
pub contracts: HashSet<Address>,
rpc_client: R,
}
impl<R> ComponentTracker<R>
where
R: RPCClient,
{
pub fn new(chain: Chain, protocol_system: &str, filter: ComponentFilter, rpc: R) -> Self {
Self {
chain,
protocol_system: protocol_system.to_string(),
filter,
components: Default::default(),
contracts: Default::default(),
rpc_client: rpc,
entrypoints: Default::default(),
}
}
pub async fn initialise_components(&mut self) -> Result<(), RPCError> {
let body = match &self.filter.variant {
ComponentFilterVariant::Ids(ids) => ProtocolComponentsRequestBody::id_filtered(
&self.protocol_system,
ids.clone(),
self.chain,
),
ComponentFilterVariant::MinimumTVLRange { range: (_, upper_tvl_threshold), .. } => {
ProtocolComponentsRequestBody::system_filtered(
&self.protocol_system,
Some(*upper_tvl_threshold),
self.chain,
)
}
};
self.components = self
.rpc_client
.get_protocol_components_paginated(&body, None, RPC_CLIENT_CONCURRENCY)
.await?
.protocol_components
.into_iter()
.map(|pc| (pc.id.clone(), pc))
.filter(|(id, _)| !self.filter.is_blocklisted(id))
.collect::<HashMap<_, _>>();
self.reinitialize_contracts();
Ok(())
}
fn reinitialize_contracts(&mut self) {
self.contracts = self
.components
.values()
.flat_map(|comp| comp.contract_ids.iter().cloned())
.collect();
let tracked_component_ids = self
.components
.keys()
.cloned()
.collect::<HashSet<_>>();
for entrypoint in self.entrypoints.values() {
if !entrypoint
.components
.is_disjoint(&tracked_component_ids)
{
self.contracts
.extend(entrypoint.contracts.iter().cloned());
}
}
}
fn update_contracts(&mut self, components: Vec<ComponentId>) {
let mut tracked_component_ids = HashSet::new();
for comp in components {
if let Some(component) = self.components.get(&comp) {
self.contracts
.extend(component.contract_ids.iter().cloned());
tracked_component_ids.insert(comp);
}
}
for entrypoint in self.entrypoints.values() {
if !entrypoint
.components
.is_disjoint(&tracked_component_ids)
{
self.contracts
.extend(entrypoint.contracts.iter().cloned());
}
}
}
#[instrument(skip(self, new_components))]
pub async fn start_tracking(
&mut self,
new_components: &[&ComponentId],
) -> Result<(), RPCError> {
let new_components: Vec<_> = new_components
.iter()
.filter(|id| !self.filter.is_blocklisted(id))
.copied()
.collect();
if new_components.is_empty() {
return Ok(());
}
let request = ProtocolComponentsRequestBody::id_filtered(
&self.protocol_system,
new_components
.iter()
.map(|&id| id.to_string())
.collect(),
self.chain,
);
let components = self
.rpc_client
.get_protocol_components(&request)
.await?
.protocol_components
.into_iter()
.map(|pc| (pc.id.clone(), pc))
.collect::<HashMap<_, _>>();
let component_ids: Vec<_> = components.keys().cloned().collect();
let component_count = component_ids.len();
self.components.extend(components);
self.update_contracts(component_ids);
debug!(n_components = component_count, "StartedTracking");
Ok(())
}
#[instrument(skip(self, to_remove))]
pub fn stop_tracking<'a, I: IntoIterator<Item = &'a ComponentId> + std::fmt::Debug>(
&mut self,
to_remove: I,
) -> HashMap<ComponentId, ProtocolComponent> {
let mut removed_components = HashMap::new();
for component_id in to_remove {
if let Some(component) = self.components.remove(component_id) {
removed_components.insert(component_id.clone(), component);
}
}
self.reinitialize_contracts();
debug!(n_components = removed_components.len(), "StoppedTracking");
removed_components
}
pub fn process_entrypoints(&mut self, dci_update: &DCIUpdate) {
for (entrypoint, traces) in &dci_update.trace_results {
self.entrypoints
.entry(entrypoint.clone())
.or_default()
.contracts
.extend(traces.accessed_slots.keys().cloned());
}
for (component, entrypoints) in &dci_update.new_entrypoints {
for entrypoint in entrypoints {
let entrypoint_info = self
.entrypoints
.entry(entrypoint.external_id.clone())
.or_default();
entrypoint_info
.components
.insert(component.clone());
if self.components.contains_key(component) {
self.contracts.extend(
entrypoint_info
.contracts
.iter()
.cloned(),
);
}
}
}
}
pub fn get_contracts_by_component<'a, I: IntoIterator<Item = &'a String>>(
&self,
ids: I,
) -> HashSet<Address> {
ids.into_iter()
.filter_map(|cid| {
if let Some(comp) = self.components.get(cid) {
let dci_contracts: HashSet<Address> = self
.entrypoints
.values()
.filter(|ep| ep.components.contains(cid))
.flat_map(|ep| ep.contracts.iter().cloned())
.collect();
Some(
comp.contract_ids
.clone()
.into_iter()
.chain(dci_contracts)
.collect::<HashSet<_>>(),
)
} else {
warn!(
"Requested component is not tracked: {cid}. Skipping fetching contracts..."
);
None
}
})
.flatten()
.collect()
}
pub fn get_tracked_component_ids(&self) -> Vec<ComponentId> {
self.components
.keys()
.cloned()
.collect()
}
pub fn filter_updated_components(
&self,
deltas: &BlockChanges,
) -> (Vec<ComponentId>, Vec<ComponentId>) {
match &self.filter.variant {
ComponentFilterVariant::Ids(_) => (Default::default(), Default::default()),
ComponentFilterVariant::MinimumTVLRange { range: (remove_tvl, add_tvl), .. } => {
let (mut to_add, mut to_remove): (Vec<_>, Vec<_>) = deltas
.component_tvl
.iter()
.filter(|(_, &tvl)| tvl < *remove_tvl || tvl > *add_tvl)
.map(|(id, _)| id.clone())
.partition(|id| deltas.component_tvl[id] > *add_tvl);
to_add.retain(|id| !self.filter.is_blocklisted(id));
for id in self.components.keys() {
if self.filter.is_blocklisted(id) && !to_remove.contains(id) {
to_remove.push(id.clone());
}
}
(to_add, to_remove)
}
}
}
}
#[cfg(test)]
mod test {
use tycho_common::{
dto::{PaginationResponse, ProtocolComponentRequestResponse},
Bytes,
};
use super::*;
use crate::rpc::MockRPCClient;
fn with_mocked_rpc() -> ComponentTracker<MockRPCClient> {
let rpc = MockRPCClient::new();
ComponentTracker::new(
Chain::Ethereum,
"uniswap-v2",
ComponentFilter::with_tvl_range(0.0, 0.0),
rpc,
)
}
fn components_response() -> (Vec<Address>, ProtocolComponent) {
let contract_ids = vec![Bytes::from("0x1234"), Bytes::from("0xbabe")];
let component = ProtocolComponent {
id: "Component1".to_string(),
contract_ids: contract_ids.clone(),
..Default::default()
};
(contract_ids, component)
}
#[tokio::test]
async fn test_initialise_components() {
let mut tracker = with_mocked_rpc();
let (contract_ids, component) = components_response();
let exp_component = component.clone();
tracker
.rpc_client
.expect_get_protocol_components_paginated()
.returning(move |_, _, _| {
Ok(ProtocolComponentRequestResponse {
protocol_components: vec![component.clone()],
pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
})
});
tracker
.initialise_components()
.await
.expect("Retrieving components failed");
assert_eq!(
tracker
.components
.get("Component1")
.expect("Component1 not tracked"),
&exp_component
);
assert_eq!(tracker.contracts, contract_ids.into_iter().collect());
}
#[tokio::test]
async fn test_start_tracking() {
let mut tracker = with_mocked_rpc();
let (contract_ids, component) = components_response();
let exp_contracts = contract_ids.into_iter().collect();
let component_id = component.id.clone();
let components_arg = [&component_id];
tracker
.rpc_client
.expect_get_protocol_components()
.returning(move |_| {
Ok(ProtocolComponentRequestResponse {
protocol_components: vec![component.clone()],
pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
})
});
tracker
.start_tracking(&components_arg)
.await
.expect("Tracking components failed");
assert_eq!(&tracker.contracts, &exp_contracts);
assert!(tracker
.components
.contains_key("Component1"));
}
#[test]
fn test_stop_tracking() {
let mut tracker = with_mocked_rpc();
let (contract_ids, component) = components_response();
tracker
.components
.insert("Component1".to_string(), component.clone());
tracker.contracts.extend(contract_ids);
let components_arg = ["Component1".to_string(), "Component2".to_string()];
let exp = [("Component1".to_string(), component)]
.into_iter()
.collect();
let res = tracker.stop_tracking(&components_arg);
assert_eq!(res, exp);
assert!(tracker.contracts.is_empty());
}
#[test]
fn test_get_contracts_by_component() {
let mut tracker = with_mocked_rpc();
let (exp_contracts, component) = components_response();
tracker
.components
.insert("Component1".to_string(), component);
let components_arg = ["Component1".to_string()];
let res = tracker.get_contracts_by_component(&components_arg);
assert_eq!(res, exp_contracts.into_iter().collect());
}
#[test]
fn test_get_tracked_component_ids() {
let mut tracker = with_mocked_rpc();
let (_, component) = components_response();
tracker
.components
.insert("Component1".to_string(), component);
let exp = vec!["Component1".to_string()];
let res = tracker.get_tracked_component_ids();
assert_eq!(res, exp);
}
fn with_mocked_rpc_and_blocklist(blocklisted: Vec<&str>) -> ComponentTracker<MockRPCClient> {
let rpc = MockRPCClient::new();
let filter = ComponentFilter::with_tvl_range(0.0, 0.0).blocklist(
blocklisted
.into_iter()
.map(String::from),
);
ComponentTracker::new(Chain::Ethereum, "uniswap-v2", filter, rpc)
}
#[tokio::test]
async fn test_initialise_skips_blocklisted_components() {
let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
let (_, component) = components_response();
tracker
.rpc_client
.expect_get_protocol_components_paginated()
.returning(move |_, _, _| {
Ok(ProtocolComponentRequestResponse {
protocol_components: vec![component.clone()],
pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
})
});
tracker
.initialise_components()
.await
.expect("Retrieving components failed");
assert!(tracker.components.is_empty(), "Blocklisted component should not be in tracker");
}
#[tokio::test]
async fn test_start_tracking_skips_blocklisted() {
let mut tracker = with_mocked_rpc_and_blocklist(vec!["component1"]);
let component_id = "Component1".to_string();
let components_arg = [&component_id];
tracker
.start_tracking(&components_arg)
.await
.expect("start_tracking should succeed");
assert!(tracker.components.is_empty(), "Blocklisted component should not be tracked");
}
#[test]
fn test_filter_updated_blocks_blocklisted_add() {
let mut tracker = with_mocked_rpc_and_blocklist(vec!["blocklisted_pool"]);
tracker.filter = ComponentFilter::with_tvl_range(5.0, 10.0)
.blocklist(vec!["blocklisted_pool".to_string()]);
let deltas = BlockChanges {
component_tvl: HashMap::from([
("blocklisted_pool".to_string(), 100.0),
("allowed_pool".to_string(), 100.0),
]),
..Default::default()
};
let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
assert!(
!to_add.contains(&"blocklisted_pool".to_string()),
"Blocklisted component should not be in to_add"
);
assert!(
to_add.contains(&"allowed_pool".to_string()),
"Non-blocklisted component should be in to_add"
);
assert!(to_remove.is_empty());
}
}