Skip to main content

miden_node_utils/grpc/
layers.rs

1use std::net::{IpAddr, SocketAddr};
2use std::time::Duration;
3
4use anyhow::{Context, ensure};
5use governor::middleware::StateInformationMiddleware;
6use tower::limit::GlobalConcurrencyLimitLayer;
7use tower_governor::GovernorError;
8use tower_governor::governor::GovernorConfigBuilder;
9use tower_governor::key_extractor::{KeyExtractor, SmartIpKeyExtractor};
10
11use crate::clap::GrpcOptionsExternal;
12
13/// Builds a global concurrency limit layer using the configured semaphore.
14pub fn rate_limit_concurrent_connections(
15    grpc_options: GrpcOptionsExternal,
16) -> GlobalConcurrencyLimitLayer {
17    tower::limit::GlobalConcurrencyLimitLayer::new(grpc_options.max_concurrent_connections as usize)
18}
19
20/// Creates a per-IP rate limit layer using the configured governor settings.
21pub fn rate_limit_per_ip(
22    grpc_options: GrpcOptionsExternal,
23) -> anyhow::Result<
24    tower_governor::GovernorLayer<GrpcIpExtractor, StateInformationMiddleware, tonic::body::Body>,
25> {
26    let nanos_per_replenish = Duration::from_secs(1)
27        .as_nanos()
28        .checked_div(u128::from(grpc_options.replenish_n_per_second_per_ip.get()))
29        .unwrap_or_default();
30    ensure!(
31        nanos_per_replenish > 0,
32        "grpc.replenish_n_per_second must be less than or equal to 1e9"
33    );
34    let replenish_period = Duration::from_nanos(
35        u64::try_from(nanos_per_replenish).context("invalid gRPC rate limit configuration")?,
36    );
37    let config = GovernorConfigBuilder::default()
38        .key_extractor(GrpcIpExtractor::default())
39        .period(replenish_period)
40        .burst_size(grpc_options.burst_size.into())
41        .use_headers()
42        .finish()
43        .context("invalid gRPC rate limit configuration")?;
44    let limiter = std::sync::Arc::clone(config.limiter());
45    tokio::spawn(async move {
46        let mut interval = tokio::time::interval(Duration::from_secs(60));
47        loop {
48            interval.tick().await;
49            // avoid a DoS vector
50            limiter.retain_recent();
51        }
52    });
53    Ok(tower_governor::GovernorLayer::new(config))
54}
55
56/// Wraps [`SmartIpKeyExtractor`] by providing a fallback to the client IP address provided by the
57/// gRPC transport.
58///
59/// [`SmartIpKeyExtractor`]'s own fallback of checking the peer IP directly fails because we are in
60/// a gRPC transport and not the typical `SocketAddr` as it expects.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub struct GrpcIpExtractor(SmartIpKeyExtractor);
63
64impl Default for GrpcIpExtractor {
65    fn default() -> Self {
66        Self(SmartIpKeyExtractor)
67    }
68}
69
70impl GrpcIpExtractor {
71    #[expect(clippy::result_large_err, reason = "this is a third party error type")]
72    fn extract_tonic_address<T>(
73        request: &http::Request<T>,
74    ) -> Result<<Self as KeyExtractor>::Key, GovernorError> {
75        request
76            .extensions()
77            .get::<tonic::transport::server::TcpConnectInfo>()
78            .and_then(tonic::transport::server::TcpConnectInfo::remote_addr)
79            .as_ref()
80            .map(SocketAddr::ip)
81            .ok_or(GovernorError::UnableToExtractKey)
82    }
83}
84
85impl KeyExtractor for GrpcIpExtractor {
86    type Key = IpAddr;
87
88    fn extract<T>(
89        &self,
90        request: &http::Request<T>,
91    ) -> Result<Self::Key, tower_governor::GovernorError> {
92        self.0.extract(request).or_else(|_| Self::extract_tonic_address(request))
93    }
94}