alloy_transport/layers/
throttle.rs1use crate::{TransportError, TransportFut};
2use alloy_json_rpc::{RequestPacket, ResponsePacket};
3use governor::{
4 clock::{QuantaClock, QuantaInstant},
5 middleware::NoOpMiddleware,
6 state::{InMemoryState, NotKeyed},
7 Quota, RateLimiter,
8};
9use std::{
10 num::NonZeroU32,
11 sync::Arc,
12 task::{Context, Poll},
13};
14use tower::{Layer, Service};
15
16type Throttle = RateLimiter<NotKeyed, InMemoryState, QuantaClock, NoOpMiddleware<QuantaInstant>>;
18
19#[derive(Debug)]
21pub struct ThrottleLayer {
22 pub throttle: Arc<Throttle>,
24}
25
26impl ThrottleLayer {
27 pub fn new(requests_per_second: u32) -> Self {
33 let quota = Quota::per_second(
34 NonZeroU32::new(requests_per_second)
35 .expect("Request per second must be greater than 0"),
36 )
37 .allow_burst(NonZeroU32::new(1).unwrap());
38 let throttle = Arc::new(RateLimiter::direct(quota));
39
40 Self { throttle }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct ThrottleService<S> {
47 inner: S,
49 throttle: Arc<Throttle>,
50}
51
52impl<S> Layer<S> for ThrottleLayer {
53 type Service = ThrottleService<S>;
54
55 fn layer(&self, inner: S) -> Self::Service {
56 ThrottleService { inner, throttle: self.throttle.clone() }
57 }
58}
59
60impl<S> Service<RequestPacket> for ThrottleService<S>
61where
62 S: Service<RequestPacket, Response = ResponsePacket, Error = TransportError>
63 + Send
64 + 'static
65 + Clone,
66 S::Future: Send + 'static,
67{
68 type Response = ResponsePacket;
69 type Error = TransportError;
70 type Future = TransportFut<'static>;
71
72 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
73 self.inner.poll_ready(cx)
74 }
75
76 fn call(&mut self, request: RequestPacket) -> Self::Future {
77 let throttle = self.throttle.clone();
78 let mut inner = self.inner.clone();
79
80 Box::pin(async move {
81 throttle.until_ready().await;
82 inner.call(request).await
83 })
84 }
85}