#![deny(clippy::all, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use crate::counter::Counter;
use crate::errors::LimitadorError;
use crate::limit::{Context, Limit, Namespace};
use crate::storage::in_memory::InMemoryStorage;
use crate::storage::{
AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage, StorageErr,
};
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
#[macro_use]
extern crate core;
pub mod counter;
pub mod errors;
pub mod limit;
pub mod storage;
pub struct RateLimiter {
storage: Storage,
}
pub struct AsyncRateLimiter {
storage: AsyncStorage,
}
pub struct RateLimiterBuilder {
storage: Storage,
}
type LimitadorResult<T> = Result<T, LimitadorError>;
pub struct CheckResult {
pub limited: bool,
pub counters: Vec<Counter>,
pub limit_name: Option<String>,
}
impl CheckResult {
pub fn response_header(&mut self) -> HashMap<String, String> {
let mut headers = HashMap::new();
self.counters.sort_by(|a, b| {
let a_remaining = a.remaining().unwrap_or(a.max_value());
let b_remaining = b.remaining().unwrap_or(b.max_value());
a_remaining.cmp(&b_remaining)
});
let mut all_limits_text = String::with_capacity(20 * self.counters.len());
self.counters.iter().for_each(|counter| {
all_limits_text.push_str(
format!(", {};w={}", counter.max_value(), counter.window().as_secs()).as_str(),
);
if let Some(name) = counter.limit().name() {
all_limits_text.push_str(format!(";name=\"{}\"", name.replace('"', "'")).as_str());
}
});
if let Some(counter) = self.counters.first() {
headers.insert(
"X-RateLimit-Limit".to_string(),
format!("{}{all_limits_text}", counter.max_value()),
);
let remaining = counter.remaining().unwrap_or(counter.max_value());
headers.insert("X-RateLimit-Remaining".to_string(), format!("{remaining}"));
if let Some(duration) = counter.expires_in() {
headers.insert(
"X-RateLimit-Reset".to_string(),
format!("{}", duration.as_secs()),
);
}
}
headers
}
}
impl From<CheckResult> for bool {
fn from(value: CheckResult) -> Self {
value.limited
}
}
impl RateLimiterBuilder {
pub fn with_storage(storage: Storage) -> Self {
Self { storage }
}
pub fn new(cache_size: u64) -> Self {
Self {
storage: Storage::new(cache_size),
}
}
pub fn storage(mut self, storage: Storage) -> Self {
self.storage = storage;
self
}
pub fn build(self) -> RateLimiter {
RateLimiter {
storage: self.storage,
}
}
}
pub struct AsyncRateLimiterBuilder {
storage: AsyncStorage,
}
impl AsyncRateLimiterBuilder {
pub fn new(storage: AsyncStorage) -> Self {
Self { storage }
}
pub fn build(self) -> AsyncRateLimiter {
AsyncRateLimiter {
storage: self.storage,
}
}
}
impl RateLimiter {
pub fn new(cache_size: u64) -> Self {
Self {
storage: Storage::new(cache_size),
}
}
pub fn new_with_storage(counters: Box<dyn CounterStorage>) -> Self {
Self {
storage: Storage::with_counter_storage(counters),
}
}
pub fn get_namespaces(&self) -> HashSet<Namespace> {
self.storage.get_namespaces()
}
pub fn add_limit(&self, limit: Limit) -> bool {
self.storage.add_limit(limit)
}
pub fn delete_limit(&self, limit: &Limit) -> LimitadorResult<()> {
self.storage.delete_limit(limit)?;
Ok(())
}
pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}
pub fn delete_limits(&self, namespace: &Namespace) -> LimitadorResult<()> {
self.storage.delete_limits(namespace)?;
Ok(())
}
pub fn is_rate_limited(
&self,
namespace: &Namespace,
values: &Context,
delta: u64,
) -> LimitadorResult<CheckResult> {
let counters = self.counters_that_apply(namespace, values)?;
match self.find_first_limited_counter(&counters, delta) {
Err(e) => Err(e.into()),
Ok(auth) => match auth {
Authorization::Ok => Ok(CheckResult {
limited: false,
counters: Vec::default(),
limit_name: None,
}),
Authorization::Limited(name) => Ok(CheckResult {
limited: true,
counters: Vec::default(),
limit_name: name,
}),
},
}
}
fn find_first_limited_counter(
&self,
counters: &[Counter],
delta: u64,
) -> Result<Authorization, StorageErr> {
for counter in counters.iter() {
match self.storage.is_within_limits(counter, delta) {
Ok(within_limits) => {
if !within_limits {
return Ok(Authorization::Limited(
counter.limit().name().map(|n| n.to_owned()),
));
}
}
Err(e) => return Err(e),
}
}
Ok(Authorization::Ok)
}
pub fn update_counters(
&self,
namespace: &Namespace,
ctx: &Context,
delta: u64,
) -> LimitadorResult<()> {
let counters = self.counters_that_apply(namespace, ctx)?;
counters
.iter()
.try_for_each(|counter| self.storage.update_counter(counter, delta))
.map_err(|err| err.into())
}
pub fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
ctx: &Context,
delta: u64,
load_counters: bool,
) -> LimitadorResult<CheckResult> {
let mut counters = self.counters_that_apply(namespace, ctx)?;
if counters.is_empty() {
return Ok(CheckResult {
limited: false,
counters,
limit_name: None,
});
}
let check_result = self
.storage
.check_and_update(&mut counters, delta, load_counters)?;
let counters = if load_counters {
counters
} else {
Vec::default()
};
match check_result {
Authorization::Ok => Ok(CheckResult {
limited: false,
counters,
limit_name: None,
}),
Authorization::Limited(name) => Ok(CheckResult {
limited: true,
counters,
limit_name: name,
}),
}
}
pub fn get_counters(&self, namespace: &Namespace) -> LimitadorResult<HashSet<Counter>> {
self.storage
.get_counters(namespace)
.map_err(|err| err.into())
}
pub fn configure_with(&self, limits: impl IntoIterator<Item = Limit>) -> LimitadorResult<()> {
let limits_to_keep_or_create = classify_limits_by_namespace(limits);
let namespaces_limits_to_keep_or_create: HashSet<Namespace> =
limits_to_keep_or_create.keys().cloned().collect();
for namespace in self
.get_namespaces()
.union(&namespaces_limits_to_keep_or_create)
{
let limits_in_namespace = self.get_limits(namespace);
let limits_to_keep_in_ns: HashSet<Limit> = limits_to_keep_or_create
.get(namespace)
.cloned()
.unwrap_or_default();
for limit in limits_in_namespace.difference(&limits_to_keep_in_ns) {
self.delete_limit(limit)?;
}
for limit in limits_to_keep_in_ns.difference(&limits_in_namespace) {
self.add_limit(limit.clone());
}
for limit in limits_to_keep_in_ns.union(&limits_in_namespace) {
self.storage.update_limit(limit);
}
}
Ok(())
}
fn counters_that_apply(
&self,
namespace: &Namespace,
ctx: &Context,
) -> LimitadorResult<Vec<Counter>> {
let limits = self.storage.get_limits(namespace);
limits
.iter()
.filter(|lim| lim.applies(ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) {
Ok(None) => None,
Ok(Some(c)) => Some(Ok(c)),
Err(e) => Some(Err(e)),
})
.collect()
}
}
impl AsyncRateLimiter {
pub fn new_with_storage(storage: Box<dyn AsyncCounterStorage>) -> Self {
Self {
storage: AsyncStorage::with_counter_storage(storage),
}
}
pub fn get_namespaces(&self) -> HashSet<Namespace> {
self.storage.get_namespaces()
}
pub fn add_limit(&self, limit: Limit) -> bool {
self.storage.add_limit(limit)
}
pub async fn delete_limit(&self, limit: &Limit) -> LimitadorResult<()> {
self.storage.delete_limit(limit).await?;
Ok(())
}
pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage
.get_limits(namespace)
.iter()
.map(|l| (**l).clone())
.collect()
}
pub async fn delete_limits(&self, namespace: &Namespace) -> LimitadorResult<()> {
self.storage.delete_limits(namespace).await?;
Ok(())
}
pub async fn is_rate_limited(
&self,
namespace: &Namespace,
ctx: &Context<'_>,
delta: u64,
) -> LimitadorResult<CheckResult> {
let counters = self.counters_that_apply(namespace, ctx).await?;
match self.find_first_limited_counter(&counters, delta).await {
Err(e) => Err(e.into()),
Ok(auth) => match auth {
Authorization::Ok => Ok(CheckResult {
limited: false,
counters: Vec::default(),
limit_name: None,
}),
Authorization::Limited(name) => Ok(CheckResult {
limited: true,
counters: Vec::default(),
limit_name: name,
}),
},
}
}
async fn find_first_limited_counter(
&self,
counters: &[Counter],
delta: u64,
) -> Result<Authorization, StorageErr> {
for counter in counters.iter() {
match self.storage.is_within_limits(counter, delta).await {
Ok(within_limits) => {
if !within_limits {
return Ok(Authorization::Limited(
counter.limit().name().map(|n| n.to_owned()),
));
}
}
Err(e) => return Err(e),
}
}
Ok(Authorization::Ok)
}
pub async fn update_counters(
&self,
namespace: &Namespace,
ctx: &Context<'_>,
delta: u64,
) -> LimitadorResult<()> {
let counters = self.counters_that_apply(namespace, ctx).await?;
for counter in counters {
self.storage.update_counter(&counter, delta).await?
}
Ok(())
}
pub async fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
ctx: &Context<'_>,
delta: u64,
load_counters: bool,
) -> LimitadorResult<CheckResult> {
let mut counters = self.counters_that_apply(namespace, ctx).await?;
if counters.is_empty() {
return Ok(CheckResult {
limited: false,
counters,
limit_name: None,
});
}
let check_result = self
.storage
.check_and_update(&mut counters, delta, load_counters)
.await?;
let counters = if load_counters {
counters
} else {
Vec::default()
};
match check_result {
Authorization::Ok => Ok(CheckResult {
limited: false,
counters,
limit_name: None,
}),
Authorization::Limited(name) => Ok(CheckResult {
limited: true,
counters,
limit_name: name,
}),
}
}
pub async fn get_counters(&self, namespace: &Namespace) -> LimitadorResult<HashSet<Counter>> {
self.storage
.get_counters(namespace)
.await
.map_err(|err| err.into())
}
pub async fn configure_with(
&self,
limits: impl IntoIterator<Item = Limit>,
) -> LimitadorResult<()> {
let limits_to_keep_or_create = classify_limits_by_namespace(limits);
let namespaces_limits_to_keep_or_create: HashSet<Namespace> =
limits_to_keep_or_create.keys().cloned().collect();
for namespace in self
.get_namespaces()
.union(&namespaces_limits_to_keep_or_create)
{
let limits_in_namespace = self.get_limits(namespace);
let limits_to_keep_in_ns: HashSet<Limit> = limits_to_keep_or_create
.get(namespace)
.cloned()
.unwrap_or_default();
for limit in limits_in_namespace.difference(&limits_to_keep_in_ns) {
self.delete_limit(limit).await?;
}
for limit in limits_to_keep_in_ns.difference(&limits_in_namespace) {
self.add_limit(limit.clone());
}
for limit in limits_to_keep_in_ns.union(&limits_in_namespace) {
self.storage.update_limit(limit);
}
}
Ok(())
}
async fn counters_that_apply(
&self,
namespace: &Namespace,
ctx: &Context<'_>,
) -> LimitadorResult<Vec<Counter>> {
let limits = self.storage.get_limits(namespace);
limits
.iter()
.filter(|lim| lim.applies(ctx))
.filter_map(|lim| match Counter::new(Arc::clone(lim), ctx) {
Ok(None) => None,
Ok(Some(c)) => Some(Ok(c)),
Err(e) => Some(Err(e)),
})
.collect()
}
}
fn classify_limits_by_namespace(
limits: impl IntoIterator<Item = Limit>,
) -> HashMap<Namespace, HashSet<Limit>> {
let mut res: HashMap<Namespace, HashSet<Limit>> = HashMap::new();
for limit in limits {
match res.get_mut(limit.namespace()) {
Some(limits) => {
limits.insert(limit);
}
None => {
let mut set = HashSet::new();
set.insert(limit.clone());
res.insert(limit.namespace().clone(), set);
}
}
}
res
}
#[cfg(test)]
mod test {
use crate::limit::{Context, Expression, Limit};
use crate::RateLimiter;
use std::collections::HashMap;
#[test]
fn properly_updates_existing_limits() {
let rl = RateLimiter::new(100);
let namespace = "foo";
let l = Limit::new(namespace, 42, 100, vec![], Vec::<Expression>::default());
rl.add_limit(l.clone());
let limits = rl.get_limits(&namespace.into());
assert_eq!(limits.len(), 1);
assert!(limits.contains(&l));
assert_eq!(limits.iter().next().unwrap().max_value(), 42);
let r = rl
.check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().max_value(), 42);
let mut l = l.clone();
l.set_max_value(50);
rl.configure_with([l.clone()]).unwrap();
let limits = rl.get_limits(&namespace.into());
assert_eq!(limits.len(), 1);
assert!(limits.contains(&l));
assert_eq!(limits.iter().next().unwrap().max_value(), 50);
let r = rl
.check_rate_limited_and_update(&namespace.into(), &Context::default(), 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().max_value(), 50);
}
#[test]
fn deletes_qualified_counters() {
let rl = RateLimiter::new(100);
let namespace = "foo";
let l = Limit::new(
namespace,
42,
100,
vec![],
vec![Expression::parse("x").unwrap()],
);
let ctx = Context::from(HashMap::from([("x".to_string(), "a".to_string())]));
rl.add_limit(l.clone());
let r = rl
.check_rate_limited_and_update(&namespace.into(), &ctx, 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().remaining(), Some(41));
rl.delete_limit(&l).unwrap();
rl.add_limit(l.clone());
let r = rl
.check_rate_limited_and_update(&namespace.into(), &ctx, 1, true)
.unwrap();
assert_eq!(r.counters.first().unwrap().remaining(), Some(41));
}
}