tycho_client/
rpc.rs

1//! # Tycho RPC Client
2//!
3//! The objective of this module is to provide swift and simplified access to the Remote Procedure
4//! Call (RPC) endpoints of Tycho. These endpoints are chiefly responsible for facilitating data
5//! queries, especially querying snapshots of data.
6use std::{
7    collections::HashMap,
8    sync::Arc,
9    time::{Duration, SystemTime},
10};
11
12use async_trait::async_trait;
13use backoff::{exponential::ExponentialBackoffBuilder, ExponentialBackoff};
14use futures03::future::try_join_all;
15#[cfg(test)]
16use mockall::automock;
17use reqwest::{header, Client, ClientBuilder, Response, StatusCode, Url};
18use serde::Serialize;
19use thiserror::Error;
20use time::{format_description::well_known::Rfc2822, OffsetDateTime};
21use tokio::{
22    sync::{RwLock, Semaphore},
23    time::sleep,
24};
25use tracing::{debug, error, instrument, trace, warn};
26use tycho_common::{
27    dto::{
28        BlockParam, Chain, ComponentTvlRequestBody, ComponentTvlRequestResponse,
29        EntryPointWithTracingParams, PaginationParams, PaginationResponse, ProtocolComponent,
30        ProtocolComponentRequestResponse, ProtocolComponentsRequestBody, ProtocolStateRequestBody,
31        ProtocolStateRequestResponse, ProtocolSystemsRequestBody, ProtocolSystemsRequestResponse,
32        ResponseToken, StateRequestBody, StateRequestResponse, TokensRequestBody,
33        TokensRequestResponse, TracedEntryPointRequestBody, TracedEntryPointRequestResponse,
34        TracingResult, VersionParam,
35    },
36    models::ComponentId,
37    Bytes,
38};
39
40use crate::{
41    feed::synchronizer::{ComponentWithState, Snapshot},
42    TYCHO_SERVER_VERSION,
43};
44
45/// Suggested concurrency level for RPC clients.
46pub const RPC_CLIENT_CONCURRENCY: usize = 4;
47
48/// Request body for fetching a snapshot of protocol states and VM storage.
49///
50/// This struct helps to coordinate fetching  multiple pieces of related data
51/// (protocol states, contract storage, TVL, entry points).
52#[derive(Clone, Debug, PartialEq)]
53pub struct SnapshotParameters<'a> {
54    /// Which chain to fetch snapshots for
55    pub chain: Chain,
56    /// Protocol system name, required for correct state resolution
57    pub protocol_system: &'a str,
58    /// Components to fetch protocol states for
59    pub components: &'a HashMap<ComponentId, ProtocolComponent>,
60    /// Traced entry points data mapped by component id
61    pub entrypoints: Option<&'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>>,
62    /// Contract addresses to fetch VM storage for
63    pub contract_ids: &'a [Bytes],
64    /// Block number for versioning
65    pub block_number: u64,
66    /// Whether to include balance information
67    pub include_balances: bool,
68    /// Whether to fetch TVL data
69    pub include_tvl: bool,
70}
71
72impl<'a> SnapshotParameters<'a> {
73    pub fn new(
74        chain: Chain,
75        protocol_system: &'a str,
76        components: &'a HashMap<ComponentId, ProtocolComponent>,
77        contract_ids: &'a [Bytes],
78        block_number: u64,
79    ) -> Self {
80        Self {
81            chain,
82            protocol_system,
83            components,
84            entrypoints: None,
85            contract_ids,
86            block_number,
87            include_balances: true,
88            include_tvl: true,
89        }
90    }
91
92    /// Set whether to include balance information (default: true)
93    pub fn include_balances(mut self, include_balances: bool) -> Self {
94        self.include_balances = include_balances;
95        self
96    }
97
98    /// Set whether to fetch TVL data (default: true)
99    pub fn include_tvl(mut self, include_tvl: bool) -> Self {
100        self.include_tvl = include_tvl;
101        self
102    }
103
104    pub fn entrypoints(
105        mut self,
106        entrypoints: &'a HashMap<String, Vec<(EntryPointWithTracingParams, TracingResult)>>,
107    ) -> Self {
108        self.entrypoints = Some(entrypoints);
109        self
110    }
111}
112
113#[derive(Error, Debug)]
114pub enum RPCError {
115    /// The passed tycho url failed to parse.
116    #[error("Failed to parse URL: {0}. Error: {1}")]
117    UrlParsing(String, String),
118
119    /// The request data is not correctly formed.
120    #[error("Failed to format request: {0}")]
121    FormatRequest(String),
122
123    /// Errors forwarded from the HTTP protocol.
124    #[error("Unexpected HTTP client error: {0}")]
125    HttpClient(String, #[source] reqwest::Error),
126
127    /// The response from the server could not be parsed correctly.
128    #[error("Failed to parse response: {0}")]
129    ParseResponse(String),
130
131    /// Other fatal errors.
132    #[error("Fatal error: {0}")]
133    Fatal(String),
134
135    #[error("Rate limited until {0:?}")]
136    RateLimited(Option<SystemTime>),
137
138    #[error("Server unreachable: {0}")]
139    ServerUnreachable(String),
140}
141
142#[cfg_attr(test, automock)]
143#[async_trait]
144pub trait RPCClient: Send + Sync {
145    /// Retrieves a snapshot of contract state.
146    async fn get_contract_state(
147        &self,
148        request: &StateRequestBody,
149    ) -> Result<StateRequestResponse, RPCError>;
150
151    async fn get_contract_state_paginated(
152        &self,
153        chain: Chain,
154        ids: &[Bytes],
155        protocol_system: &str,
156        version: &VersionParam,
157        chunk_size: usize,
158        concurrency: usize,
159    ) -> Result<StateRequestResponse, RPCError> {
160        let semaphore = Arc::new(Semaphore::new(concurrency));
161
162        // Sort the ids to maximize server-side cache hits
163        let mut sorted_ids = ids.to_vec();
164        sorted_ids.sort();
165
166        let chunked_bodies = sorted_ids
167            .chunks(chunk_size)
168            .map(|chunk| StateRequestBody {
169                contract_ids: Some(chunk.to_vec()),
170                protocol_system: protocol_system.to_string(),
171                chain,
172                version: version.clone(),
173                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
174            })
175            .collect::<Vec<_>>();
176
177        let mut tasks = Vec::new();
178        for body in chunked_bodies.iter() {
179            let sem = semaphore.clone();
180            tasks.push(async move {
181                let _permit = sem
182                    .acquire()
183                    .await
184                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
185                self.get_contract_state(body).await
186            });
187        }
188
189        // Execute all tasks concurrently with the defined concurrency limit.
190        let responses = try_join_all(tasks).await?;
191
192        // Aggregate the responses into a single result.
193        let accounts = responses
194            .iter()
195            .flat_map(|r| r.accounts.clone())
196            .collect();
197        let total: i64 = responses
198            .iter()
199            .map(|r| r.pagination.total)
200            .sum();
201
202        Ok(StateRequestResponse {
203            accounts,
204            pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
205        })
206    }
207
208    async fn get_protocol_components(
209        &self,
210        request: &ProtocolComponentsRequestBody,
211    ) -> Result<ProtocolComponentRequestResponse, RPCError>;
212
213    async fn get_protocol_components_paginated(
214        &self,
215        request: &ProtocolComponentsRequestBody,
216        chunk_size: usize,
217        concurrency: usize,
218    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
219        let semaphore = Arc::new(Semaphore::new(concurrency));
220
221        // If a set of component IDs is specified, the maximum return size is already known,
222        // allowing us to pre-compute the number of requests to be made.
223        match request.component_ids {
224            Some(ref ids) => {
225                // We can divide the component_ids into chunks of size chunk_size
226                let chunked_bodies = ids
227                    .chunks(chunk_size)
228                    .enumerate()
229                    .map(|(index, _)| ProtocolComponentsRequestBody {
230                        protocol_system: request.protocol_system.clone(),
231                        component_ids: request.component_ids.clone(),
232                        tvl_gt: request.tvl_gt,
233                        chain: request.chain,
234                        pagination: PaginationParams {
235                            page: index as i64,
236                            page_size: chunk_size as i64,
237                        },
238                    })
239                    .collect::<Vec<_>>();
240
241                let mut tasks = Vec::new();
242                for body in chunked_bodies.iter() {
243                    let sem = semaphore.clone();
244                    tasks.push(async move {
245                        let _permit = sem
246                            .acquire()
247                            .await
248                            .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
249                        self.get_protocol_components(body).await
250                    });
251                }
252
253                try_join_all(tasks)
254                    .await
255                    .map(|responses| ProtocolComponentRequestResponse {
256                        protocol_components: responses
257                            .into_iter()
258                            .flat_map(|r| r.protocol_components.into_iter())
259                            .collect(),
260                        pagination: PaginationResponse {
261                            page: 0,
262                            page_size: chunk_size as i64,
263                            total: ids.len() as i64,
264                        },
265                    })
266            }
267            _ => {
268                // If no component ids are specified, we need to make requests based on the total
269                // number of results from the first response.
270
271                let initial_request = ProtocolComponentsRequestBody {
272                    protocol_system: request.protocol_system.clone(),
273                    component_ids: request.component_ids.clone(),
274                    tvl_gt: request.tvl_gt,
275                    chain: request.chain,
276                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
277                };
278                let first_response = self
279                    .get_protocol_components(&initial_request)
280                    .await
281                    .map_err(|err| RPCError::Fatal(err.to_string()))?;
282
283                let total_items = first_response.pagination.total;
284                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
285
286                // Initialize the final response accumulator
287                let mut accumulated_response = ProtocolComponentRequestResponse {
288                    protocol_components: first_response.protocol_components,
289                    pagination: PaginationResponse {
290                        page: 0,
291                        page_size: chunk_size as i64,
292                        total: total_items,
293                    },
294                };
295
296                let mut page = 1;
297                while page < total_pages {
298                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
299
300                    // Create request bodies for parallel requests, respecting the concurrency limit
301                    let chunked_bodies = (0..requests_in_this_iteration)
302                        .map(|iter| ProtocolComponentsRequestBody {
303                            protocol_system: request.protocol_system.clone(),
304                            component_ids: request.component_ids.clone(),
305                            tvl_gt: request.tvl_gt,
306                            chain: request.chain,
307                            pagination: PaginationParams {
308                                page: page + iter,
309                                page_size: chunk_size as i64,
310                            },
311                        })
312                        .collect::<Vec<_>>();
313
314                    let tasks: Vec<_> = chunked_bodies
315                        .iter()
316                        .map(|body| {
317                            let sem = semaphore.clone();
318                            async move {
319                                let _permit = sem.acquire().await.map_err(|_| {
320                                    RPCError::Fatal("Semaphore dropped".to_string())
321                                })?;
322                                self.get_protocol_components(body).await
323                            }
324                        })
325                        .collect();
326
327                    let responses = try_join_all(tasks)
328                        .await
329                        .map(|responses| {
330                            let total = responses[0].pagination.total;
331                            ProtocolComponentRequestResponse {
332                                protocol_components: responses
333                                    .into_iter()
334                                    .flat_map(|r| r.protocol_components.into_iter())
335                                    .collect(),
336                                pagination: PaginationResponse {
337                                    page,
338                                    page_size: chunk_size as i64,
339                                    total,
340                                },
341                            }
342                        });
343
344                    // Update the accumulated response or set the initial response
345                    match responses {
346                        Ok(mut resp) => {
347                            accumulated_response
348                                .protocol_components
349                                .append(&mut resp.protocol_components);
350                        }
351                        Err(e) => return Err(e),
352                    }
353
354                    page += concurrency as i64;
355                }
356                Ok(accumulated_response)
357            }
358        }
359    }
360
361    async fn get_protocol_states(
362        &self,
363        request: &ProtocolStateRequestBody,
364    ) -> Result<ProtocolStateRequestResponse, RPCError>;
365
366    #[allow(clippy::too_many_arguments)]
367    async fn get_protocol_states_paginated<T>(
368        &self,
369        chain: Chain,
370        ids: &[T],
371        protocol_system: &str,
372        include_balances: bool,
373        version: &VersionParam,
374        chunk_size: usize,
375        concurrency: usize,
376    ) -> Result<ProtocolStateRequestResponse, RPCError>
377    where
378        T: AsRef<str> + Sync + 'static,
379    {
380        let semaphore = Arc::new(Semaphore::new(concurrency));
381        let chunked_bodies = ids
382            .chunks(chunk_size)
383            .map(|c| ProtocolStateRequestBody {
384                protocol_ids: Some(
385                    c.iter()
386                        .map(|id| id.as_ref().to_string())
387                        .collect(),
388                ),
389                protocol_system: protocol_system.to_string(),
390                chain,
391                include_balances,
392                version: version.clone(),
393                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
394            })
395            .collect::<Vec<_>>();
396
397        let mut tasks = Vec::new();
398        for body in chunked_bodies.iter() {
399            let sem = semaphore.clone();
400            tasks.push(async move {
401                let _permit = sem
402                    .acquire()
403                    .await
404                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
405                self.get_protocol_states(body).await
406            });
407        }
408
409        try_join_all(tasks)
410            .await
411            .map(|responses| {
412                let states = responses
413                    .clone()
414                    .into_iter()
415                    .flat_map(|r| r.states)
416                    .collect();
417                let total = responses
418                    .iter()
419                    .map(|r| r.pagination.total)
420                    .sum();
421                ProtocolStateRequestResponse {
422                    states,
423                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
424                }
425            })
426    }
427
428    /// This function returns only one chunk of tokens. To get all tokens please call
429    /// get_all_tokens.
430    async fn get_tokens(
431        &self,
432        request: &TokensRequestBody,
433    ) -> Result<TokensRequestResponse, RPCError>;
434
435    async fn get_all_tokens(
436        &self,
437        chain: Chain,
438        min_quality: Option<i32>,
439        traded_n_days_ago: Option<u64>,
440        chunk_size: usize,
441        concurrency: usize,
442    ) -> Result<Vec<ResponseToken>, RPCError> {
443        let semaphore = Arc::new(Semaphore::new(concurrency));
444
445        // Make initial request to get total count
446        let page_size = chunk_size.try_into().map_err(|_| {
447            RPCError::FormatRequest("Failed to convert chunk_size into i64".to_string())
448        })?;
449
450        let initial_request = TokensRequestBody {
451            token_addresses: None,
452            min_quality,
453            traded_n_days_ago,
454            pagination: PaginationParams { page: 0, page_size },
455            chain,
456        };
457
458        let first_response = self
459            .get_tokens(&initial_request)
460            .await?;
461        let total_items = first_response.pagination.total;
462        let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
463
464        let mut all_tokens = first_response.tokens;
465
466        // If only one page, return immediately
467        if total_pages <= 1 {
468            return Ok(all_tokens);
469        }
470
471        // Create a task for each remaining page
472        let tasks: Vec<_> = (1..total_pages)
473            .map(|page| {
474                let sem = semaphore.clone();
475                let request = TokensRequestBody {
476                    token_addresses: None,
477                    min_quality,
478                    traded_n_days_ago,
479                    pagination: PaginationParams { page, page_size },
480                    chain,
481                };
482
483                async move {
484                    // Semaphore controls how many requests are actually in-flight
485                    let _permit = sem
486                        .acquire()
487                        .await
488                        .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
489                    self.get_tokens(&request).await
490                }
491            })
492            .collect();
493
494        // Join all tasks - semaphore ensures only 'concurrency' execute at once
495        let responses = try_join_all(tasks).await?;
496
497        // Aggregate all tokens from all pages
498        for mut response in responses {
499            all_tokens.append(&mut response.tokens);
500        }
501
502        Ok(all_tokens)
503    }
504
505    async fn get_protocol_systems(
506        &self,
507        request: &ProtocolSystemsRequestBody,
508    ) -> Result<ProtocolSystemsRequestResponse, RPCError>;
509
510    async fn get_component_tvl(
511        &self,
512        request: &ComponentTvlRequestBody,
513    ) -> Result<ComponentTvlRequestResponse, RPCError>;
514
515    async fn get_component_tvl_paginated(
516        &self,
517        request: &ComponentTvlRequestBody,
518        chunk_size: usize,
519        concurrency: usize,
520    ) -> Result<ComponentTvlRequestResponse, RPCError> {
521        let semaphore = Arc::new(Semaphore::new(concurrency));
522
523        match request.component_ids {
524            Some(ref ids) => {
525                let chunked_requests = ids
526                    .chunks(chunk_size)
527                    .enumerate()
528                    .map(|(index, _)| ComponentTvlRequestBody {
529                        chain: request.chain,
530                        protocol_system: request.protocol_system.clone(),
531                        component_ids: Some(ids.clone()),
532                        pagination: PaginationParams {
533                            page: index as i64,
534                            page_size: chunk_size as i64,
535                        },
536                    })
537                    .collect::<Vec<_>>();
538
539                let tasks: Vec<_> = chunked_requests
540                    .into_iter()
541                    .map(|req| {
542                        let sem = semaphore.clone();
543                        async move {
544                            let _permit = sem
545                                .acquire()
546                                .await
547                                .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
548                            self.get_component_tvl(&req).await
549                        }
550                    })
551                    .collect();
552
553                let responses = try_join_all(tasks).await?;
554
555                let mut merged_tvl = HashMap::new();
556                for resp in responses {
557                    for (key, value) in resp.tvl {
558                        *merged_tvl.entry(key).or_insert(0.0) = value;
559                    }
560                }
561
562                Ok(ComponentTvlRequestResponse {
563                    tvl: merged_tvl,
564                    pagination: PaginationResponse {
565                        page: 0,
566                        page_size: chunk_size as i64,
567                        total: ids.len() as i64,
568                    },
569                })
570            }
571            _ => {
572                let first_request = ComponentTvlRequestBody {
573                    chain: request.chain,
574                    protocol_system: request.protocol_system.clone(),
575                    component_ids: request.component_ids.clone(),
576                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
577                };
578
579                let first_response = self
580                    .get_component_tvl(&first_request)
581                    .await?;
582                let total_items = first_response.pagination.total;
583                let total_pages = (total_items as f64 / chunk_size as f64).ceil() as i64;
584
585                let mut merged_tvl = first_response.tvl;
586
587                let mut page = 1;
588                while page < total_pages {
589                    let requests_in_this_iteration = (total_pages - page).min(concurrency as i64);
590
591                    let chunked_requests: Vec<_> = (0..requests_in_this_iteration)
592                        .map(|i| ComponentTvlRequestBody {
593                            chain: request.chain,
594                            protocol_system: request.protocol_system.clone(),
595                            component_ids: request.component_ids.clone(),
596                            pagination: PaginationParams {
597                                page: page + i,
598                                page_size: chunk_size as i64,
599                            },
600                        })
601                        .collect();
602
603                    let tasks: Vec<_> = chunked_requests
604                        .into_iter()
605                        .map(|req| {
606                            let sem = semaphore.clone();
607                            async move {
608                                let _permit = sem.acquire().await.map_err(|_| {
609                                    RPCError::Fatal("Semaphore dropped".to_string())
610                                })?;
611                                self.get_component_tvl(&req).await
612                            }
613                        })
614                        .collect();
615
616                    let responses = try_join_all(tasks).await?;
617
618                    // merge hashmap
619                    for resp in responses {
620                        for (key, value) in resp.tvl {
621                            *merged_tvl.entry(key).or_insert(0.0) += value;
622                        }
623                    }
624
625                    page += concurrency as i64;
626                }
627
628                Ok(ComponentTvlRequestResponse {
629                    tvl: merged_tvl,
630                    pagination: PaginationResponse {
631                        page: 0,
632                        page_size: chunk_size as i64,
633                        total: total_items,
634                    },
635                })
636            }
637        }
638    }
639
640    async fn get_traced_entry_points(
641        &self,
642        request: &TracedEntryPointRequestBody,
643    ) -> Result<TracedEntryPointRequestResponse, RPCError>;
644
645    async fn get_traced_entry_points_paginated(
646        &self,
647        chain: Chain,
648        protocol_system: &str,
649        component_ids: &[String],
650        chunk_size: usize,
651        concurrency: usize,
652    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
653        let semaphore = Arc::new(Semaphore::new(concurrency));
654        let chunked_bodies = component_ids
655            .chunks(chunk_size)
656            .map(|c| TracedEntryPointRequestBody {
657                chain,
658                protocol_system: protocol_system.to_string(),
659                component_ids: Some(c.to_vec()),
660                pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
661            })
662            .collect::<Vec<_>>();
663
664        let mut tasks = Vec::new();
665        for body in chunked_bodies.iter() {
666            let sem = semaphore.clone();
667            tasks.push(async move {
668                let _permit = sem
669                    .acquire()
670                    .await
671                    .map_err(|_| RPCError::Fatal("Semaphore dropped".to_string()))?;
672                self.get_traced_entry_points(body).await
673            });
674        }
675
676        try_join_all(tasks)
677            .await
678            .map(|responses| {
679                let traced_entry_points = responses
680                    .clone()
681                    .into_iter()
682                    .flat_map(|r| r.traced_entry_points)
683                    .collect();
684                let total = responses
685                    .iter()
686                    .map(|r| r.pagination.total)
687                    .sum();
688                TracedEntryPointRequestResponse {
689                    traced_entry_points,
690                    pagination: PaginationResponse { page: 0, page_size: chunk_size as i64, total },
691                }
692            })
693    }
694
695    async fn get_snapshots<'a>(
696        &self,
697        request: &SnapshotParameters<'a>,
698        chunk_size: usize,
699        concurrency: usize,
700    ) -> Result<Snapshot, RPCError>;
701}
702
703/// Configuration options for HttpRPCClient
704#[derive(Debug, Clone)]
705pub struct HttpRPCClientOptions {
706    /// Optional API key for authentication
707    pub auth_key: Option<String>,
708    /// Enable compression for requests (default: true)
709    /// When enabled, adds Accept-Encoding: zstd header
710    pub compression: bool,
711}
712
713impl Default for HttpRPCClientOptions {
714    fn default() -> Self {
715        Self::new()
716    }
717}
718
719impl HttpRPCClientOptions {
720    /// Create new options with default values (compression enabled)
721    pub fn new() -> Self {
722        Self { auth_key: None, compression: true }
723    }
724
725    /// Set the authentication key
726    pub fn with_auth_key(mut self, auth_key: Option<String>) -> Self {
727        self.auth_key = auth_key;
728        self
729    }
730
731    /// Set whether to enable compression (default: true)
732    pub fn with_compression(mut self, compression: bool) -> Self {
733        self.compression = compression;
734        self
735    }
736}
737
738#[derive(Debug, Clone)]
739pub struct HttpRPCClient {
740    http_client: Client,
741    url: Url,
742    retry_after: Arc<RwLock<Option<SystemTime>>>,
743    backoff_policy: ExponentialBackoff,
744    server_restart_duration: Duration,
745}
746
747impl HttpRPCClient {
748    pub fn new(base_uri: &str, options: HttpRPCClientOptions) -> Result<Self, RPCError> {
749        let uri = base_uri
750            .parse::<Url>()
751            .map_err(|e| RPCError::UrlParsing(base_uri.to_string(), e.to_string()))?;
752
753        // Add default headers
754        let mut headers = header::HeaderMap::new();
755        headers.insert(header::CONTENT_TYPE, header::HeaderValue::from_static("application/json"));
756        let user_agent = format!("tycho-client-{version}", version = env!("CARGO_PKG_VERSION"));
757        headers.insert(
758            header::USER_AGENT,
759            header::HeaderValue::from_str(&user_agent)
760                .map_err(|e| RPCError::FormatRequest(format!("Invalid user agent format: {e}")))?,
761        );
762
763        // Add Authorization if one is given
764        if let Some(key) = options.auth_key.as_deref() {
765            let mut auth_value = header::HeaderValue::from_str(key).map_err(|e| {
766                RPCError::FormatRequest(format!("Invalid authorization key format: {e}"))
767            })?;
768            auth_value.set_sensitive(true);
769            headers.insert(header::AUTHORIZATION, auth_value);
770        }
771
772        let mut client_builder = ClientBuilder::new()
773            .default_headers(headers)
774            .http2_prior_knowledge();
775
776        // When compression is disabled, turn off all automatic compression
777        if !options.compression {
778            client_builder = client_builder.no_zstd();
779        }
780
781        let client = client_builder
782            .build()
783            .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
784
785        Ok(Self {
786            http_client: client,
787            url: uri,
788            retry_after: Arc::new(RwLock::new(None)),
789            backoff_policy: ExponentialBackoffBuilder::new()
790                .with_initial_interval(Duration::from_millis(250))
791                // increase backoff time by 75% each failure
792                .with_multiplier(1.75)
793                // keep retrying every 30s
794                .with_max_interval(Duration::from_secs(30))
795                // if all retries take longer than 2m, give up
796                .with_max_elapsed_time(Some(Duration::from_secs(125)))
797                .build(),
798            server_restart_duration: Duration::from_secs(120),
799        })
800    }
801
802    #[cfg(test)]
803    pub fn with_test_backoff_policy(mut self) -> Self {
804        // Extremely short intervals for very fast testing
805        self.backoff_policy = ExponentialBackoffBuilder::new()
806            .with_initial_interval(Duration::from_millis(1))
807            .with_multiplier(1.1)
808            .with_max_interval(Duration::from_millis(5))
809            .with_max_elapsed_time(Some(Duration::from_millis(50)))
810            .build();
811        self.server_restart_duration = Duration::from_millis(50);
812        self
813    }
814
815    /// Converts a error response to a Result.
816    ///
817    /// Raises an error if the response status code id 429, 502, 503 or 504. In the 429
818    /// case it will try to look for a retry-after header an parse it accordingly. The
819    /// parsed value is then passed as part of the error.
820    async fn error_for_response(
821        &self,
822        response: reqwest::Response,
823    ) -> Result<reqwest::Response, RPCError> {
824        match response.status() {
825            StatusCode::TOO_MANY_REQUESTS => {
826                let retry_after_raw = response
827                    .headers()
828                    .get(reqwest::header::RETRY_AFTER)
829                    .and_then(|h| h.to_str().ok())
830                    .and_then(parse_retry_value);
831
832                Err(RPCError::RateLimited(retry_after_raw))
833            }
834            StatusCode::BAD_GATEWAY |
835            StatusCode::SERVICE_UNAVAILABLE |
836            StatusCode::GATEWAY_TIMEOUT => Err(RPCError::ServerUnreachable(
837                response
838                    .text()
839                    .await
840                    .unwrap_or_else(|_| "Server Unreachable".to_string()),
841            )),
842            _ => Ok(response),
843        }
844    }
845
846    /// Classifies errors into transient or permanent ones.
847    ///
848    /// Transient errors are retried with a potential backoff, permanent ones are not.
849    /// If the error is RateLimited, this method will set the self.retry_after value so
850    /// future requests wait until the rate limit has been reset.
851    async fn handle_error_for_backoff(&self, e: RPCError) -> backoff::Error<RPCError> {
852        match e {
853            RPCError::ServerUnreachable(_) => {
854                backoff::Error::retry_after(e, self.server_restart_duration)
855            }
856            RPCError::RateLimited(Some(until)) => {
857                let mut retry_after_guard = self.retry_after.write().await;
858                *retry_after_guard = Some(
859                    retry_after_guard
860                        .unwrap_or(until)
861                        .max(until),
862                );
863
864                if let Ok(duration) = until.duration_since(SystemTime::now()) {
865                    backoff::Error::retry_after(e, duration)
866                } else {
867                    e.into()
868                }
869            }
870            RPCError::RateLimited(None) => e.into(),
871            _ => backoff::Error::permanent(e),
872        }
873    }
874
875    /// Waits until the current rate limit time has passed.
876    ///
877    /// Only waits if there is a time and that time is in the future, else return
878    /// immediately.
879    async fn wait_until_retry_after(&self) {
880        if let Some(&until) = self.retry_after.read().await.as_ref() {
881            let now = SystemTime::now();
882            if until > now {
883                if let Ok(duration) = until.duration_since(now) {
884                    sleep(duration).await
885                }
886            }
887        }
888    }
889
890    /// Makes a post request handling transient failures.
891    ///
892    /// If a retry-after header is received it will be respected. Else the configured
893    /// backoff policy is used to deal with transient network or server errors.
894    async fn make_post_request<T: Serialize + ?Sized>(
895        &self,
896        request: &T,
897        uri: &String,
898    ) -> Result<Response, RPCError> {
899        self.wait_until_retry_after().await;
900        let response = backoff::future::retry(self.backoff_policy.clone(), || async {
901            let server_response = self
902                .http_client
903                .post(uri)
904                .json(request)
905                .send()
906                .await
907                .map_err(|e| RPCError::HttpClient(e.to_string(), e))?;
908
909            match self
910                .error_for_response(server_response)
911                .await
912            {
913                Ok(response) => Ok(response),
914                Err(e) => Err(self.handle_error_for_backoff(e).await),
915            }
916        })
917        .await?;
918        Ok(response)
919    }
920}
921
922fn parse_retry_value(val: &str) -> Option<SystemTime> {
923    if let Ok(secs) = val.parse::<u64>() {
924        return Some(SystemTime::now() + Duration::from_secs(secs));
925    }
926    if let Ok(date) = OffsetDateTime::parse(val, &Rfc2822) {
927        return Some(date.into());
928    }
929    None
930}
931
932#[async_trait]
933impl RPCClient for HttpRPCClient {
934    #[instrument(skip(self, request))]
935    async fn get_contract_state(
936        &self,
937        request: &StateRequestBody,
938    ) -> Result<StateRequestResponse, RPCError> {
939        // Check if contract ids are specified
940        if request
941            .contract_ids
942            .as_ref()
943            .is_none_or(|ids| ids.is_empty())
944        {
945            warn!("No contract ids specified in request.");
946        }
947
948        let uri = format!(
949            "{}/{}/contract_state",
950            self.url
951                .to_string()
952                .trim_end_matches('/'),
953            TYCHO_SERVER_VERSION
954        );
955        debug!(%uri, "Sending contract_state request to Tycho server");
956        trace!(?request, "Sending request to Tycho server");
957        let response = self
958            .make_post_request(request, &uri)
959            .await?;
960        trace!(?response, "Received response from Tycho server");
961
962        let body = response
963            .text()
964            .await
965            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
966        if body.is_empty() {
967            // Pure native protocols will return empty contract states
968            return Ok(StateRequestResponse {
969                accounts: vec![],
970                pagination: PaginationResponse {
971                    page: request.pagination.page,
972                    page_size: request.pagination.page,
973                    total: 0,
974                },
975            });
976        }
977
978        let accounts = serde_json::from_str::<StateRequestResponse>(&body)
979            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
980        trace!(?accounts, "Received contract_state response from Tycho server");
981
982        Ok(accounts)
983    }
984
985    async fn get_protocol_components(
986        &self,
987        request: &ProtocolComponentsRequestBody,
988    ) -> Result<ProtocolComponentRequestResponse, RPCError> {
989        let uri = format!(
990            "{}/{}/protocol_components",
991            self.url
992                .to_string()
993                .trim_end_matches('/'),
994            TYCHO_SERVER_VERSION,
995        );
996        debug!(%uri, "Sending protocol_components request to Tycho server");
997        trace!(?request, "Sending request to Tycho server");
998
999        let response = self
1000            .make_post_request(request, &uri)
1001            .await?;
1002
1003        trace!(?response, "Received response from Tycho server");
1004
1005        let body = response
1006            .text()
1007            .await
1008            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1009        let components = serde_json::from_str::<ProtocolComponentRequestResponse>(&body)
1010            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1011        trace!(?components, "Received protocol_components response from Tycho server");
1012
1013        Ok(components)
1014    }
1015
1016    async fn get_protocol_states(
1017        &self,
1018        request: &ProtocolStateRequestBody,
1019    ) -> Result<ProtocolStateRequestResponse, RPCError> {
1020        // Check if protocol ids are specified
1021        if request
1022            .protocol_ids
1023            .as_ref()
1024            .is_none_or(|ids| ids.is_empty())
1025        {
1026            warn!("No protocol ids specified in request.");
1027        }
1028
1029        let uri = format!(
1030            "{}/{}/protocol_state",
1031            self.url
1032                .to_string()
1033                .trim_end_matches('/'),
1034            TYCHO_SERVER_VERSION
1035        );
1036        debug!(%uri, "Sending protocol_states request to Tycho server");
1037        trace!(?request, "Sending request to Tycho server");
1038
1039        let response = self
1040            .make_post_request(request, &uri)
1041            .await?;
1042        trace!(?response, "Received response from Tycho server");
1043
1044        let body = response
1045            .text()
1046            .await
1047            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1048
1049        if body.is_empty() {
1050            // Pure VM protocols will return empty states
1051            return Ok(ProtocolStateRequestResponse {
1052                states: vec![],
1053                pagination: PaginationResponse {
1054                    page: request.pagination.page,
1055                    page_size: request.pagination.page_size,
1056                    total: 0,
1057                },
1058            });
1059        }
1060
1061        let states = serde_json::from_str::<ProtocolStateRequestResponse>(&body)
1062            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1063        trace!(?states, "Received protocol_states response from Tycho server");
1064
1065        Ok(states)
1066    }
1067
1068    async fn get_tokens(
1069        &self,
1070        request: &TokensRequestBody,
1071    ) -> Result<TokensRequestResponse, RPCError> {
1072        let uri = format!(
1073            "{}/{}/tokens",
1074            self.url
1075                .to_string()
1076                .trim_end_matches('/'),
1077            TYCHO_SERVER_VERSION
1078        );
1079        debug!(%uri, "Sending tokens request to Tycho server");
1080
1081        let response = self
1082            .make_post_request(request, &uri)
1083            .await?;
1084
1085        let body = response
1086            .text()
1087            .await
1088            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1089        let tokens = serde_json::from_str::<TokensRequestResponse>(&body)
1090            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1091
1092        Ok(tokens)
1093    }
1094
1095    async fn get_protocol_systems(
1096        &self,
1097        request: &ProtocolSystemsRequestBody,
1098    ) -> Result<ProtocolSystemsRequestResponse, RPCError> {
1099        let uri = format!(
1100            "{}/{}/protocol_systems",
1101            self.url
1102                .to_string()
1103                .trim_end_matches('/'),
1104            TYCHO_SERVER_VERSION
1105        );
1106        debug!(%uri, "Sending protocol_systems request to Tycho server");
1107        trace!(?request, "Sending request to Tycho server");
1108        let response = self
1109            .make_post_request(request, &uri)
1110            .await?;
1111        trace!(?response, "Received response from Tycho server");
1112        let body = response
1113            .text()
1114            .await
1115            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1116        let protocol_systems = serde_json::from_str::<ProtocolSystemsRequestResponse>(&body)
1117            .map_err(|err| RPCError::ParseResponse(format!("Error: {err}, Body: {body}")))?;
1118        trace!(?protocol_systems, "Received protocol_systems response from Tycho server");
1119        Ok(protocol_systems)
1120    }
1121
1122    async fn get_component_tvl(
1123        &self,
1124        request: &ComponentTvlRequestBody,
1125    ) -> Result<ComponentTvlRequestResponse, RPCError> {
1126        let uri = format!(
1127            "{}/{}/component_tvl",
1128            self.url
1129                .to_string()
1130                .trim_end_matches('/'),
1131            TYCHO_SERVER_VERSION
1132        );
1133        debug!(%uri, "Sending get_component_tvl request to Tycho server");
1134        trace!(?request, "Sending request to Tycho server");
1135        let response = self
1136            .make_post_request(request, &uri)
1137            .await?;
1138        trace!(?response, "Received response from Tycho server");
1139        let body = response
1140            .text()
1141            .await
1142            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1143        let component_tvl =
1144            serde_json::from_str::<ComponentTvlRequestResponse>(&body).map_err(|err| {
1145                error!("Failed to parse component_tvl response: {:?}", &body);
1146                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1147            })?;
1148        trace!(?component_tvl, "Received component_tvl response from Tycho server");
1149        Ok(component_tvl)
1150    }
1151
1152    async fn get_traced_entry_points(
1153        &self,
1154        request: &TracedEntryPointRequestBody,
1155    ) -> Result<TracedEntryPointRequestResponse, RPCError> {
1156        let uri = format!(
1157            "{}/{TYCHO_SERVER_VERSION}/traced_entry_points",
1158            self.url
1159                .to_string()
1160                .trim_end_matches('/')
1161        );
1162        debug!(%uri, "Sending traced_entry_points request to Tycho server");
1163        trace!(?request, "Sending request to Tycho server");
1164
1165        let response = self
1166            .make_post_request(request, &uri)
1167            .await?;
1168
1169        trace!(?response, "Received response from Tycho server");
1170
1171        let body = response
1172            .text()
1173            .await
1174            .map_err(|e| RPCError::ParseResponse(e.to_string()))?;
1175        let entrypoints =
1176            serde_json::from_str::<TracedEntryPointRequestResponse>(&body).map_err(|err| {
1177                error!("Failed to parse traced_entry_points response: {:?}", &body);
1178                RPCError::ParseResponse(format!("Error: {err}, Body: {body}"))
1179            })?;
1180        trace!(?entrypoints, "Received traced_entry_points response from Tycho server");
1181        Ok(entrypoints)
1182    }
1183
1184    async fn get_snapshots<'a>(
1185        &self,
1186        request: &SnapshotParameters<'a>,
1187        chunk_size: usize,
1188        concurrency: usize,
1189    ) -> Result<Snapshot, RPCError> {
1190        let component_ids: Vec<_> = request
1191            .components
1192            .keys()
1193            .cloned()
1194            .collect();
1195
1196        let version = VersionParam::new(
1197            None,
1198            Some({
1199                #[allow(deprecated)]
1200                BlockParam {
1201                    hash: None,
1202                    chain: Some(request.chain),
1203                    number: Some(request.block_number as i64),
1204                }
1205            }),
1206        );
1207
1208        let component_tvl = if request.include_tvl && !component_ids.is_empty() {
1209            let body = ComponentTvlRequestBody::id_filtered(component_ids.clone(), request.chain);
1210            self.get_component_tvl_paginated(&body, chunk_size, concurrency)
1211                .await?
1212                .tvl
1213        } else {
1214            HashMap::new()
1215        };
1216
1217        let mut protocol_states = if !component_ids.is_empty() {
1218            self.get_protocol_states_paginated(
1219                request.chain,
1220                &component_ids,
1221                request.protocol_system,
1222                request.include_balances,
1223                &version,
1224                chunk_size,
1225                concurrency,
1226            )
1227            .await?
1228            .states
1229            .into_iter()
1230            .map(|state| (state.component_id.clone(), state))
1231            .collect()
1232        } else {
1233            HashMap::new()
1234        };
1235
1236        // Convert to ComponentWithState, which includes entrypoint information.
1237        let states = request
1238            .components
1239            .values()
1240            .filter_map(|component| {
1241                if let Some(state) = protocol_states.remove(&component.id) {
1242                    Some((
1243                        component.id.clone(),
1244                        ComponentWithState {
1245                            state,
1246                            component: component.clone(),
1247                            component_tvl: component_tvl
1248                                .get(&component.id)
1249                                .cloned(),
1250                            entrypoints: request
1251                                .entrypoints
1252                                .as_ref()
1253                                .and_then(|map| map.get(&component.id))
1254                                .cloned()
1255                                .unwrap_or_default(),
1256                        },
1257                    ))
1258                } else if component_ids.contains(&component.id) {
1259                    // only emit error event if we requested this component
1260                    let component_id = &component.id;
1261                    error!(?component_id, "Missing state for native component!");
1262                    None
1263                } else {
1264                    None
1265                }
1266            })
1267            .collect();
1268
1269        let vm_storage = if !request.contract_ids.is_empty() {
1270            let contract_states = self
1271                .get_contract_state_paginated(
1272                    request.chain,
1273                    request.contract_ids,
1274                    request.protocol_system,
1275                    &version,
1276                    chunk_size,
1277                    concurrency,
1278                )
1279                .await?
1280                .accounts
1281                .into_iter()
1282                .map(|acc| (acc.address.clone(), acc))
1283                .collect::<HashMap<_, _>>();
1284
1285            trace!(states=?&contract_states, "Retrieved ContractState");
1286
1287            let contract_address_to_components = request
1288                .components
1289                .iter()
1290                .filter_map(|(id, comp)| {
1291                    if component_ids.contains(id) {
1292                        Some(
1293                            comp.contract_ids
1294                                .iter()
1295                                .map(|address| (address.clone(), comp.id.clone())),
1296                        )
1297                    } else {
1298                        None
1299                    }
1300                })
1301                .flatten()
1302                .fold(HashMap::<Bytes, Vec<String>>::new(), |mut acc, (addr, c_id)| {
1303                    acc.entry(addr).or_default().push(c_id);
1304                    acc
1305                });
1306
1307            request
1308                .contract_ids
1309                .iter()
1310                .filter_map(|address| {
1311                    if let Some(state) = contract_states.get(address) {
1312                        Some((address.clone(), state.clone()))
1313                    } else if let Some(ids) = contract_address_to_components.get(address) {
1314                        // only emit error even if we did actually request this address
1315                        error!(
1316                            ?address,
1317                            ?ids,
1318                            "Component with lacking contract storage encountered!"
1319                        );
1320                        None
1321                    } else {
1322                        None
1323                    }
1324                })
1325                .collect()
1326        } else {
1327            HashMap::new()
1328        };
1329
1330        Ok(Snapshot { states, vm_storage })
1331    }
1332}
1333
1334#[cfg(test)]
1335mod tests {
1336    use std::{
1337        collections::{HashMap, HashSet},
1338        str::FromStr,
1339    };
1340
1341    use mockito::Server;
1342    use rstest::rstest;
1343    // TODO: remove once deprecated ProtocolId struct is removed
1344    #[allow(deprecated)]
1345    use tycho_common::dto::ProtocolId;
1346    use tycho_common::dto::{AddressStorageLocation, TracingParams};
1347
1348    use super::*;
1349
1350    // Dummy implementation of `get_protocol_states_paginated` for backwards compatibility testing
1351    // purposes
1352    impl MockRPCClient {
1353        #[allow(clippy::too_many_arguments)]
1354        async fn test_get_protocol_states_paginated<T>(
1355            &self,
1356            chain: Chain,
1357            ids: &[T],
1358            protocol_system: &str,
1359            include_balances: bool,
1360            version: &VersionParam,
1361            chunk_size: usize,
1362            _concurrency: usize,
1363        ) -> Vec<ProtocolStateRequestBody>
1364        where
1365            T: AsRef<str> + Clone + Send + Sync + 'static,
1366        {
1367            ids.chunks(chunk_size)
1368                .map(|chunk| ProtocolStateRequestBody {
1369                    protocol_ids: Some(
1370                        chunk
1371                            .iter()
1372                            .map(|id| id.as_ref().to_string())
1373                            .collect(),
1374                    ),
1375                    protocol_system: protocol_system.to_string(),
1376                    chain,
1377                    include_balances,
1378                    version: version.clone(),
1379                    pagination: PaginationParams { page: 0, page_size: chunk_size as i64 },
1380                })
1381                .collect()
1382        }
1383    }
1384
1385    const GET_CONTRACT_STATE_RESP: &str = r#"
1386        {
1387            "accounts": [
1388                {
1389                    "chain": "ethereum",
1390                    "address": "0x0000000000000000000000000000000000000000",
1391                    "title": "",
1392                    "slots": {},
1393                    "native_balance": "0x01f4",
1394                    "token_balances": {},
1395                    "code": "0x00",
1396                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
1397                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1398                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1399                    "creation_tx": null
1400                }
1401            ],
1402            "pagination": {
1403                "page": 0,
1404                "page_size": 20,
1405                "total": 10
1406            }
1407        }
1408        "#;
1409
1410    // TODO: remove once deprecated ProtocolId struct is removed
1411    #[allow(deprecated)]
1412    #[rstest]
1413    #[case::protocol_id_input(vec![
1414        ProtocolId { id: "id1".to_string(), chain: Chain::Ethereum },
1415        ProtocolId { id: "id2".to_string(), chain: Chain::Ethereum }
1416    ])]
1417    #[case::string_input(vec![
1418        "id1".to_string(),
1419        "id2".to_string()
1420    ])]
1421    #[tokio::test]
1422    async fn test_get_protocol_states_paginated_backwards_compatibility<T>(#[case] ids: Vec<T>)
1423    where
1424        T: AsRef<str> + Clone + Send + Sync + 'static,
1425    {
1426        let mock_client = MockRPCClient::new();
1427
1428        let request_bodies = mock_client
1429            .test_get_protocol_states_paginated(
1430                Chain::Ethereum,
1431                &ids,
1432                "test_system",
1433                true,
1434                &VersionParam::default(),
1435                2,
1436                2,
1437            )
1438            .await;
1439
1440        // Verify that the request bodies have been created correctly
1441        assert_eq!(request_bodies.len(), 1);
1442        assert_eq!(
1443            request_bodies[0]
1444                .protocol_ids
1445                .as_ref()
1446                .unwrap()
1447                .len(),
1448            2
1449        );
1450    }
1451
1452    #[tokio::test]
1453    async fn test_get_contract_state() {
1454        let mut server = Server::new_async().await;
1455        let server_resp = GET_CONTRACT_STATE_RESP;
1456        // test that the response is deserialized correctly
1457        serde_json::from_str::<StateRequestResponse>(server_resp).expect("deserialize");
1458
1459        let mocked_server = server
1460            .mock("POST", "/v1/contract_state")
1461            .expect(1)
1462            .with_body(server_resp)
1463            .create_async()
1464            .await;
1465
1466        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1467            .expect("create client");
1468
1469        let response = client
1470            .get_contract_state(&Default::default())
1471            .await
1472            .expect("get state");
1473        let accounts = response.accounts;
1474
1475        mocked_server.assert();
1476        assert_eq!(accounts.len(), 1);
1477        assert_eq!(accounts[0].slots, HashMap::new());
1478        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
1479        assert_eq!(accounts[0].code, [0].to_vec());
1480        assert_eq!(
1481            accounts[0].code_hash,
1482            hex::decode("5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e")
1483                .unwrap()
1484        );
1485    }
1486
1487    #[tokio::test]
1488    async fn test_get_protocol_components() {
1489        let mut server = Server::new_async().await;
1490        let server_resp = r#"
1491        {
1492            "protocol_components": [
1493                {
1494                    "id": "State1",
1495                    "protocol_system": "ambient",
1496                    "protocol_type_name": "Pool",
1497                    "chain": "ethereum",
1498                    "tokens": [
1499                        "0x0000000000000000000000000000000000000000",
1500                        "0x0000000000000000000000000000000000000001"
1501                    ],
1502                    "contract_ids": [
1503                        "0x0000000000000000000000000000000000000000"
1504                    ],
1505                    "static_attributes": {
1506                        "attribute_1": "0x00000000000003e8"
1507                    },
1508                    "change": "Creation",
1509                    "creation_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
1510                    "created_at": "2022-01-01T00:00:00"
1511                }
1512            ],
1513            "pagination": {
1514                "page": 0,
1515                "page_size": 20,
1516                "total": 10
1517            }
1518        }
1519        "#;
1520        // test that the response is deserialized correctly
1521        serde_json::from_str::<ProtocolComponentRequestResponse>(server_resp).expect("deserialize");
1522
1523        let mocked_server = server
1524            .mock("POST", "/v1/protocol_components")
1525            .expect(1)
1526            .with_body(server_resp)
1527            .create_async()
1528            .await;
1529
1530        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1531            .expect("create client");
1532
1533        let response = client
1534            .get_protocol_components(&Default::default())
1535            .await
1536            .expect("get state");
1537        let components = response.protocol_components;
1538
1539        mocked_server.assert();
1540        assert_eq!(components.len(), 1);
1541        assert_eq!(components[0].id, "State1");
1542        assert_eq!(components[0].protocol_system, "ambient");
1543        assert_eq!(components[0].protocol_type_name, "Pool");
1544        assert_eq!(components[0].tokens.len(), 2);
1545        let expected_attributes =
1546            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1547                .iter()
1548                .cloned()
1549                .collect::<HashMap<String, Bytes>>();
1550        assert_eq!(components[0].static_attributes, expected_attributes);
1551    }
1552
1553    #[tokio::test]
1554    async fn test_get_protocol_states() {
1555        let mut server = Server::new_async().await;
1556        let server_resp = r#"
1557        {
1558            "states": [
1559                {
1560                    "component_id": "State1",
1561                    "attributes": {
1562                        "attribute_1": "0x00000000000003e8"
1563                    },
1564                    "balances": {
1565                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
1566                    }
1567                }
1568            ],
1569            "pagination": {
1570                "page": 0,
1571                "page_size": 20,
1572                "total": 10
1573            }
1574        }
1575        "#;
1576        // test that the response is deserialized correctly
1577        serde_json::from_str::<ProtocolStateRequestResponse>(server_resp).expect("deserialize");
1578
1579        let mocked_server = server
1580            .mock("POST", "/v1/protocol_state")
1581            .expect(1)
1582            .with_body(server_resp)
1583            .create_async()
1584            .await;
1585        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1586            .expect("create client");
1587
1588        let response = client
1589            .get_protocol_states(&Default::default())
1590            .await
1591            .expect("get state");
1592        let states = response.states;
1593
1594        mocked_server.assert();
1595        assert_eq!(states.len(), 1);
1596        assert_eq!(states[0].component_id, "State1");
1597        let expected_attributes =
1598            [("attribute_1".to_string(), Bytes::from(1000_u64.to_be_bytes()))]
1599                .iter()
1600                .cloned()
1601                .collect::<HashMap<String, Bytes>>();
1602        assert_eq!(states[0].attributes, expected_attributes);
1603        let expected_balances = [(
1604            Bytes::from_str("0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2")
1605                .expect("Unsupported address format"),
1606            Bytes::from_str("0x01f4").unwrap(),
1607        )]
1608        .iter()
1609        .cloned()
1610        .collect::<HashMap<Bytes, Bytes>>();
1611        assert_eq!(states[0].balances, expected_balances);
1612    }
1613
1614    #[tokio::test]
1615    async fn test_get_tokens() {
1616        let mut server = Server::new_async().await;
1617        let server_resp = r#"
1618        {
1619            "tokens": [
1620              {
1621                "chain": "ethereum",
1622                "address": "0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2",
1623                "symbol": "WETH",
1624                "decimals": 18,
1625                "tax": 0,
1626                "gas": [
1627                  29962
1628                ],
1629                "quality": 100
1630              },
1631              {
1632                "chain": "ethereum",
1633                "address": "0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48",
1634                "symbol": "USDC",
1635                "decimals": 6,
1636                "tax": 0,
1637                "gas": [
1638                  40652
1639                ],
1640                "quality": 100
1641              }
1642            ],
1643            "pagination": {
1644              "page": 0,
1645              "page_size": 20,
1646              "total": 10
1647            }
1648          }
1649        "#;
1650        // test that the response is deserialized correctly
1651        serde_json::from_str::<TokensRequestResponse>(server_resp).expect("deserialize");
1652
1653        let mocked_server = server
1654            .mock("POST", "/v1/tokens")
1655            .expect(1)
1656            .with_body(server_resp)
1657            .create_async()
1658            .await;
1659        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1660            .expect("create client");
1661
1662        let response = client
1663            .get_tokens(&Default::default())
1664            .await
1665            .expect("get tokens");
1666
1667        let expected = vec![
1668            ResponseToken {
1669                chain: Chain::Ethereum,
1670                address: Bytes::from_str("0xc02aaa39b223fe8d0a0e5c4f27ead9083c756cc2").unwrap(),
1671                symbol: "WETH".to_string(),
1672                decimals: 18,
1673                tax: 0,
1674                gas: vec![Some(29962)],
1675                quality: 100,
1676            },
1677            ResponseToken {
1678                chain: Chain::Ethereum,
1679                address: Bytes::from_str("0xa0b86991c6218b36c1d19d4a2e9eb0ce3606eb48").unwrap(),
1680                symbol: "USDC".to_string(),
1681                decimals: 6,
1682                tax: 0,
1683                gas: vec![Some(40652)],
1684                quality: 100,
1685            },
1686        ];
1687
1688        mocked_server.assert();
1689        assert_eq!(response.tokens, expected);
1690        assert_eq!(response.pagination, PaginationResponse { page: 0, page_size: 20, total: 10 });
1691    }
1692
1693    #[tokio::test]
1694    async fn test_get_protocol_systems() {
1695        let mut server = Server::new_async().await;
1696        let server_resp = r#"
1697        {
1698            "protocol_systems": [
1699                "system1",
1700                "system2"
1701            ],
1702            "pagination": {
1703                "page": 0,
1704                "page_size": 20,
1705                "total": 10
1706            }
1707        }
1708        "#;
1709        // test that the response is deserialized correctly
1710        serde_json::from_str::<ProtocolSystemsRequestResponse>(server_resp).expect("deserialize");
1711
1712        let mocked_server = server
1713            .mock("POST", "/v1/protocol_systems")
1714            .expect(1)
1715            .with_body(server_resp)
1716            .create_async()
1717            .await;
1718        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1719            .expect("create client");
1720
1721        let response = client
1722            .get_protocol_systems(&Default::default())
1723            .await
1724            .expect("get protocol systems");
1725        let protocol_systems = response.protocol_systems;
1726
1727        mocked_server.assert();
1728        assert_eq!(protocol_systems, vec!["system1", "system2"]);
1729    }
1730
1731    #[tokio::test]
1732    async fn test_get_component_tvl() {
1733        let mut server = Server::new_async().await;
1734        let server_resp = r#"
1735        {
1736            "tvl": {
1737                "component1": 100.0
1738            },
1739            "pagination": {
1740                "page": 0,
1741                "page_size": 20,
1742                "total": 10
1743            }
1744        }
1745        "#;
1746        // test that the response is deserialized correctly
1747        serde_json::from_str::<ComponentTvlRequestResponse>(server_resp).expect("deserialize");
1748
1749        let mocked_server = server
1750            .mock("POST", "/v1/component_tvl")
1751            .expect(1)
1752            .with_body(server_resp)
1753            .create_async()
1754            .await;
1755        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1756            .expect("create client");
1757
1758        let response = client
1759            .get_component_tvl(&Default::default())
1760            .await
1761            .expect("get protocol systems");
1762        let component_tvl = response.tvl;
1763
1764        mocked_server.assert();
1765        assert_eq!(component_tvl.get("component1"), Some(&100.0));
1766    }
1767
1768    #[tokio::test]
1769    async fn test_get_traced_entry_points() {
1770        let mut server = Server::new_async().await;
1771        let server_resp = r#"
1772        {
1773            "traced_entry_points": {
1774                "component_1": [
1775                    [
1776                        {
1777                            "entry_point": {
1778                                "external_id": "entrypoint_a",
1779                                "target": "0x0000000000000000000000000000000000000001",
1780                                "signature": "sig()"
1781                            },
1782                            "params": {
1783                                "method": "rpctracer",
1784                                "caller": "0x000000000000000000000000000000000000000a",
1785                                "calldata": "0x000000000000000000000000000000000000000b"
1786                            }
1787                        },
1788                        {
1789                            "retriggers": [
1790                                [
1791                                    "0x00000000000000000000000000000000000000aa",
1792                                    {"key": "0x0000000000000000000000000000000000000aaa", "offset": 12}
1793                                ]
1794                            ],
1795                            "accessed_slots": {
1796                                "0x0000000000000000000000000000000000aaaa": [
1797                                    "0x0000000000000000000000000000000000aaaa"
1798                                ]
1799                            }
1800                        }
1801                    ]
1802                ]
1803            },
1804            "pagination": {
1805                "page": 0,
1806                "page_size": 20,
1807                "total": 1
1808            }
1809        }
1810        "#;
1811        // test that the response is deserialized correctly
1812        serde_json::from_str::<TracedEntryPointRequestResponse>(server_resp).expect("deserialize");
1813
1814        let mocked_server = server
1815            .mock("POST", "/v1/traced_entry_points")
1816            .expect(1)
1817            .with_body(server_resp)
1818            .create_async()
1819            .await;
1820        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1821            .expect("create client");
1822
1823        let response = client
1824            .get_traced_entry_points(&Default::default())
1825            .await
1826            .expect("get traced entry points");
1827        let entrypoints = response.traced_entry_points;
1828
1829        mocked_server.assert();
1830        assert_eq!(entrypoints.len(), 1);
1831        let comp1_entrypoints = entrypoints
1832            .get("component_1")
1833            .expect("component_1 entrypoints should exist");
1834        assert_eq!(comp1_entrypoints.len(), 1);
1835
1836        let (entrypoint, trace_result) = &comp1_entrypoints[0];
1837        assert_eq!(entrypoint.entry_point.external_id, "entrypoint_a");
1838        assert_eq!(
1839            entrypoint.entry_point.target,
1840            Bytes::from_str("0x0000000000000000000000000000000000000001").unwrap()
1841        );
1842        assert_eq!(entrypoint.entry_point.signature, "sig()");
1843        let TracingParams::RPCTracer(rpc_params) = &entrypoint.params;
1844        assert_eq!(
1845            rpc_params.caller,
1846            Some(Bytes::from("0x000000000000000000000000000000000000000a"))
1847        );
1848        assert_eq!(rpc_params.calldata, Bytes::from("0x000000000000000000000000000000000000000b"));
1849
1850        assert_eq!(
1851            trace_result.retriggers,
1852            HashSet::from([(
1853                Bytes::from("0x00000000000000000000000000000000000000aa"),
1854                AddressStorageLocation::new(
1855                    Bytes::from("0x0000000000000000000000000000000000000aaa"),
1856                    12
1857                )
1858            )])
1859        );
1860        assert_eq!(trace_result.accessed_slots.len(), 1);
1861        assert_eq!(
1862            trace_result.accessed_slots,
1863            HashMap::from([(
1864                Bytes::from("0x0000000000000000000000000000000000aaaa"),
1865                HashSet::from([Bytes::from("0x0000000000000000000000000000000000aaaa")])
1866            )])
1867        );
1868    }
1869
1870    #[tokio::test]
1871    async fn test_parse_retry_value_numeric() {
1872        let result = parse_retry_value("60");
1873        assert!(result.is_some());
1874
1875        let expected_time = SystemTime::now() + Duration::from_secs(60);
1876        let actual_time = result.unwrap();
1877
1878        // Allow for small timing differences during test execution
1879        let diff = if actual_time > expected_time {
1880            actual_time
1881                .duration_since(expected_time)
1882                .unwrap()
1883        } else {
1884            expected_time
1885                .duration_since(actual_time)
1886                .unwrap()
1887        };
1888        assert!(diff < Duration::from_secs(1), "Time difference too large: {:?}", diff);
1889    }
1890
1891    #[tokio::test]
1892    async fn test_parse_retry_value_rfc2822() {
1893        // Use a fixed future date in RFC2822 format
1894        let rfc2822_date = "Sat, 01 Jan 2030 12:00:00 +0000";
1895        let result = parse_retry_value(rfc2822_date);
1896        assert!(result.is_some());
1897
1898        let parsed_time = result.unwrap();
1899        assert!(parsed_time > SystemTime::now());
1900    }
1901
1902    #[tokio::test]
1903    async fn test_parse_retry_value_invalid_formats() {
1904        // Test various invalid formats
1905        assert!(parse_retry_value("invalid").is_none());
1906        assert!(parse_retry_value("").is_none());
1907        assert!(parse_retry_value("not_a_number").is_none());
1908        assert!(parse_retry_value("Mon, 32 Jan 2030 25:00:00 +0000").is_none()); // Invalid date
1909    }
1910
1911    #[tokio::test]
1912    async fn test_parse_retry_value_zero_seconds() {
1913        let result = parse_retry_value("0");
1914        assert!(result.is_some());
1915
1916        let expected_time = SystemTime::now();
1917        let actual_time = result.unwrap();
1918
1919        // Should be very close to current time
1920        let diff = if actual_time > expected_time {
1921            actual_time
1922                .duration_since(expected_time)
1923                .unwrap()
1924        } else {
1925            expected_time
1926                .duration_since(actual_time)
1927                .unwrap()
1928        };
1929        assert!(diff < Duration::from_secs(1));
1930    }
1931
1932    #[tokio::test]
1933    async fn test_error_for_response_rate_limited() {
1934        let mut server = Server::new_async().await;
1935        let mock = server
1936            .mock("GET", "/test")
1937            .with_status(429)
1938            .with_header("Retry-After", "60")
1939            .create_async()
1940            .await;
1941
1942        let client = reqwest::Client::new();
1943        let response = client
1944            .get(format!("{}/test", server.url()))
1945            .send()
1946            .await
1947            .unwrap();
1948
1949        let http_client =
1950            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1951                .unwrap()
1952                .with_test_backoff_policy();
1953        let result = http_client
1954            .error_for_response(response)
1955            .await;
1956
1957        mock.assert();
1958        assert!(matches!(result, Err(RPCError::RateLimited(_))));
1959        if let Err(RPCError::RateLimited(retry_after)) = result {
1960            assert!(retry_after.is_some());
1961        }
1962    }
1963
1964    #[tokio::test]
1965    async fn test_error_for_response_rate_limited_no_header() {
1966        let mut server = Server::new_async().await;
1967        let mock = server
1968            .mock("GET", "/test")
1969            .with_status(429)
1970            .create_async()
1971            .await;
1972
1973        let client = reqwest::Client::new();
1974        let response = client
1975            .get(format!("{}/test", server.url()))
1976            .send()
1977            .await
1978            .unwrap();
1979
1980        let http_client =
1981            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
1982                .unwrap()
1983                .with_test_backoff_policy();
1984        let result = http_client
1985            .error_for_response(response)
1986            .await;
1987
1988        mock.assert();
1989        assert!(matches!(result, Err(RPCError::RateLimited(None))));
1990    }
1991
1992    #[tokio::test]
1993    async fn test_error_for_response_server_errors() {
1994        let test_cases =
1995            vec![(502, "Bad Gateway"), (503, "Service Unavailable"), (504, "Gateway Timeout")];
1996
1997        for (status_code, expected_body) in test_cases {
1998            let mut server = Server::new_async().await;
1999            let mock = server
2000                .mock("GET", "/test")
2001                .with_status(status_code)
2002                .with_body(expected_body)
2003                .create_async()
2004                .await;
2005
2006            let client = reqwest::Client::new();
2007            let response = client
2008                .get(format!("{}/test", server.url()))
2009                .send()
2010                .await
2011                .unwrap();
2012
2013            let http_client =
2014                HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2015                    .unwrap()
2016                    .with_test_backoff_policy();
2017            let result = http_client
2018                .error_for_response(response)
2019                .await;
2020
2021            mock.assert();
2022            assert!(matches!(result, Err(RPCError::ServerUnreachable(_))));
2023            if let Err(RPCError::ServerUnreachable(body)) = result {
2024                assert_eq!(body, expected_body);
2025            }
2026        }
2027    }
2028
2029    #[tokio::test]
2030    async fn test_error_for_response_success() {
2031        let mut server = Server::new_async().await;
2032        let mock = server
2033            .mock("GET", "/test")
2034            .with_status(200)
2035            .with_body("success")
2036            .create_async()
2037            .await;
2038
2039        let client = reqwest::Client::new();
2040        let response = client
2041            .get(format!("{}/test", server.url()))
2042            .send()
2043            .await
2044            .unwrap();
2045
2046        let http_client =
2047            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2048                .unwrap()
2049                .with_test_backoff_policy();
2050        let result = http_client
2051            .error_for_response(response)
2052            .await;
2053
2054        mock.assert();
2055        assert!(result.is_ok());
2056
2057        let response = result.unwrap();
2058        assert_eq!(response.status(), 200);
2059    }
2060
2061    #[tokio::test]
2062    async fn test_handle_error_for_backoff_server_unreachable() {
2063        let http_client =
2064            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2065                .unwrap()
2066                .with_test_backoff_policy();
2067        let error = RPCError::ServerUnreachable("Service down".to_string());
2068
2069        let backoff_error = http_client
2070            .handle_error_for_backoff(error)
2071            .await;
2072
2073        match backoff_error {
2074            backoff::Error::Transient { err: RPCError::ServerUnreachable(msg), retry_after } => {
2075                assert_eq!(msg, "Service down");
2076                assert_eq!(retry_after, Some(Duration::from_millis(50))); // Fast test duration
2077            }
2078            _ => panic!("Expected transient error for ServerUnreachable"),
2079        }
2080    }
2081
2082    #[tokio::test]
2083    async fn test_handle_error_for_backoff_rate_limited_with_retry_after() {
2084        let http_client =
2085            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2086                .unwrap()
2087                .with_test_backoff_policy();
2088        let future_time = SystemTime::now() + Duration::from_secs(30);
2089        let error = RPCError::RateLimited(Some(future_time));
2090
2091        let backoff_error = http_client
2092            .handle_error_for_backoff(error)
2093            .await;
2094
2095        match backoff_error {
2096            backoff::Error::Transient { err: RPCError::RateLimited(retry_after), .. } => {
2097                assert_eq!(retry_after, Some(future_time));
2098            }
2099            _ => panic!("Expected transient error for RateLimited"),
2100        }
2101
2102        // Verify that retry_after was stored in the client state
2103        let stored_retry_after = http_client.retry_after.read().await;
2104        assert_eq!(*stored_retry_after, Some(future_time));
2105    }
2106
2107    #[tokio::test]
2108    async fn test_handle_error_for_backoff_rate_limited_no_retry_after() {
2109        let http_client =
2110            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2111                .unwrap()
2112                .with_test_backoff_policy();
2113        let error = RPCError::RateLimited(None);
2114
2115        let backoff_error = http_client
2116            .handle_error_for_backoff(error)
2117            .await;
2118
2119        match backoff_error {
2120            backoff::Error::Transient { err: RPCError::RateLimited(None), .. } => {
2121                // This is expected - no retry-after still allows retries with default policy
2122            }
2123            _ => panic!("Expected transient error for RateLimited without retry-after"),
2124        }
2125    }
2126
2127    #[tokio::test]
2128    async fn test_handle_error_for_backoff_other_errors() {
2129        let http_client =
2130            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2131                .unwrap()
2132                .with_test_backoff_policy();
2133        let error = RPCError::ParseResponse("Invalid JSON".to_string());
2134
2135        let backoff_error = http_client
2136            .handle_error_for_backoff(error)
2137            .await;
2138
2139        match backoff_error {
2140            backoff::Error::Permanent(RPCError::ParseResponse(msg)) => {
2141                assert_eq!(msg, "Invalid JSON");
2142            }
2143            _ => panic!("Expected permanent error for ParseResponse"),
2144        }
2145    }
2146
2147    #[tokio::test]
2148    async fn test_wait_until_retry_after_no_retry_time() {
2149        let http_client =
2150            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2151                .unwrap()
2152                .with_test_backoff_policy();
2153
2154        let start = std::time::Instant::now();
2155        http_client
2156            .wait_until_retry_after()
2157            .await;
2158        let elapsed = start.elapsed();
2159
2160        // Should return immediately if no retry time is set
2161        assert!(elapsed < Duration::from_millis(100));
2162    }
2163
2164    #[tokio::test]
2165    async fn test_wait_until_retry_after_past_time() {
2166        let http_client =
2167            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2168                .unwrap()
2169                .with_test_backoff_policy();
2170
2171        // Set a retry time in the past
2172        let past_time = SystemTime::now() - Duration::from_secs(10);
2173        *http_client.retry_after.write().await = Some(past_time);
2174
2175        let start = std::time::Instant::now();
2176        http_client
2177            .wait_until_retry_after()
2178            .await;
2179        let elapsed = start.elapsed();
2180
2181        // Should return immediately if retry time is in the past
2182        assert!(elapsed < Duration::from_millis(100));
2183    }
2184
2185    #[tokio::test]
2186    async fn test_wait_until_retry_after_future_time() {
2187        let http_client =
2188            HttpRPCClient::new("http://localhost:8080", HttpRPCClientOptions::default())
2189                .unwrap()
2190                .with_test_backoff_policy();
2191
2192        // Set a retry time 100ms in the future
2193        let future_time = SystemTime::now() + Duration::from_millis(100);
2194        *http_client.retry_after.write().await = Some(future_time);
2195
2196        let start = std::time::Instant::now();
2197        http_client
2198            .wait_until_retry_after()
2199            .await;
2200        let elapsed = start.elapsed();
2201
2202        // Should wait approximately the specified duration
2203        assert!(elapsed >= Duration::from_millis(80)); // Allow some tolerance
2204        assert!(elapsed <= Duration::from_millis(200)); // Upper bound for test stability
2205    }
2206
2207    #[tokio::test]
2208    async fn test_make_post_request_success() {
2209        let mut server = Server::new_async().await;
2210        let server_resp = r#"{"success": true}"#;
2211
2212        let mock = server
2213            .mock("POST", "/test")
2214            .with_status(200)
2215            .with_body(server_resp)
2216            .create_async()
2217            .await;
2218
2219        let http_client =
2220            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2221                .unwrap()
2222                .with_test_backoff_policy();
2223        let request_body = serde_json::json!({"test": "data"});
2224        let uri = format!("{}/test", server.url());
2225
2226        let result = http_client
2227            .make_post_request(&request_body, &uri)
2228            .await;
2229
2230        mock.assert();
2231        assert!(result.is_ok());
2232
2233        let response = result.unwrap();
2234        assert_eq!(response.status(), 200);
2235        assert_eq!(response.text().await.unwrap(), server_resp);
2236    }
2237
2238    #[tokio::test]
2239    async fn test_make_post_request_retry_on_server_error() {
2240        let mut server = Server::new_async().await;
2241        // First request fails with 503, second succeeds
2242        let error_mock = server
2243            .mock("POST", "/test")
2244            .with_status(503)
2245            .with_body("Service Unavailable")
2246            .expect(1)
2247            .create_async()
2248            .await;
2249
2250        let success_mock = server
2251            .mock("POST", "/test")
2252            .with_status(200)
2253            .with_body(r#"{"success": true}"#)
2254            .expect(1)
2255            .create_async()
2256            .await;
2257
2258        let http_client =
2259            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2260                .unwrap()
2261                .with_test_backoff_policy();
2262        let request_body = serde_json::json!({"test": "data"});
2263        let uri = format!("{}/test", server.url());
2264
2265        let result = http_client
2266            .make_post_request(&request_body, &uri)
2267            .await;
2268
2269        error_mock.assert();
2270        success_mock.assert();
2271        assert!(result.is_ok());
2272    }
2273
2274    #[tokio::test]
2275    async fn test_make_post_request_respect_retry_after_header() {
2276        let mut server = Server::new_async().await;
2277
2278        // First request returns 429 with retry-after, second succeeds
2279        let rate_limit_mock = server
2280            .mock("POST", "/test")
2281            .with_status(429)
2282            .with_header("Retry-After", "1") // 1 second
2283            .expect(1)
2284            .create_async()
2285            .await;
2286
2287        let success_mock = server
2288            .mock("POST", "/test")
2289            .with_status(200)
2290            .with_body(r#"{"success": true}"#)
2291            .expect(1)
2292            .create_async()
2293            .await;
2294
2295        let http_client =
2296            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2297                .unwrap()
2298                .with_test_backoff_policy();
2299        let request_body = serde_json::json!({"test": "data"});
2300        let uri = format!("{}/test", server.url());
2301
2302        let start = std::time::Instant::now();
2303        let result = http_client
2304            .make_post_request(&request_body, &uri)
2305            .await;
2306        let elapsed = start.elapsed();
2307
2308        rate_limit_mock.assert();
2309        success_mock.assert();
2310        assert!(result.is_ok());
2311
2312        // Should have waited at least 1 second due to retry-after header
2313        assert!(elapsed >= Duration::from_millis(900)); // Allow some tolerance
2314        assert!(elapsed <= Duration::from_millis(2000)); // Upper bound for test stability
2315    }
2316
2317    #[tokio::test]
2318    async fn test_make_post_request_permanent_error() {
2319        let mut server = Server::new_async().await;
2320
2321        let mock = server
2322            .mock("POST", "/test")
2323            .with_status(400) // Bad Request - should not be retried
2324            .with_body("Bad Request")
2325            .expect(1)
2326            .create_async()
2327            .await;
2328
2329        let http_client =
2330            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2331                .unwrap()
2332                .with_test_backoff_policy();
2333        let request_body = serde_json::json!({"test": "data"});
2334        let uri = format!("{}/test", server.url());
2335
2336        let result = http_client
2337            .make_post_request(&request_body, &uri)
2338            .await;
2339
2340        mock.assert();
2341        assert!(result.is_ok()); // 400 doesn't trigger retry logic, just returns the response
2342
2343        let response = result.unwrap();
2344        assert_eq!(response.status(), 400);
2345    }
2346
2347    #[tokio::test]
2348    async fn test_concurrent_requests_with_different_retry_after() {
2349        let mut server = Server::new_async().await;
2350
2351        // First request gets rate limited with 1 second retry-after
2352        let rate_limit_mock_1 = server
2353            .mock("POST", "/test1")
2354            .with_status(429)
2355            .with_header("Retry-After", "1")
2356            .expect(1)
2357            .create_async()
2358            .await;
2359
2360        // Second request gets rate limited with 2 second retry-after
2361        let rate_limit_mock_2 = server
2362            .mock("POST", "/test2")
2363            .with_status(429)
2364            .with_header("Retry-After", "2")
2365            .expect(1)
2366            .create_async()
2367            .await;
2368
2369        // Success mocks for retries
2370        let success_mock_1 = server
2371            .mock("POST", "/test1")
2372            .with_status(200)
2373            .with_body(r#"{"result": "success1"}"#)
2374            .expect(1)
2375            .create_async()
2376            .await;
2377
2378        let success_mock_2 = server
2379            .mock("POST", "/test2")
2380            .with_status(200)
2381            .with_body(r#"{"result": "success2"}"#)
2382            .expect(1)
2383            .create_async()
2384            .await;
2385
2386        let http_client =
2387            HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2388                .unwrap()
2389                .with_test_backoff_policy();
2390        let request_body = serde_json::json!({"test": "data"});
2391
2392        let uri1 = format!("{}/test1", server.url());
2393        let uri2 = format!("{}/test2", server.url());
2394
2395        // Start both requests concurrently
2396        let start = std::time::Instant::now();
2397        let (result1, result2) = tokio::join!(
2398            http_client.make_post_request(&request_body, &uri1),
2399            http_client.make_post_request(&request_body, &uri2)
2400        );
2401        let elapsed = start.elapsed();
2402
2403        rate_limit_mock_1.assert();
2404        rate_limit_mock_2.assert();
2405        success_mock_1.assert();
2406        success_mock_2.assert();
2407
2408        assert!(result1.is_ok());
2409        assert!(result2.is_ok());
2410
2411        // Both requests should succeed, but the second should take longer due to the 2s retry-after
2412        // The total time should be at least 2 seconds since the shared retry_after state
2413        // gets updated by both requests
2414        assert!(elapsed >= Duration::from_millis(1800)); // Allow some tolerance
2415        assert!(elapsed <= Duration::from_millis(3000)); // Upper bound
2416
2417        // Check the final retry_after state - should be the latest (higher) value
2418        let final_retry_after = http_client.retry_after.read().await;
2419        assert!(final_retry_after.is_some());
2420
2421        // The retry_after should be set to the latest (higher) value from the two requests
2422        if let Some(retry_time) = *final_retry_after {
2423            // The retry_after time might be in the past now since we waited,
2424            // but it should be reasonable (not too far in past/future)
2425            let now = SystemTime::now();
2426            let diff = if retry_time > now {
2427                retry_time.duration_since(now).unwrap()
2428            } else {
2429                now.duration_since(retry_time).unwrap()
2430            };
2431
2432            // Should be within a reasonable range (the 2s retry-after plus some buffer)
2433            assert!(diff <= Duration::from_secs(3), "Retry time difference too large: {:?}", diff);
2434        }
2435    }
2436
2437    #[tokio::test]
2438    async fn test_get_snapshots() {
2439        let mut server = Server::new_async().await;
2440
2441        // Mock protocol states response
2442        let protocol_states_resp = r#"
2443        {
2444            "states": [
2445                {
2446                    "component_id": "component1",
2447                    "attributes": {
2448                        "attribute_1": "0x00000000000003e8"
2449                    },
2450                    "balances": {
2451                        "0xC02aaA39b223FE8D0A0e5C4F27eAD9083C756Cc2": "0x01f4"
2452                    }
2453                }
2454            ],
2455            "pagination": {
2456                "page": 0,
2457                "page_size": 100,
2458                "total": 1
2459            }
2460        }
2461        "#;
2462
2463        // Mock contract state response
2464        let contract_state_resp = r#"
2465        {
2466            "accounts": [
2467                {
2468                    "chain": "ethereum",
2469                    "address": "0x1111111111111111111111111111111111111111",
2470                    "title": "",
2471                    "slots": {},
2472                    "native_balance": "0x01f4",
2473                    "token_balances": {},
2474                    "code": "0x00",
2475                    "code_hash": "0x5c06b7c5b3d910fd33bc2229846f9ddaf91d584d9b196e16636901ac3a77077e",
2476                    "balance_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2477                    "code_modify_tx": "0x0000000000000000000000000000000000000000000000000000000000000000",
2478                    "creation_tx": null
2479                }
2480            ],
2481            "pagination": {
2482                "page": 0,
2483                "page_size": 100,
2484                "total": 1
2485            }
2486        }
2487        "#;
2488
2489        // Mock component TVL response
2490        let tvl_resp = r#"
2491        {
2492            "tvl": {
2493                "component1": 1000000.0
2494            },
2495            "pagination": {
2496                "page": 0,
2497                "page_size": 100,
2498                "total": 1
2499            }
2500        }
2501        "#;
2502
2503        let protocol_states_mock = server
2504            .mock("POST", "/v1/protocol_state")
2505            .expect(1)
2506            .with_body(protocol_states_resp)
2507            .create_async()
2508            .await;
2509
2510        let contract_state_mock = server
2511            .mock("POST", "/v1/contract_state")
2512            .expect(1)
2513            .with_body(contract_state_resp)
2514            .create_async()
2515            .await;
2516
2517        let tvl_mock = server
2518            .mock("POST", "/v1/component_tvl")
2519            .expect(1)
2520            .with_body(tvl_resp)
2521            .create_async()
2522            .await;
2523
2524        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2525            .expect("create client");
2526
2527        #[allow(deprecated)]
2528        let component = tycho_common::dto::ProtocolComponent {
2529            id: "component1".to_string(),
2530            protocol_system: "test_protocol".to_string(),
2531            protocol_type_name: "test_type".to_string(),
2532            chain: Chain::Ethereum,
2533            tokens: vec![],
2534            contract_ids: vec![
2535                Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()
2536            ],
2537            static_attributes: HashMap::new(),
2538            change: tycho_common::dto::ChangeType::Creation,
2539            creation_tx: Bytes::from_str(
2540                "0x0000000000000000000000000000000000000000000000000000000000000000",
2541            )
2542            .unwrap(),
2543            created_at: chrono::Utc::now().naive_utc(),
2544        };
2545
2546        let mut components = HashMap::new();
2547        components.insert("component1".to_string(), component);
2548
2549        let contract_ids =
2550            vec![Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap()];
2551
2552        let request = SnapshotParameters::new(
2553            Chain::Ethereum,
2554            "test_protocol",
2555            &components,
2556            &contract_ids,
2557            12345,
2558        );
2559
2560        let response = client
2561            .get_snapshots(&request, 100, RPC_CLIENT_CONCURRENCY)
2562            .await
2563            .expect("get snapshots");
2564
2565        // Verify all mocks were called
2566        protocol_states_mock.assert();
2567        contract_state_mock.assert();
2568        tvl_mock.assert();
2569
2570        // Assert states
2571        assert_eq!(response.states.len(), 1);
2572        assert!(response
2573            .states
2574            .contains_key("component1"));
2575
2576        // Check that the state has the expected TVL
2577        let component_state = response
2578            .states
2579            .get("component1")
2580            .unwrap();
2581        assert_eq!(component_state.component_tvl, Some(1000000.0));
2582
2583        // Assert VM storage
2584        assert_eq!(response.vm_storage.len(), 1);
2585        let contract_addr = Bytes::from_str("0x1111111111111111111111111111111111111111").unwrap();
2586        assert!(response
2587            .vm_storage
2588            .contains_key(&contract_addr));
2589    }
2590
2591    #[tokio::test]
2592    async fn test_get_snapshots_empty_components() {
2593        let server = Server::new_async().await;
2594        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2595            .expect("create client");
2596
2597        let components = HashMap::new();
2598        let contract_ids = vec![];
2599
2600        let request = SnapshotParameters::new(
2601            Chain::Ethereum,
2602            "test_protocol",
2603            &components,
2604            &contract_ids,
2605            12345,
2606        );
2607
2608        let response = client
2609            .get_snapshots(&request, 100, RPC_CLIENT_CONCURRENCY)
2610            .await
2611            .expect("get snapshots");
2612
2613        // Should return empty response without making any requests
2614        assert!(response.states.is_empty());
2615        assert!(response.vm_storage.is_empty());
2616    }
2617
2618    #[tokio::test]
2619    async fn test_get_snapshots_without_tvl() {
2620        let mut server = Server::new_async().await;
2621
2622        let protocol_states_resp = r#"
2623        {
2624            "states": [
2625                {
2626                    "component_id": "component1",
2627                    "attributes": {},
2628                    "balances": {}
2629                }
2630            ],
2631            "pagination": {
2632                "page": 0,
2633                "page_size": 100,
2634                "total": 1
2635            }
2636        }
2637        "#;
2638
2639        let protocol_states_mock = server
2640            .mock("POST", "/v1/protocol_state")
2641            .expect(1)
2642            .with_body(protocol_states_resp)
2643            .create_async()
2644            .await;
2645
2646        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2647            .expect("create client");
2648
2649        // Create test component
2650        #[allow(deprecated)]
2651        let component = tycho_common::dto::ProtocolComponent {
2652            id: "component1".to_string(),
2653            protocol_system: "test_protocol".to_string(),
2654            protocol_type_name: "test_type".to_string(),
2655            chain: Chain::Ethereum,
2656            tokens: vec![],
2657            contract_ids: vec![],
2658            static_attributes: HashMap::new(),
2659            change: tycho_common::dto::ChangeType::Creation,
2660            creation_tx: Bytes::from_str(
2661                "0x0000000000000000000000000000000000000000000000000000000000000000",
2662            )
2663            .unwrap(),
2664            created_at: chrono::Utc::now().naive_utc(),
2665        };
2666
2667        let mut components = HashMap::new();
2668        components.insert("component1".to_string(), component);
2669        let contract_ids = vec![];
2670
2671        let request = SnapshotParameters::new(
2672            Chain::Ethereum,
2673            "test_protocol",
2674            &components,
2675            &contract_ids,
2676            12345,
2677        )
2678        .include_balances(false)
2679        .include_tvl(false);
2680
2681        let response = client
2682            .get_snapshots(&request, 100, RPC_CLIENT_CONCURRENCY)
2683            .await
2684            .expect("get snapshots");
2685
2686        // Verify only necessary mocks were called
2687        protocol_states_mock.assert();
2688        // No contract_state_mock.assert() since contract_ids is empty
2689        // No tvl_mock.assert() since include_tvl is false
2690
2691        assert_eq!(response.states.len(), 1);
2692        // Check that TVL is None since we didn't request it
2693        let component_state = response
2694            .states
2695            .get("component1")
2696            .unwrap();
2697        assert_eq!(component_state.component_tvl, None);
2698    }
2699
2700    #[tokio::test]
2701    async fn test_compression_enabled() {
2702        let mut server = Server::new_async().await;
2703        let server_resp = GET_CONTRACT_STATE_RESP;
2704
2705        // Compress the response using zstd
2706        let compressed_body =
2707            zstd::encode_all(server_resp.as_bytes(), 0).expect("compression failed");
2708
2709        let mocked_server = server
2710            .mock("POST", "/v1/contract_state")
2711            .expect(1)
2712            .with_header("Content-Encoding", "zstd")
2713            .with_body(compressed_body)
2714            .create_async()
2715            .await;
2716
2717        // Create client with compression enabled
2718        let client = HttpRPCClient::new(
2719            server.url().as_str(),
2720            HttpRPCClientOptions::new().with_compression(true),
2721        )
2722        .expect("create client");
2723
2724        let response = client
2725            .get_contract_state(&Default::default())
2726            .await
2727            .expect("get state");
2728        let accounts = response.accounts;
2729
2730        mocked_server.assert();
2731        assert_eq!(accounts.len(), 1);
2732        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2733    }
2734
2735    #[tokio::test]
2736    async fn test_compression_disabled() {
2737        let mut server = Server::new_async().await;
2738        let server_resp = GET_CONTRACT_STATE_RESP;
2739
2740        // Verify client does NOT send Accept-Encoding: zstd when compression is disabled
2741        // Instead, server should receive request without compression headers
2742        let mocked_server = server
2743            .mock("POST", "/v1/contract_state")
2744            .expect(1)
2745            .match_header("Accept-Encoding", mockito::Matcher::Missing)
2746            .with_status(200)
2747            .with_body(server_resp)
2748            .create_async()
2749            .await;
2750
2751        // Create client with compression disabled
2752        let client = HttpRPCClient::new(
2753            server.url().as_str(),
2754            HttpRPCClientOptions::new().with_compression(false),
2755        )
2756        .expect("create client");
2757
2758        let response = client
2759            .get_contract_state(&Default::default())
2760            .await
2761            .expect("get state");
2762        let accounts = response.accounts;
2763
2764        // Verify the mock was called (client sent request without Accept-Encoding header)
2765        mocked_server.assert();
2766        assert_eq!(accounts.len(), 1);
2767        assert_eq!(accounts[0].native_balance, Bytes::from(500u16.to_be_bytes()));
2768    }
2769
2770    #[rstest]
2771    #[case::single_page(2, 10, 1000)]
2772    #[case::multiple_pages_within_concurrency(10, 10, 2)]
2773    #[case::exceeds_concurrency_limit(60, 10, 2)]
2774    #[tokio::test]
2775    async fn test_get_all_tokens_pagination_and_concurrency(
2776        #[case] total_tokens: usize,
2777        #[case] allowed_concurrency: usize,
2778        #[case] page_size: usize,
2779    ) {
2780        use std::sync::atomic::{AtomicUsize, Ordering};
2781
2782        let concurrent_requests = Arc::new(AtomicUsize::new(0));
2783        let max_concurrent = Arc::new(AtomicUsize::new(0));
2784
2785        let mut server = Server::new_async().await;
2786
2787        let total_pages = (total_tokens as f64 / page_size as f64).ceil() as i64;
2788
2789        // Mock all required pages
2790        for page in 0..total_pages {
2791            let concurrent = concurrent_requests.clone();
2792            let max_conc = max_concurrent.clone();
2793
2794            let tokens_in_page = {
2795                let start_idx = (page as usize) * page_size;
2796                let end_idx = ((page as usize + 1) * page_size).min(total_tokens);
2797                (start_idx..end_idx)
2798                    .map(|i| {
2799                        format!(
2800                            r#"{{
2801                            "chain": "ethereum",
2802                            "address": "0x{i:040x}",
2803                            "symbol": "TOKEN_{i}",
2804                            "decimals": 18,
2805                            "tax": 0,
2806                            "gas": [30000],
2807                            "quality": 100
2808                        }}"#
2809                        )
2810                    })
2811                    .collect::<Vec<_>>()
2812            };
2813
2814            let tokens_json = tokens_in_page.join(",");
2815            let response = format!(
2816                r#"{{
2817                    "tokens": [{tokens_json}],
2818                    "pagination": {{
2819                        "page": {page},
2820                        "page_size": {page_size},
2821                        "total": {total_tokens}
2822                    }}
2823                }}"#,
2824            );
2825
2826            server
2827                .mock("POST", "/v1/tokens")
2828                .expect(1)
2829                .with_chunked_body(move |w| {
2830                    // Track concurrent requests
2831                    let current = concurrent.fetch_add(1, Ordering::SeqCst);
2832                    max_conc.fetch_max(current + 1, Ordering::SeqCst);
2833
2834                    // Simulate some work to increase likelihood of concurrent requests
2835                    std::thread::sleep(Duration::from_millis(10));
2836
2837                    concurrent.fetch_sub(1, Ordering::SeqCst);
2838
2839                    w.write_all(response.as_bytes())
2840                })
2841                .create_async()
2842                .await;
2843        }
2844
2845        let client = HttpRPCClient::new(server.url().as_str(), HttpRPCClientOptions::default())
2846            .expect("create client");
2847
2848        let tokens = client
2849            .get_all_tokens(Chain::Ethereum, None, None, page_size, allowed_concurrency)
2850            .await
2851            .expect("get all tokens");
2852
2853        // Verify concurrency was respected
2854        let max = max_concurrent.load(Ordering::SeqCst);
2855        let expected_max_concurrency = (total_pages as usize)
2856            .saturating_sub(1)
2857            .min(allowed_concurrency);
2858        assert!(
2859            max <= allowed_concurrency,
2860            "Expected max concurrent requests <= {allowed_concurrency}, got {max}"
2861        );
2862
2863        // For cases with multiple pages, verify we actually used concurrency
2864        if total_pages > 1 && expected_max_concurrency > 1 {
2865            assert!(
2866                max > 0,
2867                "Expected some concurrent requests for multi-page response, got {max}"
2868            );
2869        }
2870
2871        // Verify we got all expected tokens
2872        assert_eq!(
2873            tokens.len(),
2874            total_tokens,
2875            "Expected {total_tokens} tokens, got {}",
2876            tokens.len()
2877        );
2878
2879        // Verify tokens are in the expected order
2880        for (i, token) in tokens.iter().enumerate() {
2881            assert_eq!(token.symbol, format!("TOKEN_{i}"), "Token at index {i} has wrong symbol");
2882        }
2883    }
2884}