tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{
7    collections::HashMap,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22    sync::{RwLock, Semaphore},
23    time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27    dto::{
28        BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29        EntryPointWithTracingParams, PaginationParams, PaginationResponse, ProtocolComponent,
30        ProtocolComponentRequestResponse, ProtocolComponentsRequestBody, ProtocolStateRequestBody,
31        ProtocolStateRequestResponse, ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse,
32        ResponseToken, StateRequestBody, StateRequestResponse, TokensRequestBody,
33        TokensRequestResponse, TracedEntryPointRequestBody, TracedEntryPointRequestResponse,
34        TracingResult, VersionParam,
35    },
36    models::ComponentId,
37    Bytes,
38};
39
40use crate::{
41    feed::synchronizer::{ComponentWithState, Snapshot},
42    TYCHO_SERVER_VERSION,
43};
44
45/// Request body for fetching a snapshot of protocol states and VM storage.
46///
47/// This struct helps to coordinate fetching  multiple pieces of related data
48/// (protocol states, contract storage, TVL, entry points).
49#[derive(Clone, Debug, PartialEq)]
50pub struct SnapshotParameters<'a> {
51    /// Which chain to fetch snapshots for
52    pub chain: Chain,
53    /// Protocol system name, required for correct state resolution
54    pub protocol_system: &'a str,
55    /// Components to fetch protocol states for
56    pub components: &'a HashMap<ComponentId, ProtocolComponent>,
57    /// Traced entry points data mapped by component id
58    pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
59    /// Contract addresses to fetch VM storage for
60    pub contract_ids: &'a [Bytes],
61    /// Block number for versioning
62    pub block_number: u64,
63    /// Whether to include balance information
64    pub include_balances: bool,
65    /// Whether to fetch TVL data
66    pub include_tvl: bool,
67}
68
69impl<'a> SnapshotParameters<'a> {
70    pub fn new(
71        chain: Chain,
72        protocol_system: &'a str,
73        components: &'a HashMap<ComponentId, ProtocolComponent>,
74        contract_ids: &'a [Bytes],
75        block_number: u64,
76    ) -> Self {
77        Self {
78            chain,
79            protocol_system,
80            components,
81            entrypoints: None,
82            contract_ids,
83            block_number,
84            include_balances: true,
85            include_tvl: true,
86        }
87    }
88
89    /// Set whether to include balance information (default: true)
90    pub fn include_balances(mut self, include_balances: bool) -> Self {
91        self.include_balances = include_balances;
92        self
93    }
94
95    /// Set whether to fetch TVL data (default: true)
96    pub fn include_tvl(mut self, include_tvl: bool) -> Self {
97        self.include_tvl = include_tvl;
98        self
99    }
100
101    pub fn entrypoints(
102        mut self,
103        entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
104    ) -> Self {
105        self.entrypoints = Some(entrypoints);
106        self
107    }
108}
109
110#[derive(Error, Debug)]
111pub enum RPCError {
112    /// The passed tycho url failed to parse.
113    #[error("Failed to parse URL: {0}. Error: {1}")]
114    UrlParsing(String, String),
115
116    /// The request data is not correctly formed.
117    #[error("Failed to format request: {0}")]
118    FormatRequest(String),
119
120    /// Errors forwarded from the HTTP protocol.
121    #[error("Unexpected HTTP client error: {0}")]
122    HttpClient(String, #[source] reqwest::Error),
123
124    /// The response from the server could not be parsed correctly.
125    #[error("Failed to parse response: {0}")]
126    ParseResponse(String),
127
128    /// Other fatal errors.
129    #[error("Fatal error: {0}")]
130    Fatal(String),
131
132    #[error("Rate limited until {0:?}")]
133    RateLimited(Option<SystemTime>),
134
135    #[error("Server unreachable: {0}")]
136    ServerUnreachable(String),
137}
138
139#[cfg_attr(test, automock)]
140#[async_trait]
141pub trait RPCClient: Send + Sync {
142    /// Retrieves a snapshot of contract state.
143    async fn get_contract_state(
144        &self,
145        request: &StateRequestBody,
146    ) -> Result<StateRequestResponse, RPCError>;
147
148    async fn get_contract_state_paginated(
149        &self,
150        chain: Chain,
151        ids: &[Bytes],
152        protocol_system: &str,
153        version: &VersionParam,
154        chunk_size: usize,
155        concurrency: usize,
156    ) -> Result<StateRequestResponse, RPCError> {
157        let semaphore = Arc::new(Semaphore::new(concurrency));
158
159        // Sort the ids to maximize server-side cache hits
160        let mut sorted_ids = ids.to_vec();
161        sorted_ids.sort();
162
163        let chunked_bodies = sorted_ids
164            .chunks(chunk_size)
165            .map(|chunk| StateRequestBody {
166                contract_ids: Some(chunk.to_vec()),
167                protocol_system: protocol_system.to_string(),
168                chain,
169                version: version.clone(),
170                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
171            })
172            .collect::<Vec<_>>();
173
174        let mut tasks = Vec::new();
175        for body in chunked_bodies.iter() {
176            let sem = semaphore.clone();
177            tasks.push(async move {
178                let _permit = sem
179                    .acquire()
180                    .await
181                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
182                self.get_contract_state(body).await
183            });
184        }
185
186        // Execute all tasks concurrently with the defined concurrency limit.
187        let responses = try_join_all(tasks).await?;
188
189        // Aggregate the responses into a single result.
190        let accounts = responses
191            .iter()
192            .flat_map(|r| r.accounts.clone())
193            .collect();
194        let total: i64 = responses
195            .iter()
196            .map(|r| r.pagination.total)
197            .sum();
198
199        Ok(StateRequestResponse {
200            accounts,
201            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
202        })
203    }
204
205    async fn get_protocol_components(
206        &self,
207        request: &ProtocolComponentsRequestBody,
208    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
209
210    async fn get_protocol_components_paginated(
211        &self,
212        request: &ProtocolComponentsRequestBody,
213        chunk_size: usize,
214        concurrency: usize,
215    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
216        let semaphore = Arc::new(Semaphore::new(concurrency));
217
218        // If a set of component IDs is specified, the maximum return size is already known,
219        // allowing us to pre-compute the number of requests to be made.
220        match request.component_ids {
221            Some(ref ids) => {
222                // We can divide the component_ids into chunks of size chunk_size
223                let chunked_bodies = ids
224                    .chunks(chunk_size)
225                    .enumerate()
226                    .map(|(index, _)| ProtocolComponentsRequestBody {
227                        protocol_system: request.protocol_system.clone(),
228                        component_ids: request.component_ids.clone(),
229                        tvl_gt: request.tvl_gt,
230                        chain: request.chain,
231                        pagination: PaginationParams {
232                            page: index as i64,
233                            page_size: chunk_size as i64,
234                        },
235                    })
236                    .collect::<Vec<_>>();
237
238                let mut tasks = Vec::new();
239                for body in chunked_bodies.iter() {
240                    let sem = semaphore.clone();
241                    tasks.push(async move {
242                        let _permit = sem
243                            .acquire()
244                            .await
245                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
246                        self.get_protocol_components(body).await
247                    });
248                }
249
250                try_join_all(tasks)
251                    .await
252                    .map(|responses| ProtocolComponentRequestResponse {
253                        protocol_components: responses
254                            .into_iter()
255                            .flat_map(|r| r.protocol_components.into_iter())
256                            .collect(),
257                        pagination: PaginationResponse {
258                            page: 0,
259                            page_size: chunk_size as i64,
260                            total: ids.len() as i64,
261                        },
262                    })
263            }
264            _ => {
265                // If no component ids are specified, we need to make requests based on the total
266                // number of results from the first response.
267
268                let initial_request = ProtocolComponentsRequestBody {
269                    protocol_system: request.protocol_system.clone(),
270                    component_ids: request.component_ids.clone(),
271                    tvl_gt: request.tvl_gt,
272                    chain: request.chain,
273                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
274                };
275                let first_response = self
276                    .get_protocol_components(&initial_request)
277                    .await
278                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
279
280                let total_items = first_response.pagination.total;
281                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
282
283                // Initialize the final response accumulator
284                let mut accumulated_response = ProtocolComponentRequestResponse {
285                    protocol_components: first_response.protocol_components,
286                    pagination: PaginationResponse {
287                        page: 0,
288                        page_size: chunk_size as i64,
289                        total: total_items,
290                    },
291                };
292
293                let mut page = 1;
294                while page < total_pages {
295                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
296
297                    // Create request bodies for parallel requests, respecting the concurrency limit
298                    let chunked_bodies = (0..requests_in_this_iteration)
299                        .map(|iter| ProtocolComponentsRequestBody {
300                            protocol_system: request.protocol_system.clone(),
301                            component_ids: request.component_ids.clone(),
302                            tvl_gt: request.tvl_gt,
303                            chain: request.chain,
304                            pagination: PaginationParams {
305                                page: page + iter,
306                                page_size: chunk_size as i64,
307                            },
308                        })
309                        .collect::<Vec<_>>();
310
311                    let tasks: Vec<_> = chunked_bodies
312                        .iter()
313                        .map(|body| {
314                            let sem = semaphore.clone();
315                            async move {
316                                let _permit = sem.acquire().await.map_err(|_| {
317                                    RPCError::Fatal("Semaphore dropped".to_string())
318                                })?;
319                                self.get_protocol_components(body).await
320                            }
321                        })
322                        .collect();
323
324                    let responses = try_join_all(tasks)
325                        .await
326                        .map(|responses| {
327                            let total = responses[0].pagination.total;
328                            ProtocolComponentRequestResponse {
329                                protocol_components: responses
330                                    .into_iter()
331                                    .flat_map(|r| r.protocol_components.into_iter())
332                                    .collect(),
333                                pagination: PaginationResponse {
334                                    page,
335                                    page_size: chunk_size as i64,
336                                    total,
337                                },
338                            }
339                        });
340
341                    // Update the accumulated response or set the initial response
342                    match responses {
343                        Ok(mut resp) => {
344                            accumulated_response
345                                .protocol_components
346                                .append(&mut resp.protocol_components);
347                        }
348                        Err(e) => return Err(e),
349                    }
350
351                    page += concurrency as i64;
352                }
353                Ok(accumulated_response)
354            }
355        }
356    }
357
358    async fn get_protocol_states(
359        &self,
360        request: &ProtocolStateRequestBody,
361    ) -> Result<ProtocolStateRequestResponse, RPCError>;
362
363    #[allow(clippy::too_many_arguments)]
364    async fn get_protocol_states_paginated<T>(
365        &self,
366        chain: Chain,
367        ids: &[T],
368        protocol_system: &str,
369        include_balances: bool,
370        version: &VersionParam,
371        chunk_size: usize,
372        concurrency: usize,
373    ) -> Result<ProtocolStateRequestResponse, RPCError>
374    where
375        T: AsRef<str> + Sync + 'static,
376    {
377        let semaphore = Arc::new(Semaphore::new(concurrency));
378        let chunked_bodies = ids
379            .chunks(chunk_size)
380            .map(|c| ProtocolStateRequestBody {
381                protocol_ids: Some(
382                    c.iter()
383                        .map(|id| id.as_ref().to_string())
384                        .collect(),
385                ),
386                protocol_system: protocol_system.to_string(),
387                chain,
388                include_balances,
389                version: version.clone(),
390                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
391            })
392            .collect::<Vec<_>>();
393
394        let mut tasks = Vec::new();
395        for body in chunked_bodies.iter() {
396            let sem = semaphore.clone();
397            tasks.push(async move {
398                let _permit = sem
399                    .acquire()
400                    .await
401                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
402                self.get_protocol_states(body).await
403            });
404        }
405
406        try_join_all(tasks)
407            .await
408            .map(|responses| {
409                let states = responses
410                    .clone()
411                    .into_iter()
412                    .flat_map(|r| r.states)
413                    .collect();
414                let total = responses
415                    .iter()
416                    .map(|r| r.pagination.total)
417                    .sum();
418                ProtocolStateRequestResponse {
419                    states,
420                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
421                }
422            })
423    }
424
425    /// This function returns only one chunk of tokens. To get all tokens please call
426    /// get_all_tokens.
427    async fn get_tokens(
428        &self,
429        request: &TokensRequestBody,
430    ) -> Result<TokensRequestResponse, RPCError>;
431
432    async fn get_all_tokens(
433        &self,
434        chain: Chain,
435        min_quality: Option<i32>,
436        traded_n_days_ago: Option<u64>,
437        chunk_size: usize,
438    ) -> Result<Vec<ResponseToken>, RPCError> {
439        let mut request_page = 0;
440        let mut all_tokens = Vec::new();
441        loop {
442            let mut response = self
443                .get_tokens(&TokensRequestBody {
444                    token_addresses: None,
445                    min_quality,
446                    traded_n_days_ago,
447                    pagination: PaginationParams {
448                        page: request_page,
449                        page_size: chunk_size.try_into().map_err(|_| {
450                            RPCError::FormatRequest(
451                                "Failed to convert chunk_size into i64".to_string(),
452                            )
453                        })?,
454                    },
455                    chain,
456                })
457                .await?;
458
459            let num_tokens = response.tokens.len();
460            all_tokens.append(&mut response.tokens);
461            request_page += 1;
462
463            if num_tokens < chunk_size {
464                break;
465            }
466        }
467        Ok(all_tokens)
468    }
469
470    async fn get_protocol_systems(
471        &self,
472        request: &ProtocolSystemsRequestBody,
473    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
474
475    async fn get_component_tvl(
476        &self,
477        request: &ComponentTvlRequestBody,
478    ) -> Result<ComponentTvlRequestResponse, RPCError>;
479
480    async fn get_component_tvl_paginated(
481        &self,
482        request: &ComponentTvlRequestBody,
483        chunk_size: usize,
484        concurrency: usize,
485    ) -> Result<ComponentTvlRequestResponse, RPCError> {
486        let semaphore = Arc::new(Semaphore::new(concurrency));
487
488        match request.component_ids {
489            Some(ref ids) => {
490                let chunked_requests = ids
491                    .chunks(chunk_size)
492                    .enumerate()
493                    .map(|(index, _)| ComponentTvlRequestBody {
494                        chain: request.chain,
495                        protocol_system: request.protocol_system.clone(),
496                        component_ids: Some(ids.clone()),
497                        pagination: PaginationParams {
498                            page: index as i64,
499                            page_size: chunk_size as i64,
500                        },
501                    })
502                    .collect::<Vec<_>>();
503
504                let tasks: Vec<_> = chunked_requests
505                    .into_iter()
506                    .map(|req| {
507                        let sem = semaphore.clone();
508                        async move {
509                            let _permit = sem
510                                .acquire()
511                                .await
512                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
513                            self.get_component_tvl(&req).await
514                        }
515                    })
516                    .collect();
517
518                let responses = try_join_all(tasks).await?;
519
520                let mut merged_tvl = HashMap::new();
521                for resp in responses {
522                    for (key, value) in resp.tvl {
523                        *merged_tvl.entry(key).or_insert(0.0) = value;
524                    }
525                }
526
527                Ok(ComponentTvlRequestResponse {
528                    tvl: merged_tvl,
529                    pagination: PaginationResponse {
530                        page: 0,
531                        page_size: chunk_size as i64,
532                        total: ids.len() as i64,
533                    },
534                })
535            }
536            _ => {
537                let first_request = ComponentTvlRequestBody {
538                    chain: request.chain,
539                    protocol_system: request.protocol_system.clone(),
540                    component_ids: request.component_ids.clone(),
541                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
542                };
543
544                let first_response = self
545                    .get_component_tvl(&first_request)
546                    .await?;
547                let total_items = first_response.pagination.total;
548                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
549
550                let mut merged_tvl = first_response.tvl;
551
552                let mut page = 1;
553                while page < total_pages {
554                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
555
556                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
557                        .map(|i| ComponentTvlRequestBody {
558                            chain: request.chain,
559                            protocol_system: request.protocol_system.clone(),
560                            component_ids: request.component_ids.clone(),
561                            pagination: PaginationParams {
562                                page: page + i,
563                                page_size: chunk_size as i64,
564                            },
565                        })
566                        .collect();
567
568                    let tasks: Vec<_> = chunked_requests
569                        .into_iter()
570                        .map(|req| {
571                            let sem = semaphore.clone();
572                            async move {
573                                let _permit = sem.acquire().await.map_err(|_| {
574                                    RPCError::Fatal("Semaphore dropped".to_string())
575                                })?;
576                                self.get_component_tvl(&req).await
577                            }
578                        })
579                        .collect();
580
581                    let responses = try_join_all(tasks).await?;
582
583                    // merge hashmap
584                    for resp in responses {
585                        for (key, value) in resp.tvl {
586                            *merged_tvl.entry(key).or_insert(0.0) += value;
587                        }
588                    }
589
590                    page += concurrency as i64;
591                }
592
593                Ok(ComponentTvlRequestResponse {
594                    tvl: merged_tvl,
595                    pagination: PaginationResponse {
596                        page: 0,
597                        page_size: chunk_size as i64,
598                        total: total_items,
599                    },
600                })
601            }
602        }
603    }
604
605    async fn get_traced_entry_points(
606        &self,
607        request: &TracedEntryPointRequestBody,
608    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
609
610    async fn get_traced_entry_points_paginated(
611        &self,
612        chain: Chain,
613        protocol_system: &str,
614        component_ids: &[String],
615        chunk_size: usize,
616        concurrency: usize,
617    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
618        let semaphore = Arc::new(Semaphore::new(concurrency));
619        let chunked_bodies = component_ids
620            .chunks(chunk_size)
621            .map(|c| TracedEntryPointRequestBody {
622                chain,
623                protocol_system: protocol_system.to_string(),
624                component_ids: Some(c.to_vec()),
625                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
626            })
627            .collect::<Vec<_>>();
628
629        let mut tasks = Vec::new();
630        for body in chunked_bodies.iter() {
631            let sem = semaphore.clone();
632            tasks.push(async move {
633                let _permit = sem
634                    .acquire()
635                    .await
636                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
637                self.get_traced_entry_points(body).await
638            });
639        }
640
641        try_join_all(tasks)
642            .await
643            .map(|responses| {
644                let traced_entry_points = responses
645                    .clone()
646                    .into_iter()
647                    .flat_map(|r| r.traced_entry_points)
648                    .collect();
649                let total = responses
650                    .iter()
651                    .map(|r| r.pagination.total)
652                    .sum();
653                TracedEntryPointRequestResponse {
654                    traced_entry_points,
655                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
656                }
657            })
658    }
659
660    async fn get_snapshots<'a>(
661        &self,
662        request: &SnapshotParameters<'a>,
663        chunk_size: usize,
664        concurrency: usize,
665    ) -> Result<Snapshot, RPCError>;
666}
667
668/// Configuration options for HttpRPCClient
669#[derive(Debug, Clone)]
670pub struct HttpRPCClientOptions {
671    /// Optional API key for authentication
672    pub auth_key: Option<String>,
673    /// Enable compression for requests (default: true)
674    /// When enabled, adds Accept-Encoding: zstd header
675    pub compression: bool,
676}
677
678impl Default for HttpRPCClientOptions {
679    fn default() -> Self {
680        Self::new()
681    }
682}
683
684impl HttpRPCClientOptions {
685    /// Create new options with default values (compression enabled)
686    pub fn new() -> Self {
687        Self { auth_key: None, compression: true }
688    }
689
690    /// Set the authentication key
691    pub fn with_auth_key(mut self, auth_key: Option<String>) -> Self {
692        self.auth_key = auth_key;
693        self
694    }
695
696    /// Set whether to enable compression (default: true)
697    pub fn with_compression(mut self, compression: bool) -> Self {
698        self.compression = compression;
699        self
700    }
701}
702
703#[derive(Debug, Clone)]
704pub struct HttpRPCClient {
705    http_client: Client,
706    url: Url,
707    retry_after: Arc<RwLock<Option<SystemTime>>>,
708    backoff_policy: ExponentialBackoff,
709    server_restart_duration: Duration,
710}
711
712impl HttpRPCClient {
713    pub fn new(base_uri: &str, options: HttpRPCClientOptions) -> Result<Self, RPCError> {
714        let uri = base_uri
715            .parse::<Url>()
716            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
717
718        // Add default headers
719        let mut headers = header::HeaderMap::new();
720        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
721        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
722        headers.insert(
723            header::USER_AGENT,
724            header::HeaderValue::from_str(&user_agent)
725                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
726        );
727
728        // Add Authorization if one is given
729        if let Some(key) = options.auth_key.as_deref() {
730            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
731                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
732            })?;
733            auth_value.set_sensitive(true);
734            headers.insert(header::AUTHORIZATION, auth_value);
735        }
736
737        let mut client_builder = ClientBuilder::new()
738            .default_headers(headers)
739            .http2_prior_knowledge();
740
741        // When compression is disabled, turn off all automatic compression
742        if !options.compression {
743            client_builder = client_builder.no_zstd();
744        }
745
746        let client = client_builder
747            .build()
748            .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
749
750        Ok(Self {
751            http_client: client,
752            url: uri,
753            retry_after: Arc::new(RwLock::new(None)),
754            backoff_policy: ExponentialBackoffBuilder::new()
755                .with_initial_interval(Duration::from_millis(250))
756                // increase backoff time by 75% each failure
757                .with_multiplier(1.75)
758                // keep retrying every 30s
759                .with_max_interval(Duration::from_secs(30))
760                // if all retries take longer than 2m, give up
761                .with_max_elapsed_time(Some(Duration::from_secs(125)))
762                .build(),
763            server_restart_duration: Duration::from_secs(120),
764        })
765    }
766
767    #[cfg(test)]
768    pub fn with_test_backoff_policy(mut self) -> Self {
769        // Extremely short intervals for very fast testing
770        self.backoff_policy = ExponentialBackoffBuilder::new()
771            .with_initial_interval(Duration::from_millis(1))
772            .with_multiplier(1.1)
773            .with_max_interval(Duration::from_millis(5))
774            .with_max_elapsed_time(Some(Duration::from_millis(50)))
775            .build();
776        self.server_restart_duration = Duration::from_millis(50);
777        self
778    }
779
780    /// Converts a error response to a Result.
781    ///
782    /// Raises an error if the response status code id 429, 502, 503 or 504. In the 429
783    /// case it will try to look for a retry-after header an parse it accordingly. The
784    /// parsed value is then passed as part of the error.
785    async fn error_for_response(
786        &self,
787        response: reqwest::Response,
788    ) -> Result<reqwest::Response, RPCError> {
789        match response.status() {
790            StatusCode::TOO_MANY_REQUESTS => {
791                let retry_after_raw = response
792                    .headers()
793                    .get(reqwest::header::RETRY_AFTER)
794                    .and_then(|h| h.to_str().ok())
795                    .and_then(parse_retry_value);
796
797                Err(RPCError::RateLimited(retry_after_raw))
798            }
799            StatusCode::BAD_GATEWAY |
800            StatusCode::SERVICE_UNAVAILABLE |
801            StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
802                response
803                    .text()
804                    .await
805                    .unwrap_or_else(|_| "Server Unreachable".to_string()),
806            )),
807            _ => Ok(response),
808        }
809    }
810
811    /// Classifies errors into transient or permanent ones.
812    ///
813    /// Transient errors are retried with a potential backoff, permanent ones are not.
814    /// If the error is RateLimited, this method will set the self.retry_after value so
815    /// future requests wait until the rate limit has been reset.
816    async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
817        match e {
818            RPCError::ServerUnreachable(_) => {
819                backoff::Error::retry_after(e, self.server_restart_duration)
820            }
821            RPCError::RateLimited(Some(until)) => {
822                let mut retry_after_guard = self.retry_after.write().await;
823                *retry_after_guard = Some(
824                    retry_after_guard
825                        .unwrap_or(until)
826                        .max(until),
827                );
828
829                if let Ok(duration) = until.duration_since(SystemTime::now()) {
830                    backoff::Error::retry_after(e, duration)
831                } else {
832                    e.into()
833                }
834            }
835            RPCError::RateLimited(None) => e.into(),
836            _ => backoff::Error::permanent(e),
837        }
838    }
839
840    /// Waits until the current rate limit time has passed.
841    ///
842    /// Only waits if there is a time and that time is in the future, else return
843    /// immediately.
844    async fn wait_until_retry_after(&self) {
845        if let Some(&until) = self.retry_after.read().await.as_ref() {
846            let now = SystemTime::now();
847            if until > now {
848                if let Ok(duration) = until.duration_since(now) {
849                    sleep(duration).await
850                }
851            }
852        }
853    }
854
855    /// Makes a post request handling transient failures.
856    ///
857    /// If a retry-after header is received it will be respected. Else the configured
858    /// backoff policy is used to deal with transient network or server errors.
859    async fn make_post_request<T: Serialize + ?Sized>(
860        &self,
861        request: &T,
862        uri: &String,
863    ) -> Result<Response, RPCError> {
864        self.wait_until_retry_after().await;
865        let response = backoff::future::retry(self.backoff_policy.clone(), || async {
866            let server_response = self
867                .http_client
868                .post(uri)
869                .json(request)
870                .send()
871                .await
872                .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
873
874            match self
875                .error_for_response(server_response)
876                .await
877            {
878                Ok(response) => Ok(response),
879                Err(e) => Err(self.handle_error_for_backoff(e).await),
880            }
881        })
882        .await?;
883        Ok(response)
884    }
885}
886
887fn parse_retry_value(val: &str) -> Option<SystemTime> {
888    if let Ok(secs) = val.parse::<u64>() {
889        return Some(SystemTime::now() + Duration::from_secs(secs));
890    }
891    if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
892        return Some(date.into());
893    }
894    None
895}
896
897#[async_trait]
898impl RPCClient for HttpRPCClient {
899    #[instrument(skip(self, request))]
900    async fn get_contract_state(
901        &self,
902        request: &StateRequestBody,
903    ) -> Result<StateRequestResponse, RPCError> {
904        // Check if contract ids are specified
905        if request
906            .contract_ids
907            .as_ref()
908            .is_none_or(|ids| ids.is_empty())
909        {
910            warn!("No contract ids specified in request.");
911        }
912
913        let uri = format!(
914            "{}/{}/contract_state",
915            self.url
916                .to_string()
917                .trim_end_matches('/'),
918            TYCHO_SERVER_VERSION
919        );
920        debug!(%uri, "Sending contract_state request to Tycho server");
921        trace!(?request, "Sending request to Tycho server");
922        let response = self
923            .make_post_request(request, &uri)
924            .await?;
925        trace!(?response, "Received response from Tycho server");
926
927        let body = response
928            .text()
929            .await
930            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
931        if body.is_empty() {
932            // Pure native protocols will return empty contract states
933            return Ok(StateRequestResponse {
934                accounts: vec![],
935                pagination: PaginationResponse {
936                    page: request.pagination.page,
937                    page_size: request.pagination.page,
938                    total: 0,
939                },
940            });
941        }
942
943        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
944            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
945        trace!(?accounts, "Received contract_state response from Tycho server");
946
947        Ok(accounts)
948    }
949
950    async fn get_protocol_components(
951        &self,
952        request: &ProtocolComponentsRequestBody,
953    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
954        let uri = format!(
955            "{}/{}/protocol_components",
956            self.url
957                .to_string()
958                .trim_end_matches('/'),
959            TYCHO_SERVER_VERSION,
960        );
961        debug!(%uri, "Sending protocol_components request to Tycho server");
962        trace!(?request, "Sending request to Tycho server");
963
964        let response = self
965            .make_post_request(request, &uri)
966            .await?;
967
968        trace!(?response, "Received response from Tycho server");
969
970        let body = response
971            .text()
972            .await
973            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
974        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
975            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
976        trace!(?components, "Received protocol_components response from Tycho server");
977
978        Ok(components)
979    }
980
981    async fn get_protocol_states(
982        &self,
983        request: &ProtocolStateRequestBody,
984    ) -> Result<ProtocolStateRequestResponse, RPCError> {
985        // Check if protocol ids are specified
986        if request
987            .protocol_ids
988            .as_ref()
989            .is_none_or(|ids| ids.is_empty())
990        {
991            warn!("No protocol ids specified in request.");
992        }
993
994        let uri = format!(
995            "{}/{}/protocol_state",
996            self.url
997                .to_string()
998                .trim_end_matches('/'),
999            TYCHO_SERVER_VERSION
1000        );
1001        debug!(%uri, "Sending protocol_states request to Tycho server");
1002        trace!(?request, "Sending request to Tycho server");
1003
1004        let response = self
1005            .make_post_request(request, &uri)
1006            .await?;
1007        trace!(?response, "Received response from Tycho server");
1008
1009        let body = response
1010            .text()
1011            .await
1012            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1013
1014        if body.is_empty() {
1015            // Pure VM protocols will return empty states
1016            return Ok(ProtocolStateRequestResponse {
1017                states: vec![],
1018                pagination: PaginationResponse {
1019                    page: request.pagination.page,
1020                    page_size: request.pagination.page_size,
1021                    total: 0,
1022                },
1023            });
1024        }
1025
1026        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
1027            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1028        trace!(?states, "Received protocol_states response from Tycho server");
1029
1030        Ok(states)
1031    }
1032
1033    async fn get_tokens(
1034        &self,
1035        request: &TokensRequestBody,
1036    ) -> Result<TokensRequestResponse, RPCError> {
1037        let uri = format!(
1038            "{}/{}/tokens",
1039            self.url
1040                .to_string()
1041                .trim_end_matches('/'),
1042            TYCHO_SERVER_VERSION
1043        );
1044        debug!(%uri, "Sending tokens request to Tycho server");
1045
1046        let response = self
1047            .make_post_request(request, &uri)
1048            .await?;
1049
1050        let body = response
1051            .text()
1052            .await
1053            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1054        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1055            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1056
1057        Ok(tokens)
1058    }
1059
1060    async fn get_protocol_systems(
1061        &self,
1062        request: &ProtocolSystemsRequestBody,
1063    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1064        let uri = format!(
1065            "{}/{}/protocol_systems",
1066            self.url
1067                .to_string()
1068                .trim_end_matches('/'),
1069            TYCHO_SERVER_VERSION
1070        );
1071        debug!(%uri, "Sending protocol_systems request to Tycho server");
1072        trace!(?request, "Sending request to Tycho server");
1073        let response = self
1074            .make_post_request(request, &uri)
1075            .await?;
1076        trace!(?response, "Received response from Tycho server");
1077        let body = response
1078            .text()
1079            .await
1080            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1081        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1082            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1083        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1084        Ok(protocol_systems)
1085    }
1086
1087    async fn get_component_tvl(
1088        &self,
1089        request: &ComponentTvlRequestBody,
1090    ) -> Result<ComponentTvlRequestResponse, RPCError> {
1091        let uri = format!(
1092            "{}/{}/component_tvl",
1093            self.url
1094                .to_string()
1095                .trim_end_matches('/'),
1096            TYCHO_SERVER_VERSION
1097        );
1098        debug!(%uri, "Sending get_component_tvl request to Tycho server");
1099        trace!(?request, "Sending request to Tycho server");
1100        let response = self
1101            .make_post_request(request, &uri)
1102            .await?;
1103        trace!(?response, "Received response from Tycho server");
1104        let body = response
1105            .text()
1106            .await
1107            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1108        let component_tvl =
1109            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1110                error!("Failed to parse component_tvl response: {:?}", &body);
1111                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1112            })?;
1113        trace!(?component_tvl, "Received component_tvl response from Tycho server");
1114        Ok(component_tvl)
1115    }
1116
1117    async fn get_traced_entry_points(
1118        &self,
1119        request: &TracedEntryPointRequestBody,
1120    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1121        let uri = format!(
1122            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1123            self.url
1124                .to_string()
1125                .trim_end_matches('/')
1126        );
1127        debug!(%uri, "Sending traced_entry_points request to Tycho server");
1128        trace!(?request, "Sending request to Tycho server");
1129
1130        let response = self
1131            .make_post_request(request, &uri)
1132            .await?;
1133
1134        trace!(?response, "Received response from Tycho server");
1135
1136        let body = response
1137            .text()
1138            .await
1139            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1140        let entrypoints =
1141            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1142                error!("Failed to parse traced_entry_points response: {:?}", &body);
1143                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1144            })?;
1145        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1146        Ok(entrypoints)
1147    }
1148
1149    async fn get_snapshots<'a>(
1150        &self,
1151        request: &SnapshotParameters<'a>,
1152        chunk_size: usize,
1153        concurrency: usize,
1154    ) -> Result<Snapshot, RPCError> {
1155        let component_ids: Vec<_> = request
1156            .components
1157            .keys()
1158            .cloned()
1159            .collect();
1160
1161        let version = VersionParam::new(
1162            None,
1163            Some({
1164                #[allow(deprecated)]
1165                BlockParam {
1166                    hash: None,
1167                    chain: Some(request.chain),
1168                    number: Some(request.block_number as i64),
1169                }
1170            }),
1171        );
1172
1173        let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1174            let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1175            self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1176                .await?
1177                .tvl
1178        } else {
1179            HashMap::new()
1180        };
1181
1182        let mut protocol_states = if !component_ids.is_empty() {
1183            self.get_protocol_states_paginated(
1184                request.chain,
1185                &component_ids,
1186                request.protocol_system,
1187                request.include_balances,
1188                &version,
1189                chunk_size,
1190                concurrency,
1191            )
1192            .await?
1193            .states
1194            .into_iter()
1195            .map(|state| (state.component_id.clone(), state))
1196            .collect()
1197        } else {
1198            HashMap::new()
1199        };
1200
1201        // Convert to ComponentWithState, which includes entrypoint information.
1202        let states = request
1203            .components
1204            .values()
1205            .filter_map(|component| {
1206                if let Some(state) = protocol_states.remove(&component.id) {
1207                    Some((
1208                        component.id.clone(),
1209                        ComponentWithState {
1210                            state,
1211                            component: component.clone(),
1212                            component_tvl: component_tvl
1213                                .get(&component.id)
1214                                .cloned(),
1215                            entrypoints: request
1216                                .entrypoints
1217                                .as_ref()
1218                                .and_then(|map| map.get(&component.id))
1219                                .cloned()
1220                                .unwrap_or_default(),
1221                        },
1222                    ))
1223                } else if component_ids.contains(&component.id) {
1224                    // only emit error event if we requested this component
1225                    let component_id = &component.id;
1226                    error!(?component_id, "Missing state for native component!");
1227                    None
1228                } else {
1229                    None
1230                }
1231            })
1232            .collect();
1233
1234        let vm_storage = if !request.contract_ids.is_empty() {
1235            let contract_states = self
1236                .get_contract_state_paginated(
1237                    request.chain,
1238                    request.contract_ids,
1239                    request.protocol_system,
1240                    &version,
1241                    chunk_size,
1242                    concurrency,
1243                )
1244                .await?
1245                .accounts
1246                .into_iter()
1247                .map(|acc| (acc.address.clone(), acc))
1248                .collect::<HashMap<_, _>>();
1249
1250            trace!(states=?&contract_states, "Retrieved ContractState");
1251
1252            let contract_address_to_components = request
1253                .components
1254                .iter()
1255                .filter_map(|(id, comp)| {
1256                    if component_ids.contains(id) {
1257                        Some(
1258                            comp.contract_ids
1259                                .iter()
1260                                .map(|address| (address.clone(), comp.id.clone())),
1261                        )
1262                    } else {
1263                        None
1264                    }
1265                })
1266                .flatten()
1267                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1268                    acc.entry(addr).or_default().push(c_id);
1269                    acc
1270                });
1271
1272            request
1273                .contract_ids
1274                .iter()
1275                .filter_map(|address| {
1276                    if let Some(state) = contract_states.get(address) {
1277                        Some((address.clone(), state.clone()))
1278                    } else if let Some(ids) = contract_address_to_components.get(address) {
1279                        // only emit error even if we did actually request this address
1280                        error!(
1281                            ?address,
1282                            ?ids,
1283                            "Component with lacking contract storage encountered!"
1284                        );
1285                        None
1286                    } else {
1287                        None
1288                    }
1289                })
1290                .collect()
1291        } else {
1292            HashMap::new()
1293        };
1294
1295        Ok(Snapshot { states, vm_storage })
1296    }
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301    use std::{
1302        collections::{HashMap, HashSet},
1303        str::FromStr,
1304    };
1305
1306    use mockito::Server;
1307    use rstest::rstest;
1308    // TODO: remove once deprecated ProtocolId struct is removed
1309    #[allow(deprecated)]
1310    use tycho_common::dto::ProtocolId;
1311    use tycho_common::dto::{AddressStorageLocation, TracingParams};
1312
1313    use super::*;
1314
1315    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
1316    // purposes
1317    impl MockRPCClient {
1318        #[allow(clippy::too_many_arguments)]
1319        async fn test_get_protocol_states_paginated<T>(
1320            &self,
1321            chain: Chain,
1322            ids: &[T],
1323            protocol_system: &str,
1324            include_balances: bool,
1325            version: &VersionParam,
1326            chunk_size: usize,
1327            _concurrency: usize,
1328        ) -> Vec<ProtocolStateRequestBody>
1329        where
1330            T: AsRef<str> + Clone + Send + Sync + 'static,
1331        {
1332            ids.chunks(chunk_size)
1333                .map(|chunk| ProtocolStateRequestBody {
1334                    protocol_ids: Some(
1335                        chunk
1336                            .iter()
1337                            .map(|id| id.as_ref().to_string())
1338                            .collect(),
1339                    ),
1340                    protocol_system: protocol_system.to_string(),
1341                    chain,
1342                    include_balances,
1343                    version: version.clone(),
1344                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1345                })
1346                .collect()
1347        }
1348    }
1349
1350    const GET_CONTRACT_STATE_RESP: &str = r#"
1351        {
1352            "accounts": [
1353                {
1354                    "chain": "ethereum",
1355                    "address": "0x0000000000000000000000000000000000000000",
1356                    "title": "",
1357                    "slots": {},
1358                    "native_balance": "0x01f4",
1359                    "token_balances": {},
1360                    "code": "0x00",
1361                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1362                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1363                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1364                    "creation_tx": null
1365                }
1366            ],
1367            "pagination": {
1368                "page": 0,
1369                "page_size": 20,
1370                "total": 10
1371            }
1372        }
1373        "#;
1374
1375    // TODO: remove once deprecated ProtocolId struct is removed
1376    #[allow(deprecated)]
1377    #[rstest]
1378    #[case::protocol_id_input(vec![
1379        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1380        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1381    ])]
1382    #[case::string_input(vec![
1383        "id1".to_string(),
1384        "id2".to_string()
1385    ])]
1386    #[tokio::test]
1387    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1388    where
1389        T: AsRef<str> + Clone + Send + Sync + 'static,
1390    {
1391        let mock_client = MockRPCClient::new();
1392
1393        let request_bodies = mock_client
1394            .test_get_protocol_states_paginated(
1395                Chain::Ethereum,
1396                &ids,
1397                "test_system",
1398                true,
1399                &VersionParam::default(),
1400                2,
1401                2,
1402            )
1403            .await;
1404
1405        // Verify that the request bodies have been created correctly
1406        assert_eq!(request_bodies.len(), 1);
1407        assert_eq!(
1408            request_bodies[0]
1409                .protocol_ids
1410                .as_ref()
1411                .unwrap()
1412                .len(),
1413            2
1414        );
1415    }
1416
1417    #[tokio::test]
1418    async fn test_get_contract_state() {
1419        let mut server = Server::new_async().await;
1420        let server_resp = GET_CONTRACT_STATE_RESP;
1421        // test that the response is deserialized correctly
1422        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1423
1424        let mocked_server = server
1425            .mock("POST", "/v1/contract_state")
1426            .expect(1)
1427            .with_body(server_resp)
1428            .create_async()
1429            .await;
1430
1431        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1432            .expect("create client");
1433
1434        let response = client
1435            .get_contract_state(&Default::default())
1436            .await
1437            .expect("get state");
1438        let accounts = response.accounts;
1439
1440        mocked_server.assert();
1441        assert_eq!(accounts.len(), 1);
1442        assert_eq!(accounts[0].slots, HashMap::new());
1443        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1444        assert_eq!(accounts[0].code, [0].to_vec());
1445        assert_eq!(
1446            accounts[0].code_hash,
1447            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1448                .unwrap()
1449        );
1450    }
1451
1452    #[tokio::test]
1453    async fn test_get_protocol_components() {
1454        let mut server = Server::new_async().await;
1455        let server_resp = r#"
1456        {
1457            "protocol_components": [
1458                {
1459                    "id": "State1",
1460                    "protocol_system": "ambient",
1461                    "protocol_type_name": "Pool",
1462                    "chain": "ethereum",
1463                    "tokens": [
1464                        "0x0000000000000000000000000000000000000000",
1465                        "0x0000000000000000000000000000000000000001"
1466                    ],
1467                    "contract_ids": [
1468                        "0x0000000000000000000000000000000000000000"
1469                    ],
1470                    "static_attributes": {
1471                        "attribute_1": "0x00000000000003e8"
1472                    },
1473                    "change": "Creation",
1474                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1475                    "created_at": "2022-01-01T00:00:00"
1476                }
1477            ],
1478            "pagination": {
1479                "page": 0,
1480                "page_size": 20,
1481                "total": 10
1482            }
1483        }
1484        "#;
1485        // test that the response is deserialized correctly
1486        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1487
1488        let mocked_server = server
1489            .mock("POST", "/v1/protocol_components")
1490            .expect(1)
1491            .with_body(server_resp)
1492            .create_async()
1493            .await;
1494
1495        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1496            .expect("create client");
1497
1498        let response = client
1499            .get_protocol_components(&Default::default())
1500            .await
1501            .expect("get state");
1502        let components = response.protocol_components;
1503
1504        mocked_server.assert();
1505        assert_eq!(components.len(), 1);
1506        assert_eq!(components[0].id, "State1");
1507        assert_eq!(components[0].protocol_system, "ambient");
1508        assert_eq!(components[0].protocol_type_name, "Pool");
1509        assert_eq!(components[0].tokens.len(), 2);
1510        let expected_attributes =
1511            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1512                .iter()
1513                .cloned()
1514                .collect::<HashMap<String, Bytes>>();
1515        assert_eq!(components[0].static_attributes, expected_attributes);
1516    }
1517
1518    #[tokio::test]
1519    async fn test_get_protocol_states() {
1520        let mut server = Server::new_async().await;
1521        let server_resp = r#"
1522        {
1523            "states": [
1524                {
1525                    "component_id": "State1",
1526                    "attributes": {
1527                        "attribute_1": "0x00000000000003e8"
1528                    },
1529                    "balances": {
1530                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1531                    }
1532                }
1533            ],
1534            "pagination": {
1535                "page": 0,
1536                "page_size": 20,
1537                "total": 10
1538            }
1539        }
1540        "#;
1541        // test that the response is deserialized correctly
1542        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1543
1544        let mocked_server = server
1545            .mock("POST", "/v1/protocol_state")
1546            .expect(1)
1547            .with_body(server_resp)
1548            .create_async()
1549            .await;
1550        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1551            .expect("create client");
1552
1553        let response = client
1554            .get_protocol_states(&Default::default())
1555            .await
1556            .expect("get state");
1557        let states = response.states;
1558
1559        mocked_server.assert();
1560        assert_eq!(states.len(), 1);
1561        assert_eq!(states[0].component_id, "State1");
1562        let expected_attributes =
1563            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1564                .iter()
1565                .cloned()
1566                .collect::<HashMap<String, Bytes>>();
1567        assert_eq!(states[0].attributes, expected_attributes);
1568        let expected_balances = [(
1569            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1570                .expect("Unsupported address format"),
1571            Bytes::from_str("0x01f4").unwrap(),
1572        )]
1573        .iter()
1574        .cloned()
1575        .collect::<HashMap<Bytes, Bytes>>();
1576        assert_eq!(states[0].balances, expected_balances);
1577    }
1578
1579    #[tokio::test]
1580    async fn test_get_tokens() {
1581        let mut server = Server::new_async().await;
1582        let server_resp = r#"
1583        {
1584            "tokens": [
1585              {
1586                "chain": "ethereum",
1587                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1588                "symbol": "WETH",
1589                "decimals": 18,
1590                "tax": 0,
1591                "gas": [
1592                  29962
1593                ],
1594                "quality": 100
1595              },
1596              {
1597                "chain": "ethereum",
1598                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1599                "symbol": "USDC",
1600                "decimals": 6,
1601                "tax": 0,
1602                "gas": [
1603                  40652
1604                ],
1605                "quality": 100
1606              }
1607            ],
1608            "pagination": {
1609              "page": 0,
1610              "page_size": 20,
1611              "total": 10
1612            }
1613          }
1614        "#;
1615        // test that the response is deserialized correctly
1616        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1617
1618        let mocked_server = server
1619            .mock("POST", "/v1/tokens")
1620            .expect(1)
1621            .with_body(server_resp)
1622            .create_async()
1623            .await;
1624        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1625            .expect("create client");
1626
1627        let response = client
1628            .get_tokens(&Default::default())
1629            .await
1630            .expect("get tokens");
1631
1632        let expected = vec![
1633            ResponseToken {
1634                chain: Chain::Ethereum,
1635                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1636                symbol: "WETH".to_string(),
1637                decimals: 18,
1638                tax: 0,
1639                gas: vec![Some(29962)],
1640                quality: 100,
1641            },
1642            ResponseToken {
1643                chain: Chain::Ethereum,
1644                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1645                symbol: "USDC".to_string(),
1646                decimals: 6,
1647                tax: 0,
1648                gas: vec![Some(40652)],
1649                quality: 100,
1650            },
1651        ];
1652
1653        mocked_server.assert();
1654        assert_eq!(response.tokens, expected);
1655        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1656    }
1657
1658    #[tokio::test]
1659    async fn test_get_protocol_systems() {
1660        let mut server = Server::new_async().await;
1661        let server_resp = r#"
1662        {
1663            "protocol_systems": [
1664                "system1",
1665                "system2"
1666            ],
1667            "pagination": {
1668                "page": 0,
1669                "page_size": 20,
1670                "total": 10
1671            }
1672        }
1673        "#;
1674        // test that the response is deserialized correctly
1675        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1676
1677        let mocked_server = server
1678            .mock("POST", "/v1/protocol_systems")
1679            .expect(1)
1680            .with_body(server_resp)
1681            .create_async()
1682            .await;
1683        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1684            .expect("create client");
1685
1686        let response = client
1687            .get_protocol_systems(&Default::default())
1688            .await
1689            .expect("get protocol systems");
1690        let protocol_systems = response.protocol_systems;
1691
1692        mocked_server.assert();
1693        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1694    }
1695
1696    #[tokio::test]
1697    async fn test_get_component_tvl() {
1698        let mut server = Server::new_async().await;
1699        let server_resp = r#"
1700        {
1701            "tvl": {
1702                "component1": 100.0
1703            },
1704            "pagination": {
1705                "page": 0,
1706                "page_size": 20,
1707                "total": 10
1708            }
1709        }
1710        "#;
1711        // test that the response is deserialized correctly
1712        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1713
1714        let mocked_server = server
1715            .mock("POST", "/v1/component_tvl")
1716            .expect(1)
1717            .with_body(server_resp)
1718            .create_async()
1719            .await;
1720        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1721            .expect("create client");
1722
1723        let response = client
1724            .get_component_tvl(&Default::default())
1725            .await
1726            .expect("get protocol systems");
1727        let component_tvl = response.tvl;
1728
1729        mocked_server.assert();
1730        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1731    }
1732
1733    #[tokio::test]
1734    async fn test_get_traced_entry_points() {
1735        let mut server = Server::new_async().await;
1736        let server_resp = r#"
1737        {
1738            "traced_entry_points": {
1739                "component_1": [
1740                    [
1741                        {
1742                            "entry_point": {
1743                                "external_id": "entrypoint_a",
1744                                "target": "0x0000000000000000000000000000000000000001",
1745                                "signature": "sig()"
1746                            },
1747                            "params": {
1748                                "method": "rpctracer",
1749                                "caller": "0x000000000000000000000000000000000000000a",
1750                                "calldata": "0x000000000000000000000000000000000000000b"
1751                            }
1752                        },
1753                        {
1754                            "retriggers": [
1755                                [
1756                                    "0x00000000000000000000000000000000000000aa",
1757                                    {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1758                                ]
1759                            ],
1760                            "accessed_slots": {
1761                                "0x0000000000000000000000000000000000aaaa": [
1762                                    "0x0000000000000000000000000000000000aaaa"
1763                                ]
1764                            }
1765                        }
1766                    ]
1767                ]
1768            },
1769            "pagination": {
1770                "page": 0,
1771                "page_size": 20,
1772                "total": 1
1773            }
1774        }
1775        "#;
1776        // test that the response is deserialized correctly
1777        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1778
1779        let mocked_server = server
1780            .mock("POST", "/v1/traced_entry_points")
1781            .expect(1)
1782            .with_body(server_resp)
1783            .create_async()
1784            .await;
1785        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1786            .expect("create client");
1787
1788        let response = client
1789            .get_traced_entry_points(&Default::default())
1790            .await
1791            .expect("get traced entry points");
1792        let entrypoints = response.traced_entry_points;
1793
1794        mocked_server.assert();
1795        assert_eq!(entrypoints.len(), 1);
1796        let comp1_entrypoints = entrypoints
1797            .get("component_1")
1798            .expect("component_1 entrypoints should exist");
1799        assert_eq!(comp1_entrypoints.len(), 1);
1800
1801        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1802        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1803        assert_eq!(
1804            entrypoint.entry_point.target,
1805            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1806        );
1807        assert_eq!(entrypoint.entry_point.signature, "sig()");
1808        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1809        assert_eq!(
1810            rpc_params.caller,
1811            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1812        );
1813        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1814
1815        assert_eq!(
1816            trace_result.retriggers,
1817            HashSet::from([(
1818                Bytes::from("0x00000000000000000000000000000000000000aa"),
1819                AddressStorageLocation::new(
1820                    Bytes::from("0x0000000000000000000000000000000000000aaa"),
1821                    12
1822                )
1823            )])
1824        );
1825        assert_eq!(trace_result.accessed_slots.len(), 1);
1826        assert_eq!(
1827            trace_result.accessed_slots,
1828            HashMap::from([(
1829                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1830                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1831            )])
1832        );
1833    }
1834
1835    #[tokio::test]
1836    async fn test_parse_retry_value_numeric() {
1837        let result = parse_retry_value("60");
1838        assert!(result.is_some());
1839
1840        let expected_time = SystemTime::now() + Duration::from_secs(60);
1841        let actual_time = result.unwrap();
1842
1843        // Allow for small timing differences during test execution
1844        let diff = if actual_time > expected_time {
1845            actual_time
1846                .duration_since(expected_time)
1847                .unwrap()
1848        } else {
1849            expected_time
1850                .duration_since(actual_time)
1851                .unwrap()
1852        };
1853        assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1854    }
1855
1856    #[tokio::test]
1857    async fn test_parse_retry_value_rfc2822() {
1858        // Use a fixed future date in RFC2822 format
1859        let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1860        let result = parse_retry_value(rfc2822_date);
1861        assert!(result.is_some());
1862
1863        let parsed_time = result.unwrap();
1864        assert!(parsed_time > SystemTime::now());
1865    }
1866
1867    #[tokio::test]
1868    async fn test_parse_retry_value_invalid_formats() {
1869        // Test various invalid formats
1870        assert!(parse_retry_value("invalid").is_none());
1871        assert!(parse_retry_value("").is_none());
1872        assert!(parse_retry_value("not_a_number").is_none());
1873        assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); // Invalid date
1874    }
1875
1876    #[tokio::test]
1877    async fn test_parse_retry_value_zero_seconds() {
1878        let result = parse_retry_value("0");
1879        assert!(result.is_some());
1880
1881        let expected_time = SystemTime::now();
1882        let actual_time = result.unwrap();
1883
1884        // Should be very close to current time
1885        let diff = if actual_time > expected_time {
1886            actual_time
1887                .duration_since(expected_time)
1888                .unwrap()
1889        } else {
1890            expected_time
1891                .duration_since(actual_time)
1892                .unwrap()
1893        };
1894        assert!(diff < Duration::from_secs(1));
1895    }
1896
1897    #[tokio::test]
1898    async fn test_error_for_response_rate_limited() {
1899        let mut server = Server::new_async().await;
1900        let mock = server
1901            .mock("GET", "/test")
1902            .with_status(429)
1903            .with_header("Retry-After", "60")
1904            .create_async()
1905            .await;
1906
1907        let client = reqwest::Client::new();
1908        let response = client
1909            .get(format!("{}/test", server.url()))
1910            .send()
1911            .await
1912            .unwrap();
1913
1914        let http_client =
1915            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1916                .unwrap()
1917                .with_test_backoff_policy();
1918        let result = http_client
1919            .error_for_response(response)
1920            .await;
1921
1922        mock.assert();
1923        assert!(matches!(result, Err(RPCError::RateLimited(_))));
1924        if let Err(RPCError::RateLimited(retry_after)) = result {
1925            assert!(retry_after.is_some());
1926        }
1927    }
1928
1929    #[tokio::test]
1930    async fn test_error_for_response_rate_limited_no_header() {
1931        let mut server = Server::new_async().await;
1932        let mock = server
1933            .mock("GET", "/test")
1934            .with_status(429)
1935            .create_async()
1936            .await;
1937
1938        let client = reqwest::Client::new();
1939        let response = client
1940            .get(format!("{}/test", server.url()))
1941            .send()
1942            .await
1943            .unwrap();
1944
1945        let http_client =
1946            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1947                .unwrap()
1948                .with_test_backoff_policy();
1949        let result = http_client
1950            .error_for_response(response)
1951            .await;
1952
1953        mock.assert();
1954        assert!(matches!(result, Err(RPCError::RateLimited(None))));
1955    }
1956
1957    #[tokio::test]
1958    async fn test_error_for_response_server_errors() {
1959        let test_cases =
1960            vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
1961
1962        for (status_code, expected_body) in test_cases {
1963            let mut server = Server::new_async().await;
1964            let mock = server
1965                .mock("GET", "/test")
1966                .with_status(status_code)
1967                .with_body(expected_body)
1968                .create_async()
1969                .await;
1970
1971            let client = reqwest::Client::new();
1972            let response = client
1973                .get(format!("{}/test", server.url()))
1974                .send()
1975                .await
1976                .unwrap();
1977
1978            let http_client =
1979                HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1980                    .unwrap()
1981                    .with_test_backoff_policy();
1982            let result = http_client
1983                .error_for_response(response)
1984                .await;
1985
1986            mock.assert();
1987            assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
1988            if let Err(RPCError::ServerUnreachable(body)) = result {
1989                assert_eq!(body, expected_body);
1990            }
1991        }
1992    }
1993
1994    #[tokio::test]
1995    async fn test_error_for_response_success() {
1996        let mut server = Server::new_async().await;
1997        let mock = server
1998            .mock("GET", "/test")
1999            .with_status(200)
2000            .with_body("success")
2001            .create_async()
2002            .await;
2003
2004        let client = reqwest::Client::new();
2005        let response = client
2006            .get(format!("{}/test", server.url()))
2007            .send()
2008            .await
2009            .unwrap();
2010
2011        let http_client =
2012            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2013                .unwrap()
2014                .with_test_backoff_policy();
2015        let result = http_client
2016            .error_for_response(response)
2017            .await;
2018
2019        mock.assert();
2020        assert!(result.is_ok());
2021
2022        let response = result.unwrap();
2023        assert_eq!(response.status(), 200);
2024    }
2025
2026    #[tokio::test]
2027    async fn test_handle_error_for_backoff_server_unreachable() {
2028        let http_client =
2029            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2030                .unwrap()
2031                .with_test_backoff_policy();
2032        let error = RPCError::ServerUnreachable("Service down".to_string());
2033
2034        let backoff_error = http_client
2035            .handle_error_for_backoff(error)
2036            .await;
2037
2038        match backoff_error {
2039            backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
2040                assert_eq!(msg, "Service down");
2041                assert_eq!(retry_after, Some(Duration::from_millis(50))); // Fast test duration
2042            }
2043            _ => panic!("Expected transient error for ServerUnreachable"),
2044        }
2045    }
2046
2047    #[tokio::test]
2048    async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
2049        let http_client =
2050            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2051                .unwrap()
2052                .with_test_backoff_policy();
2053        let future_time = SystemTime::now() + Duration::from_secs(30);
2054        let error = RPCError::RateLimited(Some(future_time));
2055
2056        let backoff_error = http_client
2057            .handle_error_for_backoff(error)
2058            .await;
2059
2060        match backoff_error {
2061            backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2062                assert_eq!(retry_after, Some(future_time));
2063            }
2064            _ => panic!("Expected transient error for RateLimited"),
2065        }
2066
2067        // Verify that retry_after was stored in the client state
2068        let stored_retry_after = http_client.retry_after.read().await;
2069        assert_eq!(*stored_retry_after, Some(future_time));
2070    }
2071
2072    #[tokio::test]
2073    async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2074        let http_client =
2075            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2076                .unwrap()
2077                .with_test_backoff_policy();
2078        let error = RPCError::RateLimited(None);
2079
2080        let backoff_error = http_client
2081            .handle_error_for_backoff(error)
2082            .await;
2083
2084        match backoff_error {
2085            backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2086                // This is expected - no retry-after still allows retries with default policy
2087            }
2088            _ => panic!("Expected transient error for RateLimited without retry-after"),
2089        }
2090    }
2091
2092    #[tokio::test]
2093    async fn test_handle_error_for_backoff_other_errors() {
2094        let http_client =
2095            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2096                .unwrap()
2097                .with_test_backoff_policy();
2098        let error = RPCError::ParseResponse("Invalid JSON".to_string());
2099
2100        let backoff_error = http_client
2101            .handle_error_for_backoff(error)
2102            .await;
2103
2104        match backoff_error {
2105            backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2106                assert_eq!(msg, "Invalid JSON");
2107            }
2108            _ => panic!("Expected permanent error for ParseResponse"),
2109        }
2110    }
2111
2112    #[tokio::test]
2113    async fn test_wait_until_retry_after_no_retry_time() {
2114        let http_client =
2115            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2116                .unwrap()
2117                .with_test_backoff_policy();
2118
2119        let start = std::time::Instant::now();
2120        http_client
2121            .wait_until_retry_after()
2122            .await;
2123        let elapsed = start.elapsed();
2124
2125        // Should return immediately if no retry time is set
2126        assert!(elapsed < Duration::from_millis(100));
2127    }
2128
2129    #[tokio::test]
2130    async fn test_wait_until_retry_after_past_time() {
2131        let http_client =
2132            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2133                .unwrap()
2134                .with_test_backoff_policy();
2135
2136        // Set a retry time in the past
2137        let past_time = SystemTime::now() - Duration::from_secs(10);
2138        *http_client.retry_after.write().await = Some(past_time);
2139
2140        let start = std::time::Instant::now();
2141        http_client
2142            .wait_until_retry_after()
2143            .await;
2144        let elapsed = start.elapsed();
2145
2146        // Should return immediately if retry time is in the past
2147        assert!(elapsed < Duration::from_millis(100));
2148    }
2149
2150    #[tokio::test]
2151    async fn test_wait_until_retry_after_future_time() {
2152        let http_client =
2153            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2154                .unwrap()
2155                .with_test_backoff_policy();
2156
2157        // Set a retry time 100ms in the future
2158        let future_time = SystemTime::now() + Duration::from_millis(100);
2159        *http_client.retry_after.write().await = Some(future_time);
2160
2161        let start = std::time::Instant::now();
2162        http_client
2163            .wait_until_retry_after()
2164            .await;
2165        let elapsed = start.elapsed();
2166
2167        // Should wait approximately the specified duration
2168        assert!(elapsed >= Duration::from_millis(80)); // Allow some tolerance
2169        assert!(elapsed <= Duration::from_millis(200)); // Upper bound for test stability
2170    }
2171
2172    #[tokio::test]
2173    async fn test_make_post_request_success() {
2174        let mut server = Server::new_async().await;
2175        let server_resp = r#"{"success": true}"#;
2176
2177        let mock = server
2178            .mock("POST", "/test")
2179            .with_status(200)
2180            .with_body(server_resp)
2181            .create_async()
2182            .await;
2183
2184        let http_client =
2185            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2186                .unwrap()
2187                .with_test_backoff_policy();
2188        let request_body = serde_json::json!({"test": "data"});
2189        let uri = format!("{}/test", server.url());
2190
2191        let result = http_client
2192            .make_post_request(&request_body, &uri)
2193            .await;
2194
2195        mock.assert();
2196        assert!(result.is_ok());
2197
2198        let response = result.unwrap();
2199        assert_eq!(response.status(), 200);
2200        assert_eq!(response.text().await.unwrap(), server_resp);
2201    }
2202
2203    #[tokio::test]
2204    async fn test_make_post_request_retry_on_server_error() {
2205        let mut server = Server::new_async().await;
2206        // First request fails with 503, second succeeds
2207        let error_mock = server
2208            .mock("POST", "/test")
2209            .with_status(503)
2210            .with_body("Service Unavailable")
2211            .expect(1)
2212            .create_async()
2213            .await;
2214
2215        let success_mock = server
2216            .mock("POST", "/test")
2217            .with_status(200)
2218            .with_body(r#"{"success": true}"#)
2219            .expect(1)
2220            .create_async()
2221            .await;
2222
2223        let http_client =
2224            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2225                .unwrap()
2226                .with_test_backoff_policy();
2227        let request_body = serde_json::json!({"test": "data"});
2228        let uri = format!("{}/test", server.url());
2229
2230        let result = http_client
2231            .make_post_request(&request_body, &uri)
2232            .await;
2233
2234        error_mock.assert();
2235        success_mock.assert();
2236        assert!(result.is_ok());
2237    }
2238
2239    #[tokio::test]
2240    async fn test_make_post_request_respect_retry_after_header() {
2241        let mut server = Server::new_async().await;
2242
2243        // First request returns 429 with retry-after, second succeeds
2244        let rate_limit_mock = server
2245            .mock("POST", "/test")
2246            .with_status(429)
2247            .with_header("Retry-After", "1") // 1 second
2248            .expect(1)
2249            .create_async()
2250            .await;
2251
2252        let success_mock = server
2253            .mock("POST", "/test")
2254            .with_status(200)
2255            .with_body(r#"{"success": true}"#)
2256            .expect(1)
2257            .create_async()
2258            .await;
2259
2260        let http_client =
2261            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2262                .unwrap()
2263                .with_test_backoff_policy();
2264        let request_body = serde_json::json!({"test": "data"});
2265        let uri = format!("{}/test", server.url());
2266
2267        let start = std::time::Instant::now();
2268        let result = http_client
2269            .make_post_request(&request_body, &uri)
2270            .await;
2271        let elapsed = start.elapsed();
2272
2273        rate_limit_mock.assert();
2274        success_mock.assert();
2275        assert!(result.is_ok());
2276
2277        // Should have waited at least 1 second due to retry-after header
2278        assert!(elapsed >= Duration::from_millis(900)); // Allow some tolerance
2279        assert!(elapsed <= Duration::from_millis(2000)); // Upper bound for test stability
2280    }
2281
2282    #[tokio::test]
2283    async fn test_make_post_request_permanent_error() {
2284        let mut server = Server::new_async().await;
2285
2286        let mock = server
2287            .mock("POST", "/test")
2288            .with_status(400) // Bad Request - should not be retried
2289            .with_body("Bad Request")
2290            .expect(1)
2291            .create_async()
2292            .await;
2293
2294        let http_client =
2295            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2296                .unwrap()
2297                .with_test_backoff_policy();
2298        let request_body = serde_json::json!({"test": "data"});
2299        let uri = format!("{}/test", server.url());
2300
2301        let result = http_client
2302            .make_post_request(&request_body, &uri)
2303            .await;
2304
2305        mock.assert();
2306        assert!(result.is_ok()); // 400 doesn't trigger retry logic, just returns the response
2307
2308        let response = result.unwrap();
2309        assert_eq!(response.status(), 400);
2310    }
2311
2312    #[tokio::test]
2313    async fn test_concurrent_requests_with_different_retry_after() {
2314        let mut server = Server::new_async().await;
2315
2316        // First request gets rate limited with 1 second retry-after
2317        let rate_limit_mock_1 = server
2318            .mock("POST", "/test1")
2319            .with_status(429)
2320            .with_header("Retry-After", "1")
2321            .expect(1)
2322            .create_async()
2323            .await;
2324
2325        // Second request gets rate limited with 2 second retry-after
2326        let rate_limit_mock_2 = server
2327            .mock("POST", "/test2")
2328            .with_status(429)
2329            .with_header("Retry-After", "2")
2330            .expect(1)
2331            .create_async()
2332            .await;
2333
2334        // Success mocks for retries
2335        let success_mock_1 = server
2336            .mock("POST", "/test1")
2337            .with_status(200)
2338            .with_body(r#"{"result": "success1"}"#)
2339            .expect(1)
2340            .create_async()
2341            .await;
2342
2343        let success_mock_2 = server
2344            .mock("POST", "/test2")
2345            .with_status(200)
2346            .with_body(r#"{"result": "success2"}"#)
2347            .expect(1)
2348            .create_async()
2349            .await;
2350
2351        let http_client =
2352            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2353                .unwrap()
2354                .with_test_backoff_policy();
2355        let request_body = serde_json::json!({"test": "data"});
2356
2357        let uri1 = format!("{}/test1", server.url());
2358        let uri2 = format!("{}/test2", server.url());
2359
2360        // Start both requests concurrently
2361        let start = std::time::Instant::now();
2362        let (result1, result2) = tokio::join!(
2363            http_client.make_post_request(&request_body, &uri1),
2364            http_client.make_post_request(&request_body, &uri2)
2365        );
2366        let elapsed = start.elapsed();
2367
2368        rate_limit_mock_1.assert();
2369        rate_limit_mock_2.assert();
2370        success_mock_1.assert();
2371        success_mock_2.assert();
2372
2373        assert!(result1.is_ok());
2374        assert!(result2.is_ok());
2375
2376        // Both requests should succeed, but the second should take longer due to the 2s retry-after
2377        // The total time should be at least 2 seconds since the shared retry_after state
2378        // gets updated by both requests
2379        assert!(elapsed >= Duration::from_millis(1800)); // Allow some tolerance
2380        assert!(elapsed <= Duration::from_millis(3000)); // Upper bound
2381
2382        // Check the final retry_after state - should be the latest (higher) value
2383        let final_retry_after = http_client.retry_after.read().await;
2384        assert!(final_retry_after.is_some());
2385
2386        // The retry_after should be set to the latest (higher) value from the two requests
2387        if let Some(retry_time) = *final_retry_after {
2388            // The retry_after time might be in the past now since we waited,
2389            // but it should be reasonable (not too far in past/future)
2390            let now = SystemTime::now();
2391            let diff = if retry_time > now {
2392                retry_time.duration_since(now).unwrap()
2393            } else {
2394                now.duration_since(retry_time).unwrap()
2395            };
2396
2397            // Should be within a reasonable range (the 2s retry-after plus some buffer)
2398            assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2399        }
2400    }
2401
2402    #[tokio::test]
2403    async fn test_get_snapshots() {
2404        let mut server = Server::new_async().await;
2405
2406        // Mock protocol states response
2407        let protocol_states_resp = r#"
2408        {
2409            "states": [
2410                {
2411                    "component_id": "component1",
2412                    "attributes": {
2413                        "attribute_1": "0x00000000000003e8"
2414                    },
2415                    "balances": {
2416                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2417                    }
2418                }
2419            ],
2420            "pagination": {
2421                "page": 0,
2422                "page_size": 100,
2423                "total": 1
2424            }
2425        }
2426        "#;
2427
2428        // Mock contract state response
2429        let contract_state_resp = r#"
2430        {
2431            "accounts": [
2432                {
2433                    "chain": "ethereum",
2434                    "address": "0x1111111111111111111111111111111111111111",
2435                    "title": "",
2436                    "slots": {},
2437                    "native_balance": "0x01f4",
2438                    "token_balances": {},
2439                    "code": "0x00",
2440                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2441                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2442                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2443                    "creation_tx": null
2444                }
2445            ],
2446            "pagination": {
2447                "page": 0,
2448                "page_size": 100,
2449                "total": 1
2450            }
2451        }
2452        "#;
2453
2454        // Mock component TVL response
2455        let tvl_resp = r#"
2456        {
2457            "tvl": {
2458                "component1": 1000000.0
2459            },
2460            "pagination": {
2461                "page": 0,
2462                "page_size": 100,
2463                "total": 1
2464            }
2465        }
2466        "#;
2467
2468        let protocol_states_mock = server
2469            .mock("POST", "/v1/protocol_state")
2470            .expect(1)
2471            .with_body(protocol_states_resp)
2472            .create_async()
2473            .await;
2474
2475        let contract_state_mock = server
2476            .mock("POST", "/v1/contract_state")
2477            .expect(1)
2478            .with_body(contract_state_resp)
2479            .create_async()
2480            .await;
2481
2482        let tvl_mock = server
2483            .mock("POST", "/v1/component_tvl")
2484            .expect(1)
2485            .with_body(tvl_resp)
2486            .create_async()
2487            .await;
2488
2489        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2490            .expect("create client");
2491
2492        #[allow(deprecated)]
2493        let component = tycho_common::dto::ProtocolComponent {
2494            id: "component1".to_string(),
2495            protocol_system: "test_protocol".to_string(),
2496            protocol_type_name: "test_type".to_string(),
2497            chain: Chain::Ethereum,
2498            tokens: vec![],
2499            contract_ids: vec![
2500                Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2501            ],
2502            static_attributes: HashMap::new(),
2503            change: tycho_common::dto::ChangeType::Creation,
2504            creation_tx: Bytes::from_str(
2505                "0x0000000000000000000000000000000000000000000000000000000000000000",
2506            )
2507            .unwrap(),
2508            created_at: chrono::Utc::now().naive_utc(),
2509        };
2510
2511        let mut components = HashMap::new();
2512        components.insert("component1".to_string(), component);
2513
2514        let contract_ids =
2515            vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2516
2517        let request = SnapshotParameters::new(
2518            Chain::Ethereum,
2519            "test_protocol",
2520            &components,
2521            &contract_ids,
2522            12345,
2523        );
2524
2525        let response = client
2526            .get_snapshots(&request, 100, 4)
2527            .await
2528            .expect("get snapshots");
2529
2530        // Verify all mocks were called
2531        protocol_states_mock.assert();
2532        contract_state_mock.assert();
2533        tvl_mock.assert();
2534
2535        // Assert states
2536        assert_eq!(response.states.len(), 1);
2537        assert!(response
2538            .states
2539            .contains_key("component1"));
2540
2541        // Check that the state has the expected TVL
2542        let component_state = response
2543            .states
2544            .get("component1")
2545            .unwrap();
2546        assert_eq!(component_state.component_tvl, Some(1000000.0));
2547
2548        // Assert VM storage
2549        assert_eq!(response.vm_storage.len(), 1);
2550        let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2551        assert!(response
2552            .vm_storage
2553            .contains_key(&contract_addr));
2554    }
2555
2556    #[tokio::test]
2557    async fn test_get_snapshots_empty_components() {
2558        let server = Server::new_async().await;
2559        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2560            .expect("create client");
2561
2562        let components = HashMap::new();
2563        let contract_ids = vec![];
2564
2565        let request = SnapshotParameters::new(
2566            Chain::Ethereum,
2567            "test_protocol",
2568            &components,
2569            &contract_ids,
2570            12345,
2571        );
2572
2573        let response = client
2574            .get_snapshots(&request, 100, 4)
2575            .await
2576            .expect("get snapshots");
2577
2578        // Should return empty response without making any requests
2579        assert!(response.states.is_empty());
2580        assert!(response.vm_storage.is_empty());
2581    }
2582
2583    #[tokio::test]
2584    async fn test_get_snapshots_without_tvl() {
2585        let mut server = Server::new_async().await;
2586
2587        let protocol_states_resp = r#"
2588        {
2589            "states": [
2590                {
2591                    "component_id": "component1",
2592                    "attributes": {},
2593                    "balances": {}
2594                }
2595            ],
2596            "pagination": {
2597                "page": 0,
2598                "page_size": 100,
2599                "total": 1
2600            }
2601        }
2602        "#;
2603
2604        let protocol_states_mock = server
2605            .mock("POST", "/v1/protocol_state")
2606            .expect(1)
2607            .with_body(protocol_states_resp)
2608            .create_async()
2609            .await;
2610
2611        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2612            .expect("create client");
2613
2614        // Create test component
2615        #[allow(deprecated)]
2616        let component = tycho_common::dto::ProtocolComponent {
2617            id: "component1".to_string(),
2618            protocol_system: "test_protocol".to_string(),
2619            protocol_type_name: "test_type".to_string(),
2620            chain: Chain::Ethereum,
2621            tokens: vec![],
2622            contract_ids: vec![],
2623            static_attributes: HashMap::new(),
2624            change: tycho_common::dto::ChangeType::Creation,
2625            creation_tx: Bytes::from_str(
2626                "0x0000000000000000000000000000000000000000000000000000000000000000",
2627            )
2628            .unwrap(),
2629            created_at: chrono::Utc::now().naive_utc(),
2630        };
2631
2632        let mut components = HashMap::new();
2633        components.insert("component1".to_string(), component);
2634        let contract_ids = vec![];
2635
2636        let request = SnapshotParameters::new(
2637            Chain::Ethereum,
2638            "test_protocol",
2639            &components,
2640            &contract_ids,
2641            12345,
2642        )
2643        .include_balances(false)
2644        .include_tvl(false);
2645
2646        let response = client
2647            .get_snapshots(&request, 100, 4)
2648            .await
2649            .expect("get snapshots");
2650
2651        // Verify only necessary mocks were called
2652        protocol_states_mock.assert();
2653        // No contract_state_mock.assert() since contract_ids is empty
2654        // No tvl_mock.assert() since include_tvl is false
2655
2656        assert_eq!(response.states.len(), 1);
2657        // Check that TVL is None since we didn't request it
2658        let component_state = response
2659            .states
2660            .get("component1")
2661            .unwrap();
2662        assert_eq!(component_state.component_tvl, None);
2663    }
2664
2665    #[tokio::test]
2666    async fn test_compression_enabled() {
2667        let mut server = Server::new_async().await;
2668        let server_resp = GET_CONTRACT_STATE_RESP;
2669
2670        // Compress the response using zstd
2671        let compressed_body =
2672            zstd::encode_all(server_resp.as_bytes(), 0).expect("compression failed");
2673
2674        let mocked_server = server
2675            .mock("POST", "/v1/contract_state")
2676            .expect(1)
2677            .with_header("Content-Encoding", "zstd")
2678            .with_body(compressed_body)
2679            .create_async()
2680            .await;
2681
2682        // Create client with compression enabled
2683        let client = HttpRPCClient::new(
2684            server.url().as_str(),
2685            HttpRPCClientOptions::new().with_compression(true),
2686        )
2687        .expect("create client");
2688
2689        let response = client
2690            .get_contract_state(&Default::default())
2691            .await
2692            .expect("get state");
2693        let accounts = response.accounts;
2694
2695        mocked_server.assert();
2696        assert_eq!(accounts.len(), 1);
2697        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2698    }
2699
2700    #[tokio::test]
2701    async fn test_compression_disabled() {
2702        let mut server = Server::new_async().await;
2703        let server_resp = GET_CONTRACT_STATE_RESP;
2704
2705        // Verify client does NOT send Accept-Encoding: zstd when compression is disabled
2706        // Instead, server should receive request without compression headers
2707        let mocked_server = server
2708            .mock("POST", "/v1/contract_state")
2709            .expect(1)
2710            .match_header("Accept-Encoding", mockito::Matcher::Missing)
2711            .with_status(200)
2712            .with_body(server_resp)
2713            .create_async()
2714            .await;
2715
2716        // Create client with compression disabled
2717        let client = HttpRPCClient::new(
2718            server.url().as_str(),
2719            HttpRPCClientOptions::new().with_compression(false),
2720        )
2721        .expect("create client");
2722
2723        let response = client
2724            .get_contract_state(&Default::default())
2725            .await
2726            .expect("get state");
2727        let accounts = response.accounts;
2728
2729        // Verify the mock was called (client sent request without Accept-Encoding header)
2730        mocked_server.assert();
2731        assert_eq!(accounts.len(), 1);
2732        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2733    }
2734}