use std::{
future::{Future, ready},
net::IpAddr,
num::NonZeroU32,
pin::Pin,
sync::Arc,
task::{Context, Poll},
time::Duration,
};
use governor::{DefaultKeyedRateLimiter, Quota, RateLimiter};
use tower::{Layer, Service, ServiceBuilder, ServiceExt as _};
use crate::{
codec::header::Rcode,
resolver::pipeline::{BoxError, DnsRequest, Outcome, PipelineResponse},
};
#[derive(Debug)]
pub struct RateLimited;
impl std::fmt::Display for RateLimited {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("per-client rate limit exceeded")
}
}
impl std::error::Error for RateLimited {}
pub struct KeyedRateLimitLayer {
limiter: Arc<DefaultKeyedRateLimiter<IpAddr>>,
}
impl KeyedRateLimitLayer {
pub fn new(limiter: Arc<DefaultKeyedRateLimiter<IpAddr>>) -> Self {
Self { limiter }
}
}
impl<S> Layer<S> for KeyedRateLimitLayer {
type Service = KeyedRateLimitService<S>;
fn layer(&self, inner: S) -> Self::Service {
KeyedRateLimitService {
limiter: self.limiter.clone(),
inner,
}
}
}
#[derive(Clone)]
pub struct KeyedRateLimitService<S> {
limiter: Arc<DefaultKeyedRateLimiter<IpAddr>>,
inner: S,
}
impl<S> Service<DnsRequest> for KeyedRateLimitService<S>
where
S: Service<DnsRequest, Response = PipelineResponse, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Response = PipelineResponse;
type Error = BoxError;
type Future = Pin<Box<dyn Future<Output = Result<PipelineResponse, BoxError>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: DnsRequest) -> Self::Future {
let ip = req.client().ip();
match self.limiter.check_key(&ip) {
Ok(()) => {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);
Box::pin(async move { inner.call(req).await })
}
Err(_) => Box::pin(ready(Err(Box::new(RateLimited) as BoxError))),
}
}
}
#[derive(Debug, Clone)]
pub struct ProtectiveConfig {
pub rate_per_second: u32,
pub rate_burst: u32,
pub concurrency_cap: usize,
pub request_timeout: Duration,
}
impl Default for ProtectiveConfig {
fn default() -> Self {
Self {
rate_per_second: 100,
rate_burst: 200,
concurrency_cap: 1024,
request_timeout: Duration::from_secs(5),
}
}
}
pub fn build_protective_service<S>(
config: &ProtectiveConfig,
resolve: S,
) -> tower::util::BoxCloneService<DnsRequest, PipelineResponse, BoxError>
where
S: Service<DnsRequest, Response = PipelineResponse, Error = BoxError> + Clone + Send + 'static,
S::Future: Send + 'static,
{
let nz = |x: u32| NonZeroU32::new(x).expect("rate limit must be non-zero");
let quota = Quota::per_second(nz(config.rate_per_second)).allow_burst(nz(config.rate_burst));
let rate_limiter: Arc<DefaultKeyedRateLimiter<IpAddr>> = Arc::new(RateLimiter::keyed(quota));
ServiceBuilder::new()
.layer(KeyedRateLimitLayer::new(rate_limiter))
.layer(tower::load_shed::LoadShedLayer::new())
.layer(tower::limit::GlobalConcurrencyLimitLayer::new(
config.concurrency_cap,
))
.layer(tower::timeout::TimeoutLayer::new(config.request_timeout))
.service(resolve)
.boxed_clone()
}
pub trait ClassifyRejection {
fn rejection_policy(&self) -> (Outcome, Rcode);
}
impl ClassifyRejection for dyn std::error::Error + Send + Sync + 'static {
fn rejection_policy(&self) -> (Outcome, Rcode) {
if self.is::<RateLimited>() || self.is::<tower::load_shed::error::Overloaded>() {
(Outcome::Refused, Rcode::Refused)
} else {
(Outcome::Servfail, Rcode::ServFail)
}
}
}
#[cfg(test)]
mod tests {
use std::{net::SocketAddr, sync::Arc, time::Duration};
use bytes::Bytes;
use tokio::sync::Notify;
use tower::ServiceExt as _;
use super::*;
use crate::test_support::a_query;
use crate::{
codec::message::Query,
resolver::pipeline::{BoxError, DnsRequest, Outcome, PipelineResponse},
};
fn make_request(raw: Bytes, client: SocketAddr) -> DnsRequest {
let query = Query::try_from(raw).expect("valid query");
DnsRequest::new(query, client)
}
fn stub_fn(req: DnsRequest) -> std::future::Ready<Result<PipelineResponse, BoxError>> {
std::future::ready(Ok(PipelineResponse::new(
req.raw().clone(),
Outcome::Forwarded,
)))
}
#[tokio::test]
async fn happy_path_returns_forwarded() {
let config = ProtectiveConfig::default();
let svc = build_protective_service(&config, tower::service_fn(stub_fn));
let raw = a_query(0x0001, "example.com");
let req = make_request(raw, "10.0.0.1:1234".parse().unwrap());
let resp = svc.clone().oneshot(req).await.expect("must succeed");
assert_eq!(resp.outcome, Outcome::Forwarded);
}
#[tokio::test]
async fn rate_limit_throttles_one_client_others_pass() {
let config = ProtectiveConfig {
rate_per_second: 1,
rate_burst: 1,
concurrency_cap: 1024,
request_timeout: Duration::from_secs(5),
};
let svc = build_protective_service(&config, tower::service_fn(stub_fn));
let ip1: SocketAddr = "1.1.1.1:1234".parse().unwrap();
let ip2: SocketAddr = "2.2.2.2:1234".parse().unwrap();
let raw1 = a_query(0x0001, "example.com");
let raw2 = a_query(0x0002, "example.com");
let raw3 = a_query(0x0003, "example.com");
let r1 = svc.clone().oneshot(make_request(raw1, ip1)).await;
assert!(r1.is_ok(), "first request from 1.1.1.1 must succeed");
let r2 = svc.clone().oneshot(make_request(raw2, ip1)).await;
let err = r2.expect_err("second request from 1.1.1.1 must be rate-limited");
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Refused, "rate-limited must be Refused");
assert_eq!(rcode, Rcode::Refused, "rcode must be Refused");
let r3 = svc.clone().oneshot(make_request(raw3, ip2)).await;
assert!(r3.is_ok(), "request from different IP must still pass");
}
#[tokio::test]
async fn timeout_returns_servfail() {
let config = ProtectiveConfig {
rate_per_second: 100,
rate_burst: 200,
concurrency_cap: 1024,
request_timeout: Duration::from_millis(50),
};
let slow_svc = tower::service_fn(|req: DnsRequest| async move {
tokio::time::sleep(Duration::from_millis(500)).await;
Ok::<_, BoxError>(PipelineResponse::new(req.raw().clone(), Outcome::Forwarded))
});
let svc = build_protective_service(&config, slow_svc);
let raw = a_query(0x0004, "slow.example.com");
let req = make_request(raw, "10.0.0.2:1234".parse().unwrap());
let err = svc
.oneshot(req)
.await
.expect_err("slow request must time out");
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Servfail, "timeout must be Servfail");
assert_eq!(rcode, Rcode::ServFail, "rcode must be ServFail");
}
#[tokio::test]
async fn concurrency_load_shed_returns_refused() {
let config = ProtectiveConfig {
rate_per_second: 1000,
rate_burst: 1000,
concurrency_cap: 1,
request_timeout: Duration::from_secs(5),
};
let gate = Arc::new(Notify::new());
let gate_clone = gate.clone();
let blocking_svc = tower::service_fn(move |req: DnsRequest| {
let gate = gate_clone.clone();
async move {
gate.notified().await;
Ok::<_, BoxError>(PipelineResponse::new(req.raw().clone(), Outcome::Forwarded))
}
});
let svc = build_protective_service(&config, blocking_svc);
let raw_a = a_query(0x0005, "a.example.com");
let raw_b = a_query(0x0006, "b.example.com");
let addr: SocketAddr = "10.0.0.3:1234".parse().unwrap();
tokio::time::timeout(Duration::from_secs(2), async {
let req_a = make_request(raw_a, addr);
let mut svc_a = svc.clone();
let fut_a = tokio::spawn(async move {
svc_a.ready().await.expect("ready");
svc_a.call(req_a).await
});
tokio::time::sleep(Duration::from_millis(20)).await;
let req_b = make_request(raw_b, addr);
let result_b = svc.clone().oneshot(req_b).await;
let err = result_b.expect_err("request B must be load-shed");
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Refused, "load-shed must be Refused");
assert_eq!(rcode, Rcode::Refused, "rcode must be Refused");
gate.notify_one();
fut_a
.await
.expect("task A must complete")
.expect("A must succeed");
})
.await
.expect("test must complete within the outer timeout");
}
#[test]
fn classifier_rate_limited_maps_to_refused() {
let err: BoxError = Box::new(RateLimited);
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Refused);
assert_eq!(rcode, Rcode::Refused);
}
#[test]
fn classifier_overloaded_maps_to_refused() {
let err: BoxError = Box::new(tower::load_shed::error::Overloaded::new());
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Refused);
assert_eq!(rcode, Rcode::Refused);
}
#[test]
fn classifier_elapsed_maps_to_servfail() {
let err: BoxError = Box::new(tower::timeout::error::Elapsed::new());
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Servfail);
assert_eq!(rcode, Rcode::ServFail);
}
#[test]
fn classifier_generic_error_maps_to_servfail() {
let err: BoxError = "oops".into();
let (outcome, rcode) = err.rejection_policy();
assert_eq!(outcome, Outcome::Servfail);
assert_eq!(rcode, Rcode::ServFail);
}
}