Skip to main content

miden_node_utils/grpc/
layers.rs

1use std::time::Duration;
2
3use anyhow::{Context, ensure};
4use governor::middleware::StateInformationMiddleware;
5use tonic::service::InterceptorLayer;
6use tower::limit::GlobalConcurrencyLimitLayer;
7use tower_governor::governor::GovernorConfigBuilder;
8use tower_governor::key_extractor::SmartIpKeyExtractor;
9
10use super::connect_info::ConnectInfoInterceptor;
11use crate::clap::GrpcOptionsExternal;
12
13/// Creates the gRPC interceptor layer that attaches connection metadata.
14pub fn connect_info_layer() -> InterceptorLayer<ConnectInfoInterceptor> {
15    InterceptorLayer::new(ConnectInfoInterceptor)
16}
17
18/// Builds a global concurrency limit layer using the configured semaphore.
19pub fn rate_limit_concurrent_connections(
20    grpc_options: GrpcOptionsExternal,
21) -> GlobalConcurrencyLimitLayer {
22    tower::limit::GlobalConcurrencyLimitLayer::new(grpc_options.max_concurrent_connections as usize)
23}
24
25/// Creates a per-IP rate limit layer using the configured governor settings.
26pub fn rate_limit_per_ip(
27    grpc_options: GrpcOptionsExternal,
28) -> anyhow::Result<
29    tower_governor::GovernorLayer<
30        SmartIpKeyExtractor,
31        StateInformationMiddleware,
32        tonic::body::Body,
33    >,
34> {
35    let nanos_per_replenish = Duration::from_secs(1)
36        .as_nanos()
37        .checked_div(u128::from(grpc_options.replenish_n_per_second_per_ip.get()))
38        .unwrap_or_default();
39    ensure!(
40        nanos_per_replenish > 0,
41        "grpc.replenish_n_per_second must be less than or equal to 1e9"
42    );
43    let replenish_period = Duration::from_nanos(
44        u64::try_from(nanos_per_replenish).context("invalid gRPC rate limit configuration")?,
45    );
46    let config = GovernorConfigBuilder::default()
47        .key_extractor(SmartIpKeyExtractor)
48        .period(replenish_period)
49        .burst_size(grpc_options.burst_size.into())
50        .use_headers()
51        .finish()
52        .context("invalid gRPC rate limit configuration")?;
53    let limiter = std::sync::Arc::clone(config.limiter());
54    tokio::spawn(async move {
55        let mut interval = tokio::time::interval(Duration::from_secs(60));
56        loop {
57            interval.tick().await;
58            // avoid a DoS vector
59            limiter.retain_recent();
60        }
61    });
62    Ok(tower_governor::GovernorLayer::new(config))
63}