alloy_transport/layers/
throttle.rs

1use crate::{TransportError, TransportFut};
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use governor::{
4    clock::{QuantaClock, QuantaInstant},
5    middleware::NoOpMiddleware,
6    state::{InMemoryState, NotKeyed},
7    Quota, RateLimiter,
8};
9use std::{
10    num::NonZeroU32,
11    sync::Arc,
12    task::{Context, Poll},
13};
14use tower::{Layer, Service};
15
16/// A rate limiter for throttling RPC requests.
17type Throttle = RateLimiter<NotKeyed, InMemoryState, QuantaClock, NoOpMiddleware<QuantaInstant>>;
18
19/// A Transport Layer responsible for throttling RPC requests.
20#[derive(Debug)]
21pub struct ThrottleLayer {
22    /// Rate limiter used to throttle requests.
23    pub throttle: Arc<Throttle>,
24}
25
26impl ThrottleLayer {
27    /// Creates a new throttle layer with the specified requests per second.
28    ///
29    /// # Panics
30    ///
31    /// Panics if `requests_per_second` is 0.
32    pub fn new(requests_per_second: u32) -> Self {
33        let quota = Quota::per_second(
34            NonZeroU32::new(requests_per_second)
35                .expect("Request per second must be greater than 0"),
36        )
37        .allow_burst(NonZeroU32::new(1).unwrap());
38        let throttle = Arc::new(RateLimiter::direct(quota));
39
40        Self { throttle }
41    }
42}
43
44/// A Tower Service used by the ThrottleLayer that is responsible for throttling rpc requests.
45#[derive(Debug, Clone)]
46pub struct ThrottleService<S> {
47    /// The inner service
48    inner: S,
49    throttle: Arc<Throttle>,
50}
51
52impl<S> Layer<S> for ThrottleLayer {
53    type Service = ThrottleService<S>;
54
55    fn layer(&self, inner: S) -> Self::Service {
56        ThrottleService { inner, throttle: self.throttle.clone() }
57    }
58}
59
60impl<S> Service<RequestPacket> for ThrottleService<S>
61where
62    S: Service<RequestPacket, Response = ResponsePacket, Error = TransportError>
63        + Send
64        + 'static
65        + Clone,
66    S::Future: Send + 'static,
67{
68    type Response = ResponsePacket;
69    type Error = TransportError;
70    type Future = TransportFut<'static>;
71
72    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73        self.inner.poll_ready(cx)
74    }
75
76    fn call(&mut self, request: RequestPacket) -> Self::Future {
77        let throttle = self.throttle.clone();
78        let mut inner = self.inner.clone();
79
80        Box::pin(async move {
81            throttle.until_ready().await;
82            inner.call(request).await
83        })
84    }
85}