#![deny(clippy::all, clippy::cargo)]
#![allow(clippy::multiple_crate_versions)]
use std::collections::{HashMap, HashSet};
use crate::counter::Counter;
use crate::errors::LimitadorError;
use crate::limit::{Limit, Namespace};
use crate::prometheus_metrics::PrometheusMetrics;
use crate::storage::in_memory::InMemoryStorage;
use crate::storage::{AsyncCounterStorage, AsyncStorage, Authorization, CounterStorage, Storage};
#[macro_use]
extern crate lazy_static;
extern crate core;
pub mod counter;
pub mod errors;
pub mod limit;
mod prometheus_metrics;
pub mod storage;
pub struct RateLimiter {
storage: Storage,
prometheus_metrics: PrometheusMetrics,
}
pub struct AsyncRateLimiter {
storage: AsyncStorage,
prometheus_metrics: PrometheusMetrics,
}
pub struct RateLimiterBuilder {
storage: Storage,
prometheus_limit_name_labels_enabled: bool,
}
impl RateLimiterBuilder {
pub fn new() -> Self {
Self {
storage: Storage::new(),
prometheus_limit_name_labels_enabled: false,
}
}
pub fn storage(mut self, storage: Storage) -> Self {
self.storage = storage;
self
}
pub fn with_prometheus_limit_name_labels(mut self) -> Self {
self.prometheus_limit_name_labels_enabled = true;
self
}
pub fn build(self) -> RateLimiter {
let prometheus_metrics = if self.prometheus_limit_name_labels_enabled {
PrometheusMetrics::new_with_counters_by_limit_name()
} else {
PrometheusMetrics::new()
};
RateLimiter {
storage: self.storage,
prometheus_metrics,
}
}
}
impl Default for RateLimiterBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct AsyncRateLimiterBuilder {
storage: AsyncStorage,
prometheus_limit_name_labels_enabled: bool,
}
impl AsyncRateLimiterBuilder {
pub fn new(storage: AsyncStorage) -> Self {
Self {
storage,
prometheus_limit_name_labels_enabled: false,
}
}
pub fn with_prometheus_limit_name_labels(mut self) -> Self {
self.prometheus_limit_name_labels_enabled = true;
self
}
pub fn build(self) -> AsyncRateLimiter {
let prometheus_metrics = if self.prometheus_limit_name_labels_enabled {
PrometheusMetrics::new_with_counters_by_limit_name()
} else {
PrometheusMetrics::new()
};
AsyncRateLimiter {
storage: self.storage,
prometheus_metrics,
}
}
}
impl RateLimiter {
pub fn new() -> Self {
Self {
storage: Storage::new(),
prometheus_metrics: PrometheusMetrics::new(),
}
}
pub fn new_with_storage(counters: Box<dyn CounterStorage>) -> Self {
Self {
storage: Storage::with_counter_storage(counters),
prometheus_metrics: PrometheusMetrics::new(),
}
}
pub fn get_namespaces(&self) -> HashSet<Namespace> {
self.storage.get_namespaces()
}
pub fn add_limit(&self, limit: Limit) -> bool {
self.storage.add_limit(limit)
}
pub fn delete_limit(&self, limit: &Limit) -> Result<(), LimitadorError> {
self.storage.delete_limit(limit)?;
Ok(())
}
pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
}
pub fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
self.storage.delete_limits(namespace)?;
Ok(())
}
pub fn is_rate_limited(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
delta: i64,
) -> Result<bool, LimitadorError> {
let counters = self.counters_that_apply(namespace, values)?;
for counter in counters {
match self.storage.is_within_limits(&counter, delta) {
Ok(within_limits) => {
if !within_limits {
self.prometheus_metrics
.incr_limited_calls(namespace, counter.limit().name());
return Ok(true);
}
}
Err(e) => return Err(e.into()),
}
}
self.prometheus_metrics.incr_authorized_calls(namespace);
Ok(false)
}
pub fn update_counters(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
delta: i64,
) -> Result<(), LimitadorError> {
let counters = self.counters_that_apply(namespace, values)?;
counters
.iter()
.try_for_each(|counter| self.storage.update_counter(counter, delta))
.map_err(|err| err.into())
}
pub fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
delta: i64,
) -> Result<bool, LimitadorError> {
let counters = self.counters_that_apply(namespace, values)?;
if counters.is_empty() {
self.prometheus_metrics.incr_authorized_calls(namespace);
return Ok(false);
}
let check_result = self
.storage
.check_and_update(counters.into_iter().collect(), delta)?;
match check_result {
Authorization::Ok => {
self.prometheus_metrics.incr_authorized_calls(namespace);
Ok(false)
}
Authorization::Limited(name) => {
self.prometheus_metrics
.incr_limited_calls(namespace, name.as_deref());
Ok(true)
}
}
}
pub fn get_counters(&self, namespace: &Namespace) -> Result<HashSet<Counter>, LimitadorError> {
self.storage
.get_counters(namespace)
.map_err(|err| err.into())
}
pub fn configure_with(
&self,
limits: impl IntoIterator<Item = Limit>,
) -> Result<(), LimitadorError> {
let limits_to_keep_or_create = classify_limits_by_namespace(limits);
let namespaces_limits_to_keep_or_create: HashSet<Namespace> =
limits_to_keep_or_create.keys().cloned().collect();
for namespace in self
.get_namespaces()
.union(&namespaces_limits_to_keep_or_create)
{
let limits_in_namespace = self.get_limits(namespace);
let limits_to_keep_in_ns: HashSet<Limit> = limits_to_keep_or_create
.get(namespace)
.cloned()
.unwrap_or_default();
for limit in limits_in_namespace.difference(&limits_to_keep_in_ns) {
self.delete_limit(limit)?;
}
for limit in limits_to_keep_in_ns.difference(&limits_in_namespace) {
self.add_limit(limit.clone());
}
for limit in limits_to_keep_in_ns.union(&limits_in_namespace) {
self.storage.update_limit(limit);
}
}
Ok(())
}
pub fn gather_prometheus_metrics(&self) -> String {
self.prometheus_metrics.gather_metrics()
}
fn counters_that_apply(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.collect();
Ok(counters)
}
}
impl Default for RateLimiter {
fn default() -> Self {
Self::new()
}
}
impl AsyncRateLimiter {
pub fn new_with_storage(storage: Box<dyn AsyncCounterStorage>) -> Self {
Self {
storage: AsyncStorage::with_counter_storage(storage),
prometheus_metrics: PrometheusMetrics::new(),
}
}
pub fn get_namespaces(&self) -> HashSet<Namespace> {
self.storage.get_namespaces()
}
pub fn add_limit(&self, limit: Limit) -> bool {
self.storage.add_limit(limit)
}
pub async fn delete_limit(&self, limit: &Limit) -> Result<(), LimitadorError> {
self.storage.delete_limit(limit).await?;
Ok(())
}
pub fn get_limits(&self, namespace: &Namespace) -> HashSet<Limit> {
self.storage.get_limits(namespace)
}
pub async fn delete_limits(&self, namespace: &Namespace) -> Result<(), LimitadorError> {
self.storage.delete_limits(namespace).await?;
Ok(())
}
pub async fn is_rate_limited(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
delta: i64,
) -> Result<bool, LimitadorError> {
let counters = self.counters_that_apply(namespace, values).await?;
for counter in counters {
match self.storage.is_within_limits(&counter, delta).await {
Ok(within_limits) => {
if !within_limits {
self.prometheus_metrics
.incr_limited_calls(namespace, counter.limit().name());
return Ok(true);
}
}
Err(e) => return Err(e.into()),
}
}
self.prometheus_metrics.incr_authorized_calls(namespace);
Ok(false)
}
pub async fn update_counters(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
delta: i64,
) -> Result<(), LimitadorError> {
let counters = self.counters_that_apply(namespace, values).await?;
for counter in counters {
self.storage.update_counter(&counter, delta).await?
}
Ok(())
}
pub async fn check_rate_limited_and_update(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
delta: i64,
) -> Result<bool, LimitadorError> {
let counters = self.counters_that_apply(namespace, values).await?;
if counters.is_empty() {
self.prometheus_metrics.incr_authorized_calls(namespace);
return Ok(false);
}
let check_result = self
.storage
.check_and_update(counters.into_iter().collect(), delta)
.await?;
match check_result {
Authorization::Ok => {
self.prometheus_metrics.incr_authorized_calls(namespace);
Ok(false)
}
Authorization::Limited(name) => {
self.prometheus_metrics
.incr_limited_calls(namespace, name.as_deref());
Ok(true)
}
}
}
pub async fn get_counters(
&self,
namespace: &Namespace,
) -> Result<HashSet<Counter>, LimitadorError> {
self.storage
.get_counters(namespace)
.await
.map_err(|err| err.into())
}
pub async fn configure_with(
&self,
limits: impl IntoIterator<Item = Limit>,
) -> Result<(), LimitadorError> {
let limits_to_keep_or_create = classify_limits_by_namespace(limits);
let namespaces_limits_to_keep_or_create: HashSet<Namespace> =
limits_to_keep_or_create.keys().cloned().collect();
for namespace in self
.get_namespaces()
.union(&namespaces_limits_to_keep_or_create)
{
let limits_in_namespace = self.get_limits(namespace);
let limits_to_keep_in_ns: HashSet<Limit> = limits_to_keep_or_create
.get(namespace)
.cloned()
.unwrap_or_default();
for limit in limits_in_namespace.difference(&limits_to_keep_in_ns) {
self.delete_limit(limit).await?;
}
for limit in limits_to_keep_in_ns.difference(&limits_in_namespace) {
self.add_limit(limit.clone());
}
for limit in limits_to_keep_in_ns.union(&limits_in_namespace) {
self.storage.update_limit(limit);
}
}
Ok(())
}
pub fn gather_prometheus_metrics(&self) -> String {
self.prometheus_metrics.gather_metrics()
}
async fn counters_that_apply(
&self,
namespace: &Namespace,
values: &HashMap<String, String>,
) -> Result<Vec<Counter>, LimitadorError> {
let limits = self.get_limits(namespace);
let counters = limits
.iter()
.filter(|lim| lim.applies(values))
.map(|lim| Counter::new(lim.clone(), values.clone()))
.collect();
Ok(counters)
}
}
fn classify_limits_by_namespace(
limits: impl IntoIterator<Item = Limit>,
) -> HashMap<Namespace, HashSet<Limit>> {
let mut res: HashMap<Namespace, HashSet<Limit>> = HashMap::new();
for limit in limits {
match res.get_mut(limit.namespace()) {
Some(limits) => {
limits.insert(limit);
}
None => {
let mut set = HashSet::new();
set.insert(limit.clone());
res.insert(limit.namespace().clone(), set);
}
}
}
res
}