tycho_client/feed/
synchronizer.rs

1use std::{collections::HashMap, sync::Arc, time::Duration};
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6use tokio::{
7    select,
8    sync::{
9        mpsc::{channel, error::SendError, Receiver, Sender},
10        oneshot, Mutex,
11    },
12    task::JoinHandle,
13    time::timeout,
14};
15use tracing::{debug, error, info, instrument, trace, warn};
16use tycho_common::{
17    dto::{
18        BlockChanges, BlockParam, ComponentTvlRequestBody, EntryPointWithTracingParams,
19        ExtractorIdentity, ProtocolComponent, ResponseAccount, ResponseProtocolState,
20        TracingResult, VersionParam,
21    },
22    Bytes,
23};
24
25use crate::{
26    deltas::{DeltasClient, SubscriptionOptions},
27    feed::{
28        component_tracker::{ComponentFilter, ComponentTracker},
29        BlockHeader, HeaderLike,
30    },
31    rpc::{RPCClient, RPCError},
32    DeltasError,
33};
34
35#[derive(Error, Debug)]
36pub enum SynchronizerError {
37    /// RPC client failures.
38    #[error("RPC error: {0}")]
39    RPCError(#[from] RPCError),
40
41    /// Failed to send channel message to the consumer.
42    #[error("Failed to send channel message: {0}")]
43    ChannelError(String),
44
45    /// Timeout elapsed errors.
46    #[error("Timeout error: {0}")]
47    Timeout(String),
48
49    /// Failed to close the synchronizer.
50    #[error("Failed to close synchronizer: {0}")]
51    CloseError(String),
52
53    /// Server connection failures or interruptions.
54    #[error("Connection error: {0}")]
55    ConnectionError(String),
56
57    /// Connection closed
58    #[error("Connection closed")]
59    ConnectionClosed,
60}
61
62pub type SyncResult<T> = Result<T, SynchronizerError>;
63
64impl From<SendError<StateSyncMessage<BlockHeader>>> for SynchronizerError {
65    fn from(err: SendError<StateSyncMessage<BlockHeader>>) -> Self {
66        SynchronizerError::ChannelError(err.to_string())
67    }
68}
69
70impl From<DeltasError> for SynchronizerError {
71    fn from(err: DeltasError) -> Self {
72        match err {
73            DeltasError::NotConnected => SynchronizerError::ConnectionClosed,
74            _ => SynchronizerError::ConnectionError(err.to_string()),
75        }
76    }
77}
78
79#[derive(Clone)]
80pub struct ProtocolStateSynchronizer<R: RPCClient, D: DeltasClient> {
81    extractor_id: ExtractorIdentity,
82    retrieve_balances: bool,
83    rpc_client: R,
84    deltas_client: D,
85    max_retries: u64,
86    include_snapshots: bool,
87    component_tracker: Arc<Mutex<ComponentTracker<R>>>,
88    shared: Arc<Mutex<SharedState>>,
89    end_tx: Arc<Mutex<Option<oneshot::Sender<()>>>>,
90    timeout: u64,
91    include_tvl: bool,
92}
93
94#[derive(Debug, Default)]
95struct SharedState {
96    last_synced_block: Option<BlockHeader>,
97}
98
99#[derive(Clone, PartialEq, Debug, Serialize, Deserialize)]
100pub struct ComponentWithState {
101    pub state: ResponseProtocolState,
102    pub component: ProtocolComponent,
103    pub component_tvl: Option<f64>,
104    pub entrypoints: Vec<(EntryPointWithTracingParams, TracingResult)>,
105}
106
107#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
108pub struct Snapshot {
109    pub states: HashMap<String, ComponentWithState>,
110    pub vm_storage: HashMap<Bytes, ResponseAccount>,
111}
112
113impl Snapshot {
114    fn extend(&mut self, other: Snapshot) {
115        self.states.extend(other.states);
116        self.vm_storage.extend(other.vm_storage);
117    }
118
119    pub fn get_states(&self) -> &HashMap<String, ComponentWithState> {
120        &self.states
121    }
122
123    pub fn get_vm_storage(&self) -> &HashMap<Bytes, ResponseAccount> {
124        &self.vm_storage
125    }
126}
127
128#[derive(Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
129pub struct StateSyncMessage<H>
130where
131    H: HeaderLike,
132{
133    /// The block information for this update.
134    pub header: H,
135    /// Snapshot for new components.
136    pub snapshots: Snapshot,
137    /// A single delta contains state updates for all tracked components, as well as additional
138    /// information about the system components e.g. newly added components (even below tvl), tvl
139    /// updates, balance updates.
140    pub deltas: Option<BlockChanges>,
141    /// Components that stopped being tracked.
142    pub removed_components: HashMap<String, ProtocolComponent>,
143}
144
145impl<H> StateSyncMessage<H>
146where
147    H: HeaderLike,
148{
149    pub fn merge(mut self, other: Self) -> Self {
150        // be careful with removed and snapshots attributes here, these can be ambiguous.
151        self.removed_components
152            .retain(|k, _| !other.snapshots.states.contains_key(k));
153        self.snapshots
154            .states
155            .retain(|k, _| !other.removed_components.contains_key(k));
156
157        self.snapshots.extend(other.snapshots);
158        let deltas = match (self.deltas, other.deltas) {
159            (Some(l), Some(r)) => Some(l.merge(r)),
160            (None, Some(r)) => Some(r),
161            (Some(l), None) => Some(l),
162            (None, None) => None,
163        };
164        self.removed_components
165            .extend(other.removed_components);
166        Self {
167            header: other.header,
168            snapshots: self.snapshots,
169            deltas,
170            removed_components: self.removed_components,
171        }
172    }
173}
174
175/// StateSynchronizer
176///
177/// Used to synchronize the state of a single protocol. The synchronizer is responsible for
178/// delivering messages to the client that let him reconstruct subsets of the protocol state.
179///
180/// This involves deciding which components to track according to the clients preferences,
181/// retrieving & emitting snapshots of components which the client has not seen yet and subsequently
182/// delivering delta messages for the components that have changed.
183#[async_trait]
184pub trait StateSynchronizer: Send + Sync + 'static {
185    async fn initialize(&self) -> SyncResult<()>;
186    /// Starts the state synchronization.
187    async fn start(
188        &self,
189    ) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage<BlockHeader>>)>;
190    /// Ends the synchronization loop.
191    async fn close(&mut self) -> SyncResult<()>;
192}
193
194impl<R, D> ProtocolStateSynchronizer<R, D>
195where
196    // TODO: Consider moving these constraints directly to the
197    // client...
198    R: RPCClient + Clone + Send + Sync + 'static,
199    D: DeltasClient + Clone + Send + Sync + 'static,
200{
201    /// Creates a new state synchronizer.
202    #[allow(clippy::too_many_arguments)]
203    pub fn new(
204        extractor_id: ExtractorIdentity,
205        retrieve_balances: bool,
206        component_filter: ComponentFilter,
207        max_retries: u64,
208        include_snapshots: bool,
209        include_tvl: bool,
210        rpc_client: R,
211        deltas_client: D,
212        timeout: u64,
213    ) -> Self {
214        Self {
215            extractor_id: extractor_id.clone(),
216            retrieve_balances,
217            rpc_client: rpc_client.clone(),
218            include_snapshots,
219            deltas_client,
220            component_tracker: Arc::new(Mutex::new(ComponentTracker::new(
221                extractor_id.chain,
222                extractor_id.name.as_str(),
223                component_filter,
224                rpc_client,
225            ))),
226            max_retries,
227            shared: Arc::new(Mutex::new(SharedState::default())),
228            end_tx: Arc::new(Mutex::new(None)),
229            timeout,
230            include_tvl,
231        }
232    }
233
234    /// Retrieves state snapshots of the requested components
235    #[allow(deprecated)]
236    async fn get_snapshots<'a, I: IntoIterator<Item = &'a String>>(
237        &self,
238        header: BlockHeader,
239        tracked_components: &mut ComponentTracker<R>,
240        ids: Option<I>,
241    ) -> SyncResult<StateSyncMessage<BlockHeader>> {
242        if !self.include_snapshots {
243            return Ok(StateSyncMessage { header, ..Default::default() });
244        }
245        let version = VersionParam::new(
246            None,
247            Some(BlockParam {
248                chain: Some(self.extractor_id.chain),
249                hash: None,
250                number: Some(header.number as i64),
251            }),
252        );
253
254        // Use given ids or use all if not passed
255        let component_ids: Vec<_> = match ids {
256            Some(ids) => ids.into_iter().cloned().collect(),
257            None => tracked_components.get_tracked_component_ids(),
258        };
259
260        if component_ids.is_empty() {
261            return Ok(StateSyncMessage { header, ..Default::default() });
262        }
263
264        let component_tvl = if self.include_tvl {
265            let body = ComponentTvlRequestBody::id_filtered(
266                component_ids.clone(),
267                self.extractor_id.chain,
268            );
269            self.rpc_client
270                .get_component_tvl_paginated(&body, 100, 4)
271                .await?
272                .tvl
273        } else {
274            HashMap::new()
275        };
276
277        // Fetch entrypoints
278        let entrypoints_result = self
279            .rpc_client
280            .get_traced_entry_points_paginated(
281                self.extractor_id.chain,
282                &self.extractor_id.name,
283                &component_ids,
284                100,
285                4,
286            )
287            .await?;
288        tracked_components.process_entrypoints(&entrypoints_result.clone().into())?;
289
290        // Fetch protocol states
291        let mut protocol_states = self
292            .rpc_client
293            .get_protocol_states_paginated(
294                self.extractor_id.chain,
295                &component_ids,
296                &self.extractor_id.name,
297                self.retrieve_balances,
298                &version,
299                100,
300                4,
301            )
302            .await?
303            .states
304            .into_iter()
305            .map(|state| (state.component_id.clone(), state))
306            .collect::<HashMap<_, _>>();
307
308        trace!(states=?&protocol_states, "Retrieved ProtocolStates");
309        let states = tracked_components
310            .components
311            .values()
312            .filter_map(|component| {
313                if let Some(state) = protocol_states.remove(&component.id) {
314                    Some((
315                        component.id.clone(),
316                        ComponentWithState {
317                            state,
318                            component: component.clone(),
319                            component_tvl: component_tvl
320                                .get(&component.id)
321                                .cloned(),
322                            entrypoints: entrypoints_result
323                                .traced_entry_points
324                                .get(&component.id)
325                                .cloned()
326                                .unwrap_or_default(),
327                        },
328                    ))
329                } else if component_ids.contains(&component.id) {
330                    // only emit error event if we requested this component
331                    let component_id = &component.id;
332                    error!(?component_id, "Missing state for native component!");
333                    None
334                } else {
335                    None
336                }
337            })
338            .collect();
339
340        // Fetch contract states
341        let contract_ids = tracked_components.get_contracts_by_component(&component_ids);
342        let vm_storage = if !contract_ids.is_empty() {
343            let ids: Vec<Bytes> = contract_ids
344                .clone()
345                .into_iter()
346                .collect();
347            let contract_states = self
348                .rpc_client
349                .get_contract_state_paginated(
350                    self.extractor_id.chain,
351                    ids.as_slice(),
352                    &self.extractor_id.name,
353                    &version,
354                    100,
355                    4,
356                )
357                .await?
358                .accounts
359                .into_iter()
360                .map(|acc| (acc.address.clone(), acc))
361                .collect::<HashMap<_, _>>();
362
363            trace!(states=?&contract_states, "Retrieved ContractState");
364
365            let contract_address_to_components = tracked_components
366                .components
367                .iter()
368                .filter_map(|(id, comp)| {
369                    if component_ids.contains(id) {
370                        Some(
371                            comp.contract_ids
372                                .iter()
373                                .map(|address| (address.clone(), comp.id.clone())),
374                        )
375                    } else {
376                        None
377                    }
378                })
379                .flatten()
380                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
381                    acc.entry(addr).or_default().push(c_id);
382                    acc
383                });
384
385            contract_ids
386                .iter()
387                .filter_map(|address| {
388                    if let Some(state) = contract_states.get(address) {
389                        Some((address.clone(), state.clone()))
390                    } else if let Some(ids) = contract_address_to_components.get(address) {
391                        // only emit error even if we did actually request this address
392                        error!(
393                            ?address,
394                            ?ids,
395                            "Component with lacking contract storage encountered!"
396                        );
397                        None
398                    } else {
399                        None
400                    }
401                })
402                .collect()
403        } else {
404            HashMap::new()
405        };
406
407        Ok(StateSyncMessage {
408            header,
409            snapshots: Snapshot { states, vm_storage },
410            deltas: None,
411            removed_components: HashMap::new(),
412        })
413    }
414
415    /// Main method that does all the work.
416    #[instrument(skip(self, block_tx), fields(extractor_id = %self.extractor_id))]
417    async fn state_sync(
418        self,
419        block_tx: &mut Sender<StateSyncMessage<BlockHeader>>,
420    ) -> SyncResult<()> {
421        // initialisation
422        let mut tracker = self.component_tracker.lock().await;
423
424        let subscription_options = SubscriptionOptions::new().with_state(self.include_snapshots);
425        let (_, mut msg_rx) = self
426            .deltas_client
427            .subscribe(self.extractor_id.clone(), subscription_options)
428            .await?;
429
430        info!("Waiting for deltas...");
431        // wait for first deltas message
432        let mut first_msg = timeout(Duration::from_secs(self.timeout), msg_rx.recv())
433            .await
434            .map_err(|_| {
435                SynchronizerError::Timeout(format!(
436                    "First deltas took longer than {t}s to arrive",
437                    t = self.timeout
438                ))
439            })?
440            .ok_or_else(|| {
441                SynchronizerError::ConnectionError(
442                    "Deltas channel closed before first message".to_string(),
443                )
444            })?;
445        self.filter_deltas(&mut first_msg, &tracker);
446
447        // initial snapshot
448        let block = first_msg.get_block().clone();
449        info!(height = &block.number, "Deltas received. Retrieving snapshot");
450        let header = BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert());
451        let snapshot = self
452            .get_snapshots::<Vec<&String>>(
453                BlockHeader::from_block(&block, false),
454                &mut tracker,
455                None,
456            )
457            .await?
458            .merge(StateSyncMessage {
459                header: BlockHeader::from_block(first_msg.get_block(), first_msg.is_revert()),
460                snapshots: Default::default(),
461                deltas: Some(first_msg),
462                removed_components: Default::default(),
463            });
464
465        let n_components = tracker.components.len();
466        let n_snapshots = snapshot.snapshots.states.len();
467        info!(n_components, n_snapshots, "Initial snapshot retrieved, starting delta message feed");
468
469        {
470            let mut shared = self.shared.lock().await;
471            block_tx.send(snapshot).await?;
472            shared.last_synced_block = Some(header.clone());
473        }
474
475        loop {
476            if let Some(mut deltas) = msg_rx.recv().await {
477                let header = BlockHeader::from_block(deltas.get_block(), deltas.is_revert());
478                debug!(block_number=?header.number, "Received delta message");
479
480                let (snapshots, removed_components) = {
481                    // 1. Remove components based on latest changes
482                    // 2. Add components based on latest changes, query those for snapshots
483                    let (to_add, to_remove) = tracker.filter_updated_components(&deltas);
484
485                    // Only components we don't track yet need a snapshot,
486                    let requiring_snapshot: Vec<_> = to_add
487                        .iter()
488                        .filter(|id| {
489                            !tracker
490                                .components
491                                .contains_key(id.as_str())
492                        })
493                        .collect();
494                    debug!(components=?requiring_snapshot, "SnapshotRequest");
495                    tracker
496                        .start_tracking(requiring_snapshot.as_slice())
497                        .await?;
498                    let snapshots = self
499                        .get_snapshots(header.clone(), &mut tracker, Some(requiring_snapshot))
500                        .await?
501                        .snapshots;
502
503                    let removed_components = if !to_remove.is_empty() {
504                        tracker.stop_tracking(&to_remove)
505                    } else {
506                        Default::default()
507                    };
508
509                    (snapshots, removed_components)
510                };
511
512                // 3. Update entrypoints on the tracker (affects which contracts are tracked)
513                tracker.process_entrypoints(&deltas.dci_update)?;
514
515                // 4. Filter deltas by currently tracked components / contracts
516                self.filter_deltas(&mut deltas, &tracker);
517                let n_changes = deltas.n_changes();
518
519                // 5. Send the message
520                let next = StateSyncMessage {
521                    header: header.clone(),
522                    snapshots,
523                    deltas: Some(deltas),
524                    removed_components,
525                };
526                block_tx.send(next).await?;
527                {
528                    let mut shared = self.shared.lock().await;
529                    shared.last_synced_block = Some(header.clone());
530                }
531
532                debug!(block_number=?header.number, n_changes, "Finished processing delta message");
533            } else {
534                let mut shared = self.shared.lock().await;
535                warn!(shared = ?&shared, "Deltas channel closed, resetting shared state.");
536                shared.last_synced_block = None;
537
538                return Err(SynchronizerError::ConnectionError("Deltas channel closed".to_string()));
539            }
540        }
541    }
542
543    fn filter_deltas(&self, second_msg: &mut BlockChanges, tracker: &ComponentTracker<R>) {
544        second_msg.filter_by_component(|id| tracker.components.contains_key(id));
545        second_msg.filter_by_contract(|id| tracker.contracts.contains(id));
546    }
547}
548
549#[async_trait]
550impl<R, D> StateSynchronizer for ProtocolStateSynchronizer<R, D>
551where
552    R: RPCClient + Clone + Send + Sync + 'static,
553    D: DeltasClient + Clone + Send + Sync + 'static,
554{
555    async fn initialize(&self) -> SyncResult<()> {
556        let mut tracker = self.component_tracker.lock().await;
557        info!("Retrieving relevant protocol components");
558        tracker.initialise_components().await?;
559        info!(
560            n_components = tracker.components.len(),
561            n_contracts = tracker.contracts.len(),
562            "Finished retrieving components",
563        );
564
565        Ok(())
566    }
567
568    async fn start(
569        &self,
570    ) -> SyncResult<(JoinHandle<SyncResult<()>>, Receiver<StateSyncMessage<BlockHeader>>)> {
571        let (mut tx, rx) = channel(15);
572
573        let this = self.clone();
574        let jh = tokio::spawn(async move {
575            let mut retry_count = 0;
576            while retry_count < this.max_retries {
577                info!(extractor_id=%&this.extractor_id, retry_count, "(Re)starting synchronization loop");
578                let (end_tx, end_rx) = oneshot::channel::<()>();
579                {
580                    let mut end_tx_guard = this.end_tx.lock().await;
581                    *end_tx_guard = Some(end_tx);
582                }
583
584                select! {
585                    res = this.clone().state_sync(&mut tx) => {
586                        match res {
587                            Err(e) => {
588                                error!(
589                                    extractor_id=%&this.extractor_id,
590                                    retry_count,
591                                    error=%e,
592                                    "State synchronization errored!"
593                                );
594                                if let SynchronizerError::ConnectionClosed = e {
595                                    // break synchronization loop if connection is closed
596                                    return Err(e);
597                                }
598                            }
599                            _ => {
600                                warn!(
601                                    extractor_id=%&this.extractor_id,
602                                    retry_count,
603                                    "State synchronization exited with Ok(())"
604                                );
605                            }
606                        }
607                    },
608                    _ = end_rx => {
609                        info!(
610                            extractor_id=%&this.extractor_id,
611                            retry_count,
612                            "StateSynchronizer received close signal. Stopping"
613                        );
614                        return Ok(())
615                    }
616                }
617                retry_count += 1;
618            }
619            Err(SynchronizerError::ConnectionError("Max connection retries exceeded".to_string()))
620        });
621
622        Ok((jh, rx))
623    }
624
625    async fn close(&mut self) -> SyncResult<()> {
626        let mut end_tx = self.end_tx.lock().await;
627        if let Some(tx) = end_tx.take() {
628            let _ = tx.send(());
629            Ok(())
630        } else {
631            Err(SynchronizerError::CloseError("Synchronizer not started".to_string()))
632        }
633    }
634}
635
636#[cfg(test)]
637mod test {
638    use std::collections::HashSet;
639
640    use test_log::test;
641    use tycho_common::dto::{
642        Block, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse, DCIUpdate, EntryPoint,
643        PaginationResponse, ProtocolComponentRequestResponse, ProtocolComponentsRequestBody,
644        ProtocolStateRequestBody, ProtocolStateRequestResponse, ProtocolSystemsRequestBody,
645        ProtocolSystemsRequestResponse, RPCTracerParams, StateRequestBody, StateRequestResponse,
646        TokensRequestBody, TokensRequestResponse, TracedEntryPointRequestBody,
647        TracedEntryPointRequestResponse, TracingParams,
648    };
649    use uuid::Uuid;
650
651    use super::*;
652    use crate::{deltas::MockDeltasClient, rpc::MockRPCClient, DeltasError, RPCError};
653
654    // Required for mock client to implement clone
655    struct ArcRPCClient<T>(Arc<T>);
656
657    // Default derive(Clone) does require T to be Clone as well.
658    impl<T> Clone for ArcRPCClient<T> {
659        fn clone(&self) -> Self {
660            ArcRPCClient(self.0.clone())
661        }
662    }
663
664    #[async_trait]
665    impl<T> RPCClient for ArcRPCClient<T>
666    where
667        T: RPCClient + Sync + Send + 'static,
668    {
669        async fn get_tokens(
670            &self,
671            request: &TokensRequestBody,
672        ) -> Result<TokensRequestResponse, RPCError> {
673            self.0.get_tokens(request).await
674        }
675
676        async fn get_contract_state(
677            &self,
678            request: &StateRequestBody,
679        ) -> Result<StateRequestResponse, RPCError> {
680            self.0.get_contract_state(request).await
681        }
682
683        async fn get_protocol_components(
684            &self,
685            request: &ProtocolComponentsRequestBody,
686        ) -> Result<ProtocolComponentRequestResponse, RPCError> {
687            self.0
688                .get_protocol_components(request)
689                .await
690        }
691
692        async fn get_protocol_states(
693            &self,
694            request: &ProtocolStateRequestBody,
695        ) -> Result<ProtocolStateRequestResponse, RPCError> {
696            self.0
697                .get_protocol_states(request)
698                .await
699        }
700
701        async fn get_protocol_systems(
702            &self,
703            request: &ProtocolSystemsRequestBody,
704        ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
705            self.0
706                .get_protocol_systems(request)
707                .await
708        }
709
710        async fn get_component_tvl(
711            &self,
712            request: &ComponentTvlRequestBody,
713        ) -> Result<ComponentTvlRequestResponse, RPCError> {
714            self.0.get_component_tvl(request).await
715        }
716
717        async fn get_traced_entry_points(
718            &self,
719            request: &TracedEntryPointRequestBody,
720        ) -> Result<TracedEntryPointRequestResponse, RPCError> {
721            self.0
722                .get_traced_entry_points(request)
723                .await
724        }
725    }
726
727    // Required for mock client to implement clone
728    struct ArcDeltasClient<T>(Arc<T>);
729
730    // Default derive(Clone) does require T to be Clone as well.
731    impl<T> Clone for ArcDeltasClient<T> {
732        fn clone(&self) -> Self {
733            ArcDeltasClient(self.0.clone())
734        }
735    }
736
737    #[async_trait]
738    impl<T> DeltasClient for ArcDeltasClient<T>
739    where
740        T: DeltasClient + Sync + Send + 'static,
741    {
742        async fn subscribe(
743            &self,
744            extractor_id: ExtractorIdentity,
745            options: SubscriptionOptions,
746        ) -> Result<(Uuid, Receiver<BlockChanges>), DeltasError> {
747            self.0
748                .subscribe(extractor_id, options)
749                .await
750        }
751
752        async fn unsubscribe(&self, subscription_id: Uuid) -> Result<(), DeltasError> {
753            self.0
754                .unsubscribe(subscription_id)
755                .await
756        }
757
758        async fn connect(&self) -> Result<JoinHandle<Result<(), DeltasError>>, DeltasError> {
759            self.0.connect().await
760        }
761
762        async fn close(&self) -> Result<(), DeltasError> {
763            self.0.close().await
764        }
765    }
766
767    fn with_mocked_clients(
768        native: bool,
769        include_tvl: bool,
770        rpc_client: Option<MockRPCClient>,
771        deltas_client: Option<MockDeltasClient>,
772    ) -> ProtocolStateSynchronizer<ArcRPCClient<MockRPCClient>, ArcDeltasClient<MockDeltasClient>>
773    {
774        let rpc_client = ArcRPCClient(Arc::new(rpc_client.unwrap_or_default()));
775        let deltas_client = ArcDeltasClient(Arc::new(deltas_client.unwrap_or_default()));
776
777        ProtocolStateSynchronizer::new(
778            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
779            native,
780            ComponentFilter::with_tvl_range(50.0, 50.0),
781            1,
782            true,
783            include_tvl,
784            rpc_client,
785            deltas_client,
786            10_u64,
787        )
788    }
789
790    fn state_snapshot_native() -> ProtocolStateRequestResponse {
791        ProtocolStateRequestResponse {
792            states: vec![ResponseProtocolState {
793                component_id: "Component1".to_string(),
794                ..Default::default()
795            }],
796            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
797        }
798    }
799
800    fn component_tvl_snapshot() -> ComponentTvlRequestResponse {
801        let tvl = HashMap::from([("Component1".to_string(), 100.0)]);
802
803        ComponentTvlRequestResponse {
804            tvl,
805            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
806        }
807    }
808
809    #[test(tokio::test)]
810    async fn test_get_snapshots_native() {
811        let header = BlockHeader::default();
812        let mut rpc = MockRPCClient::new();
813        rpc.expect_get_protocol_states()
814            .returning(|_| Ok(state_snapshot_native()));
815        rpc.expect_get_traced_entry_points()
816            .returning(|_| {
817                Ok(TracedEntryPointRequestResponse {
818                    traced_entry_points: HashMap::new(),
819                    pagination: PaginationResponse::new(0, 20, 0),
820                })
821            });
822        let state_sync = with_mocked_clients(true, false, Some(rpc), None);
823        let mut tracker = ComponentTracker::new(
824            Chain::Ethereum,
825            "uniswap-v2",
826            ComponentFilter::with_tvl_range(0.0, 0.0),
827            state_sync.rpc_client.clone(),
828        );
829        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
830        tracker
831            .components
832            .insert("Component1".to_string(), component.clone());
833        let components_arg = ["Component1".to_string()];
834        let exp = StateSyncMessage {
835            header: header.clone(),
836            snapshots: Snapshot {
837                states: state_snapshot_native()
838                    .states
839                    .into_iter()
840                    .map(|state| {
841                        (
842                            state.component_id.clone(),
843                            ComponentWithState {
844                                state,
845                                component: component.clone(),
846                                entrypoints: vec![],
847                                component_tvl: None,
848                            },
849                        )
850                    })
851                    .collect(),
852                vm_storage: HashMap::new(),
853            },
854            deltas: None,
855            removed_components: Default::default(),
856        };
857
858        let snap = state_sync
859            .get_snapshots(header, &mut tracker, Some(&components_arg))
860            .await
861            .expect("Retrieving snapshot failed");
862
863        assert_eq!(snap, exp);
864    }
865
866    #[test(tokio::test)]
867    async fn test_get_snapshots_native_with_tvl() {
868        let header = BlockHeader::default();
869        let mut rpc = MockRPCClient::new();
870        rpc.expect_get_protocol_states()
871            .returning(|_| Ok(state_snapshot_native()));
872        rpc.expect_get_component_tvl()
873            .returning(|_| Ok(component_tvl_snapshot()));
874        rpc.expect_get_traced_entry_points()
875            .returning(|_| {
876                Ok(TracedEntryPointRequestResponse {
877                    traced_entry_points: HashMap::new(),
878                    pagination: PaginationResponse::new(0, 20, 0),
879                })
880            });
881        let state_sync = with_mocked_clients(true, true, Some(rpc), None);
882        let mut tracker = ComponentTracker::new(
883            Chain::Ethereum,
884            "uniswap-v2",
885            ComponentFilter::with_tvl_range(0.0, 0.0),
886            state_sync.rpc_client.clone(),
887        );
888        let component = ProtocolComponent { id: "Component1".to_string(), ..Default::default() };
889        tracker
890            .components
891            .insert("Component1".to_string(), component.clone());
892        let components_arg = ["Component1".to_string()];
893        let exp = StateSyncMessage {
894            header: header.clone(),
895            snapshots: Snapshot {
896                states: state_snapshot_native()
897                    .states
898                    .into_iter()
899                    .map(|state| {
900                        (
901                            state.component_id.clone(),
902                            ComponentWithState {
903                                state,
904                                component: component.clone(),
905                                component_tvl: Some(100.0),
906                                entrypoints: vec![],
907                            },
908                        )
909                    })
910                    .collect(),
911                vm_storage: HashMap::new(),
912            },
913            deltas: None,
914            removed_components: Default::default(),
915        };
916
917        let snap = state_sync
918            .get_snapshots(header, &mut tracker, Some(&components_arg))
919            .await
920            .expect("Retrieving snapshot failed");
921
922        assert_eq!(snap, exp);
923    }
924
925    fn state_snapshot_vm() -> StateRequestResponse {
926        StateRequestResponse {
927            accounts: vec![
928                ResponseAccount { address: Bytes::from("0x0badc0ffee"), ..Default::default() },
929                ResponseAccount { address: Bytes::from("0xbabe42"), ..Default::default() },
930            ],
931            pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
932        }
933    }
934
935    fn traced_entry_point_response() -> TracedEntryPointRequestResponse {
936        TracedEntryPointRequestResponse {
937            traced_entry_points: HashMap::from([(
938                "Component1".to_string(),
939                vec![(
940                    EntryPointWithTracingParams {
941                        entry_point: EntryPoint {
942                            external_id: "entrypoint_a".to_string(),
943                            target: Bytes::from("0x0badc0ffee"),
944                            signature: "sig()".to_string(),
945                        },
946                        params: TracingParams::RPCTracer(RPCTracerParams {
947                            caller: Some(Bytes::from("0x0badc0ffee")),
948                            calldata: Bytes::from("0x0badc0ffee"),
949                        }),
950                    },
951                    TracingResult {
952                        retriggers: HashSet::from([(
953                            Bytes::from("0x0badc0ffee"),
954                            Bytes::from("0x0badc0ffee"),
955                        )]),
956                        accessed_slots: HashMap::from([(
957                            Bytes::from("0x0badc0ffee"),
958                            HashSet::from([Bytes::from("0xbadbeef0")]),
959                        )]),
960                    },
961                )],
962            )]),
963            pagination: PaginationResponse::new(0, 20, 0),
964        }
965    }
966
967    #[test(tokio::test)]
968    async fn test_get_snapshots_vm() {
969        let header = BlockHeader::default();
970        let mut rpc = MockRPCClient::new();
971        rpc.expect_get_protocol_states()
972            .returning(|_| Ok(state_snapshot_native()));
973        rpc.expect_get_contract_state()
974            .returning(|_| Ok(state_snapshot_vm()));
975        rpc.expect_get_traced_entry_points()
976            .returning(|_| Ok(traced_entry_point_response()));
977        let state_sync = with_mocked_clients(false, false, Some(rpc), None);
978        let mut tracker = ComponentTracker::new(
979            Chain::Ethereum,
980            "uniswap-v2",
981            ComponentFilter::with_tvl_range(0.0, 0.0),
982            state_sync.rpc_client.clone(),
983        );
984        let component = ProtocolComponent {
985            id: "Component1".to_string(),
986            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
987            ..Default::default()
988        };
989        tracker
990            .components
991            .insert("Component1".to_string(), component.clone());
992        let components_arg = ["Component1".to_string()];
993        let exp = StateSyncMessage {
994            header: header.clone(),
995            snapshots: Snapshot {
996                states: [(
997                    component.id.clone(),
998                    ComponentWithState {
999                        state: ResponseProtocolState {
1000                            component_id: "Component1".to_string(),
1001                            ..Default::default()
1002                        },
1003                        component: component.clone(),
1004                        component_tvl: None,
1005                        entrypoints: vec![(
1006                            EntryPointWithTracingParams {
1007                                entry_point: EntryPoint {
1008                                    external_id: "entrypoint_a".to_string(),
1009                                    target: Bytes::from("0x0badc0ffee"),
1010                                    signature: "sig()".to_string(),
1011                                },
1012                                params: TracingParams::RPCTracer(RPCTracerParams {
1013                                    caller: Some(Bytes::from("0x0badc0ffee")),
1014                                    calldata: Bytes::from("0x0badc0ffee"),
1015                                }),
1016                            },
1017                            TracingResult {
1018                                retriggers: HashSet::from([(
1019                                    Bytes::from("0x0badc0ffee"),
1020                                    Bytes::from("0x0badc0ffee"),
1021                                )]),
1022                                accessed_slots: HashMap::from([(
1023                                    Bytes::from("0x0badc0ffee"),
1024                                    HashSet::from([Bytes::from("0xbadbeef0")]),
1025                                )]),
1026                            },
1027                        )],
1028                    },
1029                )]
1030                .into_iter()
1031                .collect(),
1032                vm_storage: state_snapshot_vm()
1033                    .accounts
1034                    .into_iter()
1035                    .map(|state| (state.address.clone(), state))
1036                    .collect(),
1037            },
1038            deltas: None,
1039            removed_components: Default::default(),
1040        };
1041
1042        let snap = state_sync
1043            .get_snapshots(header, &mut tracker, Some(&components_arg))
1044            .await
1045            .expect("Retrieving snapshot failed");
1046
1047        assert_eq!(snap, exp);
1048    }
1049
1050    #[test(tokio::test)]
1051    async fn test_get_snapshots_vm_with_tvl() {
1052        let header = BlockHeader::default();
1053        let mut rpc = MockRPCClient::new();
1054        rpc.expect_get_protocol_states()
1055            .returning(|_| Ok(state_snapshot_native()));
1056        rpc.expect_get_contract_state()
1057            .returning(|_| Ok(state_snapshot_vm()));
1058        rpc.expect_get_component_tvl()
1059            .returning(|_| Ok(component_tvl_snapshot()));
1060        rpc.expect_get_traced_entry_points()
1061            .returning(|_| {
1062                Ok(TracedEntryPointRequestResponse {
1063                    traced_entry_points: HashMap::new(),
1064                    pagination: PaginationResponse::new(0, 20, 0),
1065                })
1066            });
1067        let state_sync = with_mocked_clients(false, true, Some(rpc), None);
1068        let mut tracker = ComponentTracker::new(
1069            Chain::Ethereum,
1070            "uniswap-v2",
1071            ComponentFilter::with_tvl_range(0.0, 0.0),
1072            state_sync.rpc_client.clone(),
1073        );
1074        let component = ProtocolComponent {
1075            id: "Component1".to_string(),
1076            contract_ids: vec![Bytes::from("0x0badc0ffee"), Bytes::from("0xbabe42")],
1077            ..Default::default()
1078        };
1079        tracker
1080            .components
1081            .insert("Component1".to_string(), component.clone());
1082        let components_arg = ["Component1".to_string()];
1083        let exp = StateSyncMessage {
1084            header: header.clone(),
1085            snapshots: Snapshot {
1086                states: [(
1087                    component.id.clone(),
1088                    ComponentWithState {
1089                        state: ResponseProtocolState {
1090                            component_id: "Component1".to_string(),
1091                            ..Default::default()
1092                        },
1093                        component: component.clone(),
1094                        component_tvl: Some(100.0),
1095                        entrypoints: vec![],
1096                    },
1097                )]
1098                .into_iter()
1099                .collect(),
1100                vm_storage: state_snapshot_vm()
1101                    .accounts
1102                    .into_iter()
1103                    .map(|state| (state.address.clone(), state))
1104                    .collect(),
1105            },
1106            deltas: None,
1107            removed_components: Default::default(),
1108        };
1109
1110        let snap = state_sync
1111            .get_snapshots(header, &mut tracker, Some(&components_arg))
1112            .await
1113            .expect("Retrieving snapshot failed");
1114
1115        assert_eq!(snap, exp);
1116    }
1117
1118    fn mock_clients_for_state_sync() -> (MockRPCClient, MockDeltasClient, Sender<BlockChanges>) {
1119        let mut rpc_client = MockRPCClient::new();
1120        // Mocks for the start_tracking call, these need to come first because they are more
1121        // specific, see: https://docs.rs/mockall/latest/mockall/#matching-multiple-calls
1122        rpc_client
1123            .expect_get_protocol_components()
1124            .with(mockall::predicate::function(
1125                move |request_params: &ProtocolComponentsRequestBody| {
1126                    if let Some(ids) = request_params.component_ids.as_ref() {
1127                        ids.contains(&"Component3".to_string())
1128                    } else {
1129                        false
1130                    }
1131                },
1132            ))
1133            .returning(|_| {
1134                // return Component3
1135                Ok(ProtocolComponentRequestResponse {
1136                    protocol_components: vec![
1137                        // this component shall have a tvl update above threshold
1138                        ProtocolComponent { id: "Component3".to_string(), ..Default::default() },
1139                    ],
1140                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1141                })
1142            });
1143        rpc_client
1144            .expect_get_protocol_states()
1145            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1146                let expected_id = "Component3".to_string();
1147                if let Some(ids) = request_params.protocol_ids.as_ref() {
1148                    ids.contains(&expected_id)
1149                } else {
1150                    false
1151                }
1152            }))
1153            .returning(|_| {
1154                // return Component3 state
1155                Ok(ProtocolStateRequestResponse {
1156                    states: vec![ResponseProtocolState {
1157                        component_id: "Component3".to_string(),
1158                        ..Default::default()
1159                    }],
1160                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1161                })
1162            });
1163
1164        // mock calls for the initial state snapshots
1165        rpc_client
1166            .expect_get_protocol_components()
1167            .returning(|_| {
1168                // Initial sync of components
1169                Ok(ProtocolComponentRequestResponse {
1170                    protocol_components: vec![
1171                        // this component shall have a tvl update above threshold
1172                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1173                        // this component shall have a tvl update below threshold.
1174                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1175                        // a third component will have a tvl update above threshold
1176                    ],
1177                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1178                })
1179            });
1180        rpc_client
1181            .expect_get_protocol_states()
1182            .returning(|_| {
1183                // Initial state snapshot
1184                Ok(ProtocolStateRequestResponse {
1185                    states: vec![
1186                        ResponseProtocolState {
1187                            component_id: "Component1".to_string(),
1188                            ..Default::default()
1189                        },
1190                        ResponseProtocolState {
1191                            component_id: "Component2".to_string(),
1192                            ..Default::default()
1193                        },
1194                    ],
1195                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1196                })
1197            });
1198        rpc_client
1199            .expect_get_component_tvl()
1200            .returning(|_| {
1201                Ok(ComponentTvlRequestResponse {
1202                    tvl: [
1203                        ("Component1".to_string(), 100.0),
1204                        ("Component2".to_string(), 0.0),
1205                        ("Component3".to_string(), 1000.0),
1206                    ]
1207                    .into_iter()
1208                    .collect(),
1209                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1210                })
1211            });
1212        rpc_client
1213            .expect_get_traced_entry_points()
1214            .returning(|_| {
1215                Ok(TracedEntryPointRequestResponse {
1216                    traced_entry_points: HashMap::new(),
1217                    pagination: PaginationResponse::new(0, 20, 0),
1218                })
1219            });
1220
1221        // Mock deltas client and messages
1222        let mut deltas_client = MockDeltasClient::new();
1223        let (tx, rx) = channel(1);
1224        deltas_client
1225            .expect_subscribe()
1226            .return_once(move |_, _| {
1227                // Return subscriber id and a channel
1228                Ok((Uuid::default(), rx))
1229            });
1230        (rpc_client, deltas_client, tx)
1231    }
1232
1233    /// Test strategy
1234    ///
1235    /// - initial snapshot retrieval returns two component1 and component2 as snapshots
1236    /// - send 2 dummy messages, containing only blocks
1237    /// - third message contains a new component with some significant tvl, one initial component
1238    ///   slips below tvl threshold, another one is above tvl but does not get re-requested.
1239    #[test(tokio::test)]
1240    async fn test_state_sync() {
1241        let (rpc_client, deltas_client, tx) = mock_clients_for_state_sync();
1242        let deltas = [
1243            BlockChanges {
1244                extractor: "uniswap-v2".to_string(),
1245                chain: Chain::Ethereum,
1246                block: Block {
1247                    number: 1,
1248                    hash: Bytes::from("0x01"),
1249                    parent_hash: Bytes::from("0x00"),
1250                    chain: Chain::Ethereum,
1251                    ts: Default::default(),
1252                },
1253                revert: false,
1254                dci_update: DCIUpdate {
1255                    new_entrypoints: HashMap::from([(
1256                        "Component1".to_string(),
1257                        HashSet::from([EntryPoint {
1258                            external_id: "entrypoint_a".to_string(),
1259                            target: Bytes::from("0x0badc0ffee"),
1260                            signature: "sig()".to_string(),
1261                        }]),
1262                    )]),
1263                    new_entrypoint_params: HashMap::from([(
1264                        "entrypoint_a".to_string(),
1265                        HashSet::from([(
1266                            TracingParams::RPCTracer(RPCTracerParams {
1267                                caller: Some(Bytes::from("0x0badc0ffee")),
1268                                calldata: Bytes::from("0x0badc0ffee"),
1269                            }),
1270                            Some("Component1".to_string()),
1271                        )]),
1272                    )]),
1273                    trace_results: HashMap::from([(
1274                        "entrypoint_a".to_string(),
1275                        TracingResult {
1276                            retriggers: HashSet::from([(
1277                                Bytes::from("0x0badc0ffee"),
1278                                Bytes::from("0x0badc0ffee"),
1279                            )]),
1280                            accessed_slots: HashMap::from([(
1281                                Bytes::from("0x0badc0ffee"),
1282                                HashSet::from([Bytes::from("0xbadbeef0")]),
1283                            )]),
1284                        },
1285                    )]),
1286                },
1287                ..Default::default()
1288            },
1289            BlockChanges {
1290                extractor: "uniswap-v2".to_string(),
1291                chain: Chain::Ethereum,
1292                block: Block {
1293                    number: 2,
1294                    hash: Bytes::from("0x02"),
1295                    parent_hash: Bytes::from("0x01"),
1296                    chain: Chain::Ethereum,
1297                    ts: Default::default(),
1298                },
1299                revert: false,
1300                component_tvl: [
1301                    ("Component1".to_string(), 100.0),
1302                    ("Component2".to_string(), 0.0),
1303                    ("Component3".to_string(), 1000.0),
1304                ]
1305                .into_iter()
1306                .collect(),
1307                ..Default::default()
1308            },
1309        ];
1310        let mut state_sync = with_mocked_clients(true, true, Some(rpc_client), Some(deltas_client));
1311        state_sync
1312            .initialize()
1313            .await
1314            .expect("Init failed");
1315
1316        // Test starts here
1317        let (jh, mut rx) = state_sync
1318            .start()
1319            .await
1320            .expect("Failed to start state synchronizer");
1321        tx.send(deltas[0].clone())
1322            .await
1323            .expect("deltas channel msg 0 closed!");
1324        let first_msg = timeout(Duration::from_millis(100), rx.recv())
1325            .await
1326            .expect("waiting for first state msg timed out!")
1327            .expect("state sync block sender closed!");
1328        tx.send(deltas[1].clone())
1329            .await
1330            .expect("deltas channel msg 1 closed!");
1331        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1332            .await
1333            .expect("waiting for second state msg timed out!")
1334            .expect("state sync block sender closed!");
1335        let _ = state_sync.close().await;
1336        let exit = jh
1337            .await
1338            .expect("state sync task panicked!");
1339
1340        // assertions
1341        let exp1 = StateSyncMessage {
1342            header: BlockHeader {
1343                number: 1,
1344                hash: Bytes::from("0x01"),
1345                parent_hash: Bytes::from("0x00"),
1346                revert: false,
1347                ..Default::default()
1348            },
1349            snapshots: Snapshot {
1350                states: [
1351                    (
1352                        "Component1".to_string(),
1353                        ComponentWithState {
1354                            state: ResponseProtocolState {
1355                                component_id: "Component1".to_string(),
1356                                ..Default::default()
1357                            },
1358                            component: ProtocolComponent {
1359                                id: "Component1".to_string(),
1360                                ..Default::default()
1361                            },
1362                            component_tvl: Some(100.0),
1363                            entrypoints: vec![],
1364                        },
1365                    ),
1366                    (
1367                        "Component2".to_string(),
1368                        ComponentWithState {
1369                            state: ResponseProtocolState {
1370                                component_id: "Component2".to_string(),
1371                                ..Default::default()
1372                            },
1373                            component: ProtocolComponent {
1374                                id: "Component2".to_string(),
1375                                ..Default::default()
1376                            },
1377                            component_tvl: Some(0.0),
1378                            entrypoints: vec![],
1379                        },
1380                    ),
1381                ]
1382                .into_iter()
1383                .collect(),
1384                vm_storage: HashMap::new(),
1385            },
1386            deltas: Some(deltas[0].clone()),
1387            removed_components: Default::default(),
1388        };
1389
1390        let exp2 = StateSyncMessage {
1391            header: BlockHeader {
1392                number: 2,
1393                hash: Bytes::from("0x02"),
1394                parent_hash: Bytes::from("0x01"),
1395                revert: false,
1396                ..Default::default()
1397            },
1398            snapshots: Snapshot {
1399                states: [
1400                    // This is the new component we queried once it passed the tvl threshold.
1401                    (
1402                        "Component3".to_string(),
1403                        ComponentWithState {
1404                            state: ResponseProtocolState {
1405                                component_id: "Component3".to_string(),
1406                                ..Default::default()
1407                            },
1408                            component: ProtocolComponent {
1409                                id: "Component3".to_string(),
1410                                ..Default::default()
1411                            },
1412                            component_tvl: Some(1000.0),
1413                            entrypoints: vec![],
1414                        },
1415                    ),
1416                ]
1417                .into_iter()
1418                .collect(),
1419                vm_storage: HashMap::new(),
1420            },
1421            // Our deltas are empty and since merge methods are
1422            // tested in tycho-common we don't have much to do here.
1423            deltas: Some(BlockChanges {
1424                extractor: "uniswap-v2".to_string(),
1425                chain: Chain::Ethereum,
1426                block: Block {
1427                    number: 2,
1428                    hash: Bytes::from("0x02"),
1429                    parent_hash: Bytes::from("0x01"),
1430                    chain: Chain::Ethereum,
1431                    ts: Default::default(),
1432                },
1433                revert: false,
1434                component_tvl: [
1435                    // "Component2" should not show here.
1436                    ("Component1".to_string(), 100.0),
1437                    ("Component3".to_string(), 1000.0),
1438                ]
1439                .into_iter()
1440                .collect(),
1441                ..Default::default()
1442            }),
1443            // "Component2" was removed, because its tvl changed to 0.
1444            removed_components: [(
1445                "Component2".to_string(),
1446                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1447            )]
1448            .into_iter()
1449            .collect(),
1450        };
1451        assert_eq!(first_msg, exp1);
1452        assert_eq!(second_msg, exp2);
1453        assert!(exit.is_ok());
1454    }
1455
1456    #[test(tokio::test)]
1457    async fn test_state_sync_with_tvl_range() {
1458        // Define the range for testing
1459        let remove_tvl_threshold = 5.0;
1460        let add_tvl_threshold = 7.0;
1461
1462        let mut rpc_client = MockRPCClient::new();
1463        let mut deltas_client = MockDeltasClient::new();
1464
1465        rpc_client
1466            .expect_get_protocol_components()
1467            .with(mockall::predicate::function(
1468                move |request_params: &ProtocolComponentsRequestBody| {
1469                    if let Some(ids) = request_params.component_ids.as_ref() {
1470                        ids.contains(&"Component3".to_string())
1471                    } else {
1472                        false
1473                    }
1474                },
1475            ))
1476            .returning(|_| {
1477                Ok(ProtocolComponentRequestResponse {
1478                    protocol_components: vec![ProtocolComponent {
1479                        id: "Component3".to_string(),
1480                        ..Default::default()
1481                    }],
1482                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1483                })
1484            });
1485        rpc_client
1486            .expect_get_protocol_states()
1487            .with(mockall::predicate::function(move |request_params: &ProtocolStateRequestBody| {
1488                let expected_id = "Component3".to_string();
1489                if let Some(ids) = request_params.protocol_ids.as_ref() {
1490                    ids.contains(&expected_id)
1491                } else {
1492                    false
1493                }
1494            }))
1495            .returning(|_| {
1496                Ok(ProtocolStateRequestResponse {
1497                    states: vec![ResponseProtocolState {
1498                        component_id: "Component3".to_string(),
1499                        ..Default::default()
1500                    }],
1501                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1502                })
1503            });
1504
1505        // Mock for the initial snapshot retrieval
1506        rpc_client
1507            .expect_get_protocol_components()
1508            .returning(|_| {
1509                Ok(ProtocolComponentRequestResponse {
1510                    protocol_components: vec![
1511                        ProtocolComponent { id: "Component1".to_string(), ..Default::default() },
1512                        ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1513                    ],
1514                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1515                })
1516            });
1517        rpc_client
1518            .expect_get_protocol_states()
1519            .returning(|_| {
1520                Ok(ProtocolStateRequestResponse {
1521                    states: vec![
1522                        ResponseProtocolState {
1523                            component_id: "Component1".to_string(),
1524                            ..Default::default()
1525                        },
1526                        ResponseProtocolState {
1527                            component_id: "Component2".to_string(),
1528                            ..Default::default()
1529                        },
1530                    ],
1531                    pagination: PaginationResponse { page: 0, page_size: 20, total: 1 },
1532                })
1533            });
1534        rpc_client
1535            .expect_get_traced_entry_points()
1536            .returning(|_| {
1537                Ok(TracedEntryPointRequestResponse {
1538                    traced_entry_points: HashMap::new(),
1539                    pagination: PaginationResponse::new(0, 20, 0),
1540                })
1541            });
1542
1543        rpc_client
1544            .expect_get_component_tvl()
1545            .returning(|_| {
1546                Ok(ComponentTvlRequestResponse {
1547                    tvl: [
1548                        ("Component1".to_string(), 6.0),
1549                        ("Component2".to_string(), 2.0),
1550                        ("Component3".to_string(), 10.0),
1551                    ]
1552                    .into_iter()
1553                    .collect(),
1554                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1555                })
1556            });
1557
1558        rpc_client
1559            .expect_get_component_tvl()
1560            .returning(|_| {
1561                Ok(ComponentTvlRequestResponse {
1562                    tvl: [
1563                        ("Component1".to_string(), 6.0),
1564                        ("Component2".to_string(), 2.0),
1565                        ("Component3".to_string(), 10.0),
1566                    ]
1567                    .into_iter()
1568                    .collect(),
1569                    pagination: PaginationResponse { page: 0, page_size: 20, total: 3 },
1570                })
1571            });
1572
1573        let (tx, rx) = channel(1);
1574        deltas_client
1575            .expect_subscribe()
1576            .return_once(move |_, _| Ok((Uuid::default(), rx)));
1577
1578        let mut state_sync = ProtocolStateSynchronizer::new(
1579            ExtractorIdentity::new(Chain::Ethereum, "uniswap-v2"),
1580            true,
1581            ComponentFilter::with_tvl_range(remove_tvl_threshold, add_tvl_threshold),
1582            1,
1583            true,
1584            true,
1585            ArcRPCClient(Arc::new(rpc_client)),
1586            ArcDeltasClient(Arc::new(deltas_client)),
1587            10_u64,
1588        );
1589        state_sync
1590            .initialize()
1591            .await
1592            .expect("Init failed");
1593
1594        // Simulate the incoming BlockChanges
1595        let deltas = [
1596            BlockChanges {
1597                extractor: "uniswap-v2".to_string(),
1598                chain: Chain::Ethereum,
1599                block: Block {
1600                    number: 1,
1601                    hash: Bytes::from("0x01"),
1602                    parent_hash: Bytes::from("0x00"),
1603                    chain: Chain::Ethereum,
1604                    ts: Default::default(),
1605                },
1606                revert: false,
1607                ..Default::default()
1608            },
1609            BlockChanges {
1610                extractor: "uniswap-v2".to_string(),
1611                chain: Chain::Ethereum,
1612                block: Block {
1613                    number: 2,
1614                    hash: Bytes::from("0x02"),
1615                    parent_hash: Bytes::from("0x01"),
1616                    chain: Chain::Ethereum,
1617                    ts: Default::default(),
1618                },
1619                revert: false,
1620                component_tvl: [
1621                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1622                    ("Component2".to_string(), 2.0), // Below lower threshold, should be removed
1623                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1624                ]
1625                .into_iter()
1626                .collect(),
1627                ..Default::default()
1628            },
1629        ];
1630
1631        let (jh, mut rx) = state_sync
1632            .start()
1633            .await
1634            .expect("Failed to start state synchronizer");
1635
1636        // Simulate sending delta messages
1637        tx.send(deltas[0].clone())
1638            .await
1639            .expect("deltas channel msg 0 closed!");
1640
1641        // Expecting to receive the initial state message
1642        let _ = timeout(Duration::from_millis(100), rx.recv())
1643            .await
1644            .expect("waiting for first state msg timed out!")
1645            .expect("state sync block sender closed!");
1646
1647        // Send the third message, which should trigger TVL-based changes
1648        tx.send(deltas[1].clone())
1649            .await
1650            .expect("deltas channel msg 1 closed!");
1651        let second_msg = timeout(Duration::from_millis(100), rx.recv())
1652            .await
1653            .expect("waiting for second state msg timed out!")
1654            .expect("state sync block sender closed!");
1655
1656        let _ = state_sync.close().await;
1657        let exit = jh
1658            .await
1659            .expect("state sync task panicked!");
1660
1661        let expected_second_msg = StateSyncMessage {
1662            header: BlockHeader {
1663                number: 2,
1664                hash: Bytes::from("0x02"),
1665                parent_hash: Bytes::from("0x01"),
1666                revert: false,
1667                ..Default::default()
1668            },
1669            snapshots: Snapshot {
1670                states: [(
1671                    "Component3".to_string(),
1672                    ComponentWithState {
1673                        state: ResponseProtocolState {
1674                            component_id: "Component3".to_string(),
1675                            ..Default::default()
1676                        },
1677                        component: ProtocolComponent {
1678                            id: "Component3".to_string(),
1679                            ..Default::default()
1680                        },
1681                        component_tvl: Some(10.0),
1682                        entrypoints: vec![], // TODO: add entrypoints?
1683                    },
1684                )]
1685                .into_iter()
1686                .collect(),
1687                vm_storage: HashMap::new(),
1688            },
1689            deltas: Some(BlockChanges {
1690                extractor: "uniswap-v2".to_string(),
1691                chain: Chain::Ethereum,
1692                block: Block {
1693                    number: 2,
1694                    hash: Bytes::from("0x02"),
1695                    parent_hash: Bytes::from("0x01"),
1696                    chain: Chain::Ethereum,
1697                    ts: Default::default(),
1698                },
1699                revert: false,
1700                component_tvl: [
1701                    ("Component1".to_string(), 6.0), // Within range, should not trigger changes
1702                    ("Component3".to_string(), 10.0), // Above upper threshold, should be added
1703                ]
1704                .into_iter()
1705                .collect(),
1706                ..Default::default()
1707            }),
1708            removed_components: [(
1709                "Component2".to_string(),
1710                ProtocolComponent { id: "Component2".to_string(), ..Default::default() },
1711            )]
1712            .into_iter()
1713            .collect(),
1714        };
1715
1716        assert_eq!(second_msg, expected_second_msg);
1717        assert!(exit.is_ok());
1718    }
1719}