use crate::{
key_extractor::{KeyExtractor, PeerIpKeyExtractor},
GovernorError,
};
use governor::{
clock::{DefaultClock, QuantaInstant},
middleware::{NoOpMiddleware, RateLimitingMiddleware, StateInformationMiddleware},
state::keyed::DefaultKeyedStateStore,
Quota, RateLimiter,
};
use http::{Method, Response};
use std::{fmt, marker::PhantomData, num::NonZeroU32, sync::Arc, time::Duration};
pub const DEFAULT_PERIOD: Duration = Duration::from_millis(500);
pub const DEFAULT_BURST_SIZE: u32 = 8;
pub type SharedRateLimiter<Key, M> =
Arc<RateLimiter<Key, DefaultKeyedStateStore<Key>, DefaultClock, M>>;
#[derive(Debug, Eq, Clone, PartialEq)]
pub struct GovernorConfigBuilder<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> {
period: Duration,
burst_size: u32,
methods: Option<Vec<Method>>,
key_extractor: K,
middleware: PhantomData<M>,
}
pub(crate) struct ErrorHandler<RespBody>(
Arc<dyn Fn(GovernorError) -> Response<RespBody> + Send + Sync>,
);
impl<RespBody> ErrorHandler<RespBody> {
pub(crate) fn new(
f: impl Fn(GovernorError) -> Response<RespBody> + Send + Sync + 'static,
) -> Self {
Self(Arc::new(f))
}
pub(crate) fn handle_error(&self, error: GovernorError) -> Response<RespBody> {
(self.0)(error)
}
}
impl<RespBody> Clone for ErrorHandler<RespBody> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<RespBody> fmt::Debug for ErrorHandler<RespBody> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ErrorHandler").finish()
}
}
impl Default for GovernorConfigBuilder<PeerIpKeyExtractor, NoOpMiddleware> {
fn default() -> Self {
Self::const_default()
}
}
impl<M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBuilder<PeerIpKeyExtractor, M> {
pub fn const_default() -> Self {
GovernorConfigBuilder {
period: DEFAULT_PERIOD,
burst_size: DEFAULT_BURST_SIZE,
methods: None,
key_extractor: PeerIpKeyExtractor,
middleware: PhantomData,
}
}
pub fn const_period(mut self, duration: Duration) -> Self {
self.period = duration;
self
}
pub fn const_per_second(mut self, seconds: u64) -> Self {
self.period = Duration::from_secs(seconds);
self
}
pub fn const_per_millisecond(mut self, milliseconds: u64) -> Self {
self.period = Duration::from_millis(milliseconds);
self
}
pub fn const_per_nanosecond(mut self, nanoseconds: u64) -> Self {
self.period = Duration::from_nanos(nanoseconds);
self
}
pub fn const_burst_size(mut self, burst_size: u32) -> Self {
self.burst_size = burst_size;
self
}
}
impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfigBuilder<K, M> {
pub fn period(&mut self, duration: Duration) -> &mut Self {
self.period = duration;
self
}
pub fn per_second(&mut self, seconds: u64) -> &mut Self {
self.period = Duration::from_secs(seconds);
self
}
pub fn per_millisecond(&mut self, milliseconds: u64) -> &mut Self {
self.period = Duration::from_millis(milliseconds);
self
}
pub fn per_nanosecond(&mut self, nanoseconds: u64) -> &mut Self {
self.period = Duration::from_nanos(nanoseconds);
self
}
pub fn burst_size(&mut self, burst_size: u32) -> &mut Self {
self.burst_size = burst_size;
self
}
pub fn methods(&mut self, methods: Vec<Method>) -> &mut Self {
self.methods = Some(methods);
self
}
pub fn key_extractor<K2: KeyExtractor>(
&mut self,
key_extractor: K2,
) -> GovernorConfigBuilder<K2, M> {
GovernorConfigBuilder {
period: self.period,
burst_size: self.burst_size,
methods: self.methods.to_owned(),
key_extractor,
middleware: PhantomData,
}
}
pub fn use_headers(&mut self) -> GovernorConfigBuilder<K, StateInformationMiddleware> {
GovernorConfigBuilder {
period: self.period,
burst_size: self.burst_size,
methods: self.methods.to_owned(),
key_extractor: self.key_extractor.clone(),
middleware: PhantomData,
}
}
pub fn finish(&mut self) -> Option<GovernorConfig<K, M>> {
if self.burst_size != 0 && self.period.as_nanos() != 0 {
Some(GovernorConfig {
key_extractor: self.key_extractor.clone(),
limiter: Arc::new(
RateLimiter::keyed(
Quota::with_period(self.period)
.unwrap()
.allow_burst(NonZeroU32::new(self.burst_size).unwrap()),
)
.with_middleware::<M>(),
),
methods: self.methods.clone(),
})
} else {
None
}
}
}
#[derive(Debug, Clone)]
pub struct GovernorConfig<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> {
key_extractor: K,
limiter: SharedRateLimiter<K::Key, M>,
methods: Option<Vec<Method>>,
}
impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>> GovernorConfig<K, M> {
pub fn limiter(&self) -> &SharedRateLimiter<K::Key, M> {
&self.limiter
}
}
impl Default for GovernorConfig<PeerIpKeyExtractor, NoOpMiddleware> {
fn default() -> Self {
GovernorConfigBuilder::default().finish().unwrap()
}
}
impl<M: RateLimitingMiddleware<QuantaInstant>> GovernorConfig<PeerIpKeyExtractor, M> {
pub fn secure() -> Self {
GovernorConfigBuilder {
period: Duration::from_secs(4),
burst_size: 2,
methods: None,
key_extractor: PeerIpKeyExtractor,
middleware: PhantomData,
}
.finish()
.unwrap()
}
}
#[derive(Debug)]
pub struct Governor<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S, RespBody> {
pub key_extractor: K,
pub limiter: SharedRateLimiter<K::Key, M>,
pub methods: Option<Vec<Method>>,
pub inner: S,
error_handler: Option<ErrorHandler<RespBody>>,
}
impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S: Clone, RespBody> Clone
for Governor<K, M, S, RespBody>
{
fn clone(&self) -> Self {
Self {
key_extractor: self.key_extractor.clone(),
limiter: self.limiter.clone(),
methods: self.methods.clone(),
inner: self.inner.clone(),
error_handler: self.error_handler.clone(),
}
}
}
impl<K: KeyExtractor, M: RateLimitingMiddleware<QuantaInstant>, S, RespBody>
Governor<K, M, S, RespBody>
{
pub fn new(inner: S, config: &GovernorConfig<K, M>) -> Self {
Governor {
key_extractor: config.key_extractor.clone(),
limiter: config.limiter.clone(),
methods: config.methods.clone(),
inner,
error_handler: None,
}
}
pub fn error_handler(
mut self,
handler: impl Fn(GovernorError) -> Response<RespBody> + Send + Sync + 'static,
) -> Self {
self.error_handler = Some(ErrorHandler::new(handler));
self
}
pub(crate) fn set_error_handler(&mut self, handler: Option<ErrorHandler<RespBody>>) {
self.error_handler = handler;
}
pub(crate) fn handle_error(&self, error: GovernorError) -> Response<RespBody>
where
Response<RespBody>: From<GovernorError>,
{
if let Some(handler) = self.error_handler.as_ref() {
handler.handle_error(error)
} else {
error.into()
}
}
}