Skip to main content

miden_node_utils/grpc/
layers.rs

1use std::net::{IpAddr, SocketAddr};
2use std::task::{Context as TaskContext, Poll};
3use std::time::Duration;
4
5use anyhow::{Context, ensure};
6use governor::middleware::StateInformationMiddleware;
7use tower::limit::GlobalConcurrencyLimitLayer;
8use tower::{Layer, Service};
9use tower_governor::GovernorError;
10use tower_governor::governor::GovernorConfigBuilder;
11use tower_governor::key_extractor::{KeyExtractor, SmartIpKeyExtractor};
12
13use crate::clap::GrpcOptionsExternal;
14
15/// Builds a global concurrency limit layer using the configured semaphore.
16pub fn rate_limit_concurrent_connections(
17    grpc_options: GrpcOptionsExternal,
18) -> GlobalConcurrencyLimitLayer {
19    tower::limit::GlobalConcurrencyLimitLayer::new(grpc_options.max_concurrent_connections as usize)
20}
21
22/// Creates a per-IP rate limit layer using the configured governor settings.
23pub fn rate_limit_per_ip(
24    grpc_options: GrpcOptionsExternal,
25) -> anyhow::Result<
26    tower_governor::GovernorLayer<GrpcIpExtractor, StateInformationMiddleware, tonic::body::Body>,
27> {
28    let nanos_per_replenish = Duration::from_secs(1)
29        .as_nanos()
30        .checked_div(u128::from(grpc_options.replenish_n_per_second_per_ip.get()))
31        .unwrap_or_default();
32    ensure!(
33        nanos_per_replenish > 0,
34        "grpc.replenish_n_per_second must be less than or equal to 1e9"
35    );
36    let replenish_period = Duration::from_nanos(
37        u64::try_from(nanos_per_replenish).context("invalid gRPC rate limit configuration")?,
38    );
39    let config = GovernorConfigBuilder::default()
40        .key_extractor(GrpcIpExtractor::default())
41        .period(replenish_period)
42        .burst_size(grpc_options.burst_size.into())
43        .use_headers()
44        .finish()
45        .context("invalid gRPC rate limit configuration")?;
46    let limiter = std::sync::Arc::clone(config.limiter());
47    tokio::spawn(async move {
48        let mut interval = tokio::time::interval(Duration::from_secs(60));
49        loop {
50            interval.tick().await;
51            // avoid a DoS vector
52            limiter.retain_recent();
53        }
54    });
55    Ok(tower_governor::GovernorLayer::new(config))
56}
57
58/// The originating client IP, resolved by [`ResolveClientIpLayer`] and stored in a request's
59/// extensions.
60///
61/// gRPC handlers can read this via `request.extensions().get::<ClientIp>()` to obtain the
62/// load-balancer-aware client address without re-implementing IP extraction.
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct ClientIp(pub IpAddr);
65
66impl ClientIp {
67    /// Returns the client IP resolved for `request` by [`ResolveClientIpLayer`], or `None` if it
68    /// could not be determined.
69    pub fn from_request<T>(request: &tonic::Request<T>) -> Option<IpAddr> {
70        request.extensions().get::<Self>().map(|ip| ip.0)
71    }
72}
73
74/// A [`tower::Layer`] that resolves the originating client IP and stores it in the request's
75/// extensions as [`ClientIp`].
76///
77/// IP resolution reuses [`GrpcIpExtractor`], so clients behind a load balancer or reverse proxy are
78/// identified by their forwarded IP (via `X-Forwarded-For` / `X-Real-Ip` / `Forwarded` headers),
79/// falling back to the peer address. Resolving once at the transport layer keeps this consistent
80/// with the per-IP rate limiter and lets handlers read the result instead of re-deriving it.
81#[derive(Debug, Clone, Copy, Default)]
82pub struct ResolveClientIpLayer;
83
84impl<S> Layer<S> for ResolveClientIpLayer {
85    type Service = ResolveClientIp<S>;
86
87    fn layer(&self, inner: S) -> Self::Service {
88        ResolveClientIp { inner }
89    }
90}
91
92/// The service produced by [`ResolveClientIpLayer`].
93#[derive(Debug, Clone, Copy)]
94pub struct ResolveClientIp<S> {
95    inner: S,
96}
97
98impl<S, B> Service<http::Request<B>> for ResolveClientIp<S>
99where
100    S: Service<http::Request<B>>,
101{
102    type Response = S::Response;
103    type Error = S::Error;
104    type Future = S::Future;
105
106    fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
107        self.inner.poll_ready(cx)
108    }
109
110    fn call(&mut self, mut request: http::Request<B>) -> Self::Future {
111        if let Ok(ip) = GrpcIpExtractor::default().extract(&request) {
112            request.extensions_mut().insert(ClientIp(ip));
113        }
114        self.inner.call(request)
115    }
116}
117
118/// Wraps [`SmartIpKeyExtractor`] by providing a fallback to the client IP address provided by the
119/// gRPC transport.
120///
121/// [`SmartIpKeyExtractor`]'s own fallback of checking the peer IP directly fails because we are in
122/// a gRPC transport and not the typical `SocketAddr` as it expects.
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub struct GrpcIpExtractor(SmartIpKeyExtractor);
125
126impl Default for GrpcIpExtractor {
127    fn default() -> Self {
128        Self(SmartIpKeyExtractor)
129    }
130}
131
132impl GrpcIpExtractor {
133    #[expect(clippy::result_large_err, reason = "this is a third party error type")]
134    fn extract_tonic_address<T>(
135        request: &http::Request<T>,
136    ) -> Result<<Self as KeyExtractor>::Key, GovernorError> {
137        request
138            .extensions()
139            .get::<tonic::transport::server::TcpConnectInfo>()
140            .and_then(tonic::transport::server::TcpConnectInfo::remote_addr)
141            .as_ref()
142            .map(SocketAddr::ip)
143            .ok_or(GovernorError::UnableToExtractKey)
144    }
145}
146
147impl KeyExtractor for GrpcIpExtractor {
148    type Key = IpAddr;
149
150    fn extract<T>(
151        &self,
152        request: &http::Request<T>,
153    ) -> Result<Self::Key, tower_governor::GovernorError> {
154        self.0.extract(request).or_else(|_| Self::extract_tonic_address(request))
155    }
156}