miden_node_utils/grpc/
layers.rs1use std::net::{IpAddr, SocketAddr};
2use std::task::{Context as TaskContext, Poll};
3use std::time::Duration;
4
5use anyhow::{Context, ensure};
6use governor::middleware::StateInformationMiddleware;
7use tower::limit::GlobalConcurrencyLimitLayer;
8use tower::{Layer, Service};
9use tower_governor::GovernorError;
10use tower_governor::governor::GovernorConfigBuilder;
11use tower_governor::key_extractor::{KeyExtractor, SmartIpKeyExtractor};
12
13use crate::clap::GrpcOptionsExternal;
14
15pub fn rate_limit_concurrent_connections(
17 grpc_options: GrpcOptionsExternal,
18) -> GlobalConcurrencyLimitLayer {
19 tower::limit::GlobalConcurrencyLimitLayer::new(grpc_options.max_concurrent_connections as usize)
20}
21
22pub fn rate_limit_per_ip(
24 grpc_options: GrpcOptionsExternal,
25) -> anyhow::Result<
26 tower_governor::GovernorLayer<GrpcIpExtractor, StateInformationMiddleware, tonic::body::Body>,
27> {
28 let nanos_per_replenish = Duration::from_secs(1)
29 .as_nanos()
30 .checked_div(u128::from(grpc_options.replenish_n_per_second_per_ip.get()))
31 .unwrap_or_default();
32 ensure!(
33 nanos_per_replenish > 0,
34 "grpc.replenish_n_per_second must be less than or equal to 1e9"
35 );
36 let replenish_period = Duration::from_nanos(
37 u64::try_from(nanos_per_replenish).context("invalid gRPC rate limit configuration")?,
38 );
39 let config = GovernorConfigBuilder::default()
40 .key_extractor(GrpcIpExtractor::default())
41 .period(replenish_period)
42 .burst_size(grpc_options.burst_size.into())
43 .use_headers()
44 .finish()
45 .context("invalid gRPC rate limit configuration")?;
46 let limiter = std::sync::Arc::clone(config.limiter());
47 tokio::spawn(async move {
48 let mut interval = tokio::time::interval(Duration::from_secs(60));
49 loop {
50 interval.tick().await;
51 limiter.retain_recent();
53 }
54 });
55 Ok(tower_governor::GovernorLayer::new(config))
56}
57
58#[derive(Debug, Clone, Copy, PartialEq, Eq)]
64pub struct ClientIp(pub IpAddr);
65
66impl ClientIp {
67 pub fn from_request<T>(request: &tonic::Request<T>) -> Option<IpAddr> {
70 request.extensions().get::<Self>().map(|ip| ip.0)
71 }
72}
73
74#[derive(Debug, Clone, Copy, Default)]
82pub struct ResolveClientIpLayer;
83
84impl<S> Layer<S> for ResolveClientIpLayer {
85 type Service = ResolveClientIp<S>;
86
87 fn layer(&self, inner: S) -> Self::Service {
88 ResolveClientIp { inner }
89 }
90}
91
92#[derive(Debug, Clone, Copy)]
94pub struct ResolveClientIp<S> {
95 inner: S,
96}
97
98impl<S, B> Service<http::Request<B>> for ResolveClientIp<S>
99where
100 S: Service<http::Request<B>>,
101{
102 type Response = S::Response;
103 type Error = S::Error;
104 type Future = S::Future;
105
106 fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
107 self.inner.poll_ready(cx)
108 }
109
110 fn call(&mut self, mut request: http::Request<B>) -> Self::Future {
111 if let Ok(ip) = GrpcIpExtractor::default().extract(&request) {
112 request.extensions_mut().insert(ClientIp(ip));
113 }
114 self.inner.call(request)
115 }
116}
117
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub struct GrpcIpExtractor(SmartIpKeyExtractor);
125
126impl Default for GrpcIpExtractor {
127 fn default() -> Self {
128 Self(SmartIpKeyExtractor)
129 }
130}
131
132impl GrpcIpExtractor {
133 #[expect(clippy::result_large_err, reason = "this is a third party error type")]
134 fn extract_tonic_address<T>(
135 request: &http::Request<T>,
136 ) -> Result<<Self as KeyExtractor>::Key, GovernorError> {
137 request
138 .extensions()
139 .get::<tonic::transport::server::TcpConnectInfo>()
140 .and_then(tonic::transport::server::TcpConnectInfo::remote_addr)
141 .as_ref()
142 .map(SocketAddr::ip)
143 .ok_or(GovernorError::UnableToExtractKey)
144 }
145}
146
147impl KeyExtractor for GrpcIpExtractor {
148 type Key = IpAddr;
149
150 fn extract<T>(
151 &self,
152 request: &http::Request<T>,
153 ) -> Result<Self::Key, tower_governor::GovernorError> {
154 self.0.extract(request).or_else(|_| Self::extract_tonic_address(request))
155 }
156}