use core::hash::Hash;
use core::num::NonZeroU32;
use core::prelude::v1::*;
use crate::state::StateStore;
use crate::{
clock::{self, Reference},
errors::InsufficientCapacity,
middleware::RateLimitingMiddleware,
nanos::Nanos,
Quota, RateLimiter,
};
#[cfg(feature = "std")]
pub type DefaultHasher = std::hash::RandomState;
#[cfg(not(feature = "std"))]
pub type DefaultHasher = hashbrown::DefaultHashBuilder;
pub trait KeyedStateStore<K: Hash>: StateStore<Key = K> {}
impl<T, K: Hash> KeyedStateStore<K> for T
where
T: StateStore<Key = K>,
K: Eq + Clone + Hash,
{
}
impl<K> RateLimiter<K, DefaultKeyedStateStore<K>, clock::DefaultClock>
where
K: Clone + Hash + Eq,
{
pub fn keyed(quota: Quota) -> Self {
let state = DefaultKeyedStateStore::default();
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
#[cfg(all(feature = "std", feature = "dashmap"))]
pub fn dashmap(quota: Quota) -> Self {
let state = DashMapStateStore::default();
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
#[cfg(any(all(feature = "std", not(feature = "dashmap")), not(feature = "std")))]
pub fn hashmap(quota: Quota) -> Self {
let state = HashMapStateStore::default();
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
}
#[cfg(any(all(feature = "std", not(feature = "dashmap")), not(feature = "std")))]
impl<K, S> RateLimiter<K, DefaultKeyedStateStore<K, S>, clock::DefaultClock>
where
K: Clone + Hash + Eq,
S: core::hash::BuildHasher + Default,
{
pub fn hashmap_with_hasher(quota: Quota, hasher: S) -> Self {
let state = HashMapStateStore::new(hashmap::HashMap::with_hasher(hasher));
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
}
#[cfg(all(feature = "std", feature = "dashmap"))]
impl<K, S> RateLimiter<K, DefaultKeyedStateStore<K, S>, clock::DefaultClock>
where
K: Clone + Hash + Eq,
S: core::hash::BuildHasher + Clone + Default,
{
pub fn dashmap_with_hasher(quota: Quota, hasher: S) -> Self {
let state = DashMapStateStore::with_hasher(hasher);
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
}
#[cfg(all(feature = "std", feature = "dashmap"))]
impl<K> RateLimiter<K, HashMapStateStore<K>, clock::DefaultClock>
where
K: Clone + Hash + Eq,
{
pub fn hashmap(quota: Quota) -> Self {
let state = HashMapStateStore::default();
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
}
#[cfg(all(feature = "std", feature = "dashmap"))]
impl<K, S> RateLimiter<K, HashMapStateStore<K, S>, clock::DefaultClock>
where
K: Clone + Hash + Eq,
S: core::hash::BuildHasher + Default + Clone,
{
pub fn hashmap_with_hasher(quota: Quota, hasher: S) -> Self {
let state = HashMapStateStore::new(hashmap::HashMap::with_hasher(hasher));
let clock = clock::DefaultClock::default();
RateLimiter::new(quota, state, clock)
}
}
impl<K, S, C, MW> RateLimiter<K, S, C, MW>
where
S: KeyedStateStore<K>,
K: Hash,
C: clock::Clock,
MW: RateLimitingMiddleware<C::Instant>,
{
pub fn check_key(&self, key: &K) -> Result<MW::PositiveOutcome, MW::NegativeOutcome> {
self.gcra.test_and_update::<K, C::Instant, S, MW>(
self.start,
key,
&self.state,
self.clock.now(),
)
}
pub fn check_key_n(
&self,
key: &K,
n: NonZeroU32,
) -> Result<Result<MW::PositiveOutcome, MW::NegativeOutcome>, InsufficientCapacity> {
self.gcra.test_n_all_and_update::<K, C::Instant, S, MW>(
self.start,
key,
n,
&self.state,
self.clock.now(),
)
}
}
pub trait ShrinkableKeyedStateStore<K: Hash>: KeyedStateStore<K> {
fn retain_recent(&self, drop_below: Nanos);
fn shrink_to_fit(&self) {}
fn len(&self) -> usize;
fn is_empty(&self) -> bool;
}
impl<K, S, C, MW> RateLimiter<K, S, C, MW>
where
S: ShrinkableKeyedStateStore<K>,
K: Hash,
C: clock::Clock,
MW: RateLimitingMiddleware<C::Instant>,
{
pub fn retain_recent(&self) {
let now = self.clock.now();
let drop_below = now.duration_since(self.start).saturating_sub(self.gcra.t());
self.state.retain_recent(drop_below);
}
pub fn shrink_to_fit(&self) {
self.state.shrink_to_fit();
}
pub fn len(&self) -> usize {
self.state.len()
}
pub fn is_empty(&self) -> bool {
self.state.is_empty()
}
}
mod hashmap;
pub use hashmap::HashMapStateStore;
#[cfg(all(feature = "std", feature = "dashmap"))]
mod dashmap;
#[cfg(all(feature = "std", feature = "dashmap"))]
pub use self::dashmap::DashMapStateStore;
#[cfg(feature = "std")]
mod future;
#[cfg(any(all(feature = "std", not(feature = "dashmap")), not(feature = "std")))]
pub type DefaultKeyedStateStore<K, S = DefaultHasher> = HashMapStateStore<K, S>;
#[cfg(all(feature = "std", feature = "dashmap"))]
pub type DefaultKeyedStateStore<K, S = DefaultHasher> = DashMapStateStore<K, S>;
#[cfg(test)]
mod test {
use core::marker::PhantomData;
use nonzero_ext::nonzero;
use crate::{
clock::{Clock, FakeRelativeClock},
middleware::NoOpMiddleware,
};
use super::*;
#[test]
fn default_nonshrinkable_state_store_coverage() {
#[derive(Default)]
struct NaiveKeyedStateStore<K>(PhantomData<K>);
impl<K: Hash + Eq + Clone> StateStore for NaiveKeyedStateStore<K> {
type Key = K;
fn measure_and_replace<T, F, E>(&self, _key: &Self::Key, f: F) -> Result<T, E>
where
F: Fn(Option<Nanos>) -> Result<(T, Nanos), E>,
{
f(None).map(|(res, _)| res)
}
}
impl<K: Hash + Eq + Clone> ShrinkableKeyedStateStore<K> for NaiveKeyedStateStore<K> {
fn retain_recent(&self, _drop_below: Nanos) {
}
fn len(&self) -> usize {
0
}
fn is_empty(&self) -> bool {
true
}
}
let lim: RateLimiter<
u32,
NaiveKeyedStateStore<u32>,
FakeRelativeClock,
NoOpMiddleware<<FakeRelativeClock as Clock>::Instant>,
> = RateLimiter::new(
Quota::per_second(nonzero!(1_u32)),
NaiveKeyedStateStore::default(),
FakeRelativeClock::default(),
);
assert_eq!(lim.check_key(&1u32), Ok(()));
assert!(lim.is_empty());
assert_eq!(lim.len(), 0);
lim.retain_recent();
lim.shrink_to_fit();
}
}