forester_utils/
rpc_pool.rs

1use std::{cmp::min, time::Duration};
2
3use async_trait::async_trait;
4use bb8::{Pool, PooledConnection};
5use light_client::rpc::{LightClientConfig, Rpc, RpcError};
6use solana_sdk::commitment_config::CommitmentConfig;
7use thiserror::Error;
8use tokio::time::sleep;
9use tracing::{error, trace, warn};
10
11use crate::rate_limiter::RateLimiter;
12
13#[derive(Error, Debug)]
14pub enum PoolError {
15    #[error("Failed to create RPC client: {0}")]
16    ClientCreation(String),
17    #[error("RPC request failed: {0}")]
18    RpcRequest(#[from] RpcError),
19    #[error("Pool error: {0}")]
20    Pool(String),
21    #[error("Failed to get connection after {0} retries: {1}")]
22    MaxRetriesExceeded(u32, String),
23    #[error("Missing required field for RpcPoolBuilder: {0}")]
24    BuilderMissingField(String),
25}
26
27pub struct SolanaConnectionManager<R: Rpc + 'static> {
28    url: String,
29    commitment: CommitmentConfig,
30    // TODO: implement Rpc for SolanaConnectionManager and rate limit requests.
31    _rpc_rate_limiter: Option<RateLimiter>,
32    _send_tx_rate_limiter: Option<RateLimiter>,
33    _phantom: std::marker::PhantomData<R>,
34}
35
36impl<R: Rpc + 'static> SolanaConnectionManager<R> {
37    pub fn new(
38        url: String,
39        commitment: CommitmentConfig,
40        rpc_rate_limiter: Option<RateLimiter>,
41        send_tx_rate_limiter: Option<RateLimiter>,
42    ) -> Self {
43        Self {
44            url,
45            commitment,
46            _rpc_rate_limiter: rpc_rate_limiter,
47            _send_tx_rate_limiter: send_tx_rate_limiter,
48            _phantom: std::marker::PhantomData,
49        }
50    }
51}
52
53#[async_trait]
54impl<R: Rpc + 'static> bb8::ManageConnection for SolanaConnectionManager<R> {
55    type Connection = R;
56    type Error = PoolError;
57
58    async fn connect(&self) -> Result<Self::Connection, Self::Error> {
59        let config = LightClientConfig {
60            url: self.url.to_string(),
61            commitment_config: Some(self.commitment),
62            with_indexer: false,
63            fetch_active_tree: false,
64        };
65
66        Ok(R::new(config).await?)
67    }
68
69    async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> {
70        conn.health().await.map_err(PoolError::RpcRequest)
71    }
72
73    fn has_broken(&self, _conn: &mut Self::Connection) -> bool {
74        false
75    }
76}
77
78#[derive(Debug)]
79pub struct SolanaRpcPool<R: Rpc + 'static> {
80    pool: Pool<SolanaConnectionManager<R>>,
81    max_retries: u32,
82    initial_retry_delay: Duration,
83    max_retry_delay: Duration,
84}
85
86#[derive(Debug)]
87pub struct SolanaRpcPoolBuilder<R: Rpc> {
88    url: Option<String>,
89    commitment: Option<CommitmentConfig>,
90
91    max_size: u32,
92    connection_timeout_secs: u64,
93    idle_timeout_secs: u64,
94    max_retries: u32,
95    initial_retry_delay_ms: u64,
96    max_retry_delay_ms: u64,
97
98    rpc_rate_limiter: Option<RateLimiter>,
99    send_tx_rate_limiter: Option<RateLimiter>,
100    _phantom: std::marker::PhantomData<R>,
101}
102
103impl<R: Rpc> Default for SolanaRpcPoolBuilder<R> {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109impl<R: Rpc> SolanaRpcPoolBuilder<R> {
110    pub fn new() -> Self {
111        Self {
112            url: None,
113            commitment: None,
114            max_size: 50,
115            connection_timeout_secs: 15,
116            idle_timeout_secs: 300,
117            max_retries: 3,
118            initial_retry_delay_ms: 1000,
119            max_retry_delay_ms: 16000,
120            rpc_rate_limiter: None,
121            send_tx_rate_limiter: None,
122            _phantom: std::marker::PhantomData,
123        }
124    }
125
126    pub fn url(mut self, url: String) -> Self {
127        self.url = Some(url);
128        self
129    }
130
131    pub fn commitment(mut self, commitment: CommitmentConfig) -> Self {
132        self.commitment = Some(commitment);
133        self
134    }
135
136    pub fn max_size(mut self, max_size: u32) -> Self {
137        self.max_size = max_size;
138        self
139    }
140
141    pub fn connection_timeout_secs(mut self, secs: u64) -> Self {
142        self.connection_timeout_secs = secs;
143        self
144    }
145
146    pub fn idle_timeout_secs(mut self, secs: u64) -> Self {
147        self.idle_timeout_secs = secs;
148        self
149    }
150
151    pub fn max_retries(mut self, retries: u32) -> Self {
152        self.max_retries = retries;
153        self
154    }
155
156    pub fn initial_retry_delay_ms(mut self, ms: u64) -> Self {
157        self.initial_retry_delay_ms = ms;
158        self
159    }
160
161    pub fn max_retry_delay_ms(mut self, ms: u64) -> Self {
162        self.max_retry_delay_ms = ms;
163        self
164    }
165
166    pub fn rpc_rate_limiter(mut self, limiter: RateLimiter) -> Self {
167        self.rpc_rate_limiter = Some(limiter);
168        self
169    }
170
171    pub fn send_tx_rate_limiter(mut self, limiter: RateLimiter) -> Self {
172        self.send_tx_rate_limiter = Some(limiter);
173        self
174    }
175
176    pub async fn build(self) -> Result<SolanaRpcPool<R>, PoolError> {
177        let url = self
178            .url
179            .ok_or_else(|| PoolError::BuilderMissingField("url".to_string()))?;
180        let commitment = self
181            .commitment
182            .ok_or_else(|| PoolError::BuilderMissingField("commitment".to_string()))?;
183
184        let manager = SolanaConnectionManager::new(
185            url,
186            commitment,
187            self.rpc_rate_limiter,
188            self.send_tx_rate_limiter,
189        );
190
191        let pool = Pool::builder()
192            .max_size(self.max_size)
193            .connection_timeout(Duration::from_secs(self.connection_timeout_secs))
194            .idle_timeout(Some(Duration::from_secs(self.idle_timeout_secs)))
195            .build(manager)
196            .await
197            .map_err(|e| PoolError::Pool(e.to_string()))?;
198
199        Ok(SolanaRpcPool {
200            pool,
201            max_retries: self.max_retries,
202            initial_retry_delay: Duration::from_millis(self.initial_retry_delay_ms),
203            max_retry_delay: Duration::from_millis(self.max_retry_delay_ms),
204        })
205    }
206}
207
208impl<R: Rpc> SolanaRpcPool<R> {
209    pub async fn get_connection(
210        &self,
211    ) -> Result<PooledConnection<'_, SolanaConnectionManager<R>>, PoolError> {
212        let mut current_retries = 0;
213        let mut current_delay = self.initial_retry_delay;
214
215        loop {
216            trace!(
217                "Attempting to get RPC connection... (Attempt {})",
218                current_retries + 1
219            );
220            match self.pool.get().await {
221                Ok(conn) => {
222                    trace!(
223                        "Successfully got RPC connection (Attempt {})",
224                        current_retries + 1
225                    );
226                    return Ok(conn);
227                }
228                Err(e) => {
229                    error!(
230                        "Failed to get RPC connection (Attempt {}): {:?}",
231                        current_retries + 1,
232                        e
233                    );
234                    if current_retries < self.max_retries {
235                        current_retries += 1;
236                        warn!(
237                            "Retrying to get RPC connection in {:?} (Attempt {}/{})",
238                            current_delay,
239                            current_retries + 1,
240                            self.max_retries + 1
241                        );
242                        tokio::task::yield_now().await;
243                        sleep(current_delay).await;
244                        current_delay = min(current_delay * 2, self.max_retry_delay);
245                    } else {
246                        error!(
247                            "Failed to get RPC connection after {} attempts. Last error: {:?}",
248                            self.max_retries + 1,
249                            e
250                        );
251                        return Err(PoolError::MaxRetriesExceeded(
252                            self.max_retries + 1,
253                            e.to_string(),
254                        ));
255                    }
256                }
257            }
258        }
259    }
260}