use std::net::{IpAddr, SocketAddr};
use std::task::{Context as TaskContext, Poll};
use std::time::Duration;
use anyhow::{Context, ensure};
use governor::middleware::StateInformationMiddleware;
use tower::limit::GlobalConcurrencyLimitLayer;
use tower::{Layer, Service};
use tower_governor::GovernorError;
use tower_governor::governor::GovernorConfigBuilder;
use tower_governor::key_extractor::{KeyExtractor, SmartIpKeyExtractor};
use crate::clap::GrpcOptionsExternal;
pub fn rate_limit_concurrent_connections(
grpc_options: GrpcOptionsExternal,
) -> GlobalConcurrencyLimitLayer {
tower::limit::GlobalConcurrencyLimitLayer::new(grpc_options.max_concurrent_connections as usize)
}
pub fn rate_limit_per_ip(
grpc_options: GrpcOptionsExternal,
) -> anyhow::Result<
tower_governor::GovernorLayer<GrpcIpExtractor, StateInformationMiddleware, tonic::body::Body>,
> {
let nanos_per_replenish = Duration::from_secs(1)
.as_nanos()
.checked_div(u128::from(grpc_options.replenish_n_per_second_per_ip.get()))
.unwrap_or_default();
ensure!(
nanos_per_replenish > 0,
"grpc.replenish_n_per_second must be less than or equal to 1e9"
);
let replenish_period = Duration::from_nanos(
u64::try_from(nanos_per_replenish).context("invalid gRPC rate limit configuration")?,
);
let config = GovernorConfigBuilder::default()
.key_extractor(GrpcIpExtractor::default())
.period(replenish_period)
.burst_size(grpc_options.burst_size.into())
.use_headers()
.finish()
.context("invalid gRPC rate limit configuration")?;
let limiter = std::sync::Arc::clone(config.limiter());
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
interval.tick().await;
limiter.retain_recent();
}
});
Ok(tower_governor::GovernorLayer::new(config))
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ClientIp(pub IpAddr);
impl ClientIp {
pub fn from_request<T>(request: &tonic::Request<T>) -> Option<IpAddr> {
request.extensions().get::<Self>().map(|ip| ip.0)
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ResolveClientIpLayer;
impl<S> Layer<S> for ResolveClientIpLayer {
type Service = ResolveClientIp<S>;
fn layer(&self, inner: S) -> Self::Service {
ResolveClientIp { inner }
}
}
#[derive(Debug, Clone, Copy)]
pub struct ResolveClientIp<S> {
inner: S,
}
impl<S, B> Service<http::Request<B>> for ResolveClientIp<S>
where
S: Service<http::Request<B>>,
{
type Response = S::Response;
type Error = S::Error;
type Future = S::Future;
fn poll_ready(&mut self, cx: &mut TaskContext<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, mut request: http::Request<B>) -> Self::Future {
if let Ok(ip) = GrpcIpExtractor::default().extract(&request) {
request.extensions_mut().insert(ClientIp(ip));
}
self.inner.call(request)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct GrpcIpExtractor(SmartIpKeyExtractor);
impl Default for GrpcIpExtractor {
fn default() -> Self {
Self(SmartIpKeyExtractor)
}
}
impl GrpcIpExtractor {
#[expect(clippy::result_large_err, reason = "this is a third party error type")]
fn extract_tonic_address<T>(
request: &http::Request<T>,
) -> Result<<Self as KeyExtractor>::Key, GovernorError> {
request
.extensions()
.get::<tonic::transport::server::TcpConnectInfo>()
.and_then(tonic::transport::server::TcpConnectInfo::remote_addr)
.as_ref()
.map(SocketAddr::ip)
.ok_or(GovernorError::UnableToExtractKey)
}
}
impl KeyExtractor for GrpcIpExtractor {
type Key = IpAddr;
fn extract<T>(
&self,
request: &http::Request<T>,
) -> Result<Self::Key, tower_governor::GovernorError> {
self.0.extract(request).or_else(|_| Self::extract_tonic_address(request))
}
}