use std::{
future::Future,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use crate::policy::Policy;
use crate::rate_limit::RateLimitResult;
struct Inner {
available_tokens: f64,
last_refill: Instant,
}
#[derive(Clone)]
pub struct RateLimiter {
pub max_tokens: usize,
pub refill_rate: Duration,
inner: Arc<Mutex<Inner>>,
}
impl Default for RateLimiter {
fn default() -> Self {
let now = Instant::now();
Self {
max_tokens: 10,
refill_rate: Duration::from_secs(1),
inner: Arc::new(Mutex::new(Inner {
available_tokens: 10.0,
last_refill: now,
})),
}
}
}
impl RateLimiter {
pub fn with_max_tokens(mut self, max: usize) -> Self {
self.max_tokens = max;
self
}
pub fn with_refill_rate(mut self, rate: Duration) -> Self {
self.refill_rate = rate;
self
}
pub fn try_consume(&self, tokens: usize) -> bool {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
let now = Instant::now();
let elapsed = now
.checked_duration_since(inner.last_refill)
.unwrap_or(Duration::ZERO);
let tokens_to_add = elapsed.as_secs_f64() / self.refill_rate.as_secs_f64();
inner.available_tokens =
(inner.available_tokens + tokens_to_add).min(self.max_tokens as f64);
inner.last_refill = now;
if inner.available_tokens >= tokens as f64 {
inner.available_tokens -= tokens as f64;
true
} else {
false
}
}
pub fn available_tokens(&self) -> usize {
self.inner
.lock()
.unwrap_or_else(|e| e.into_inner())
.available_tokens as usize
}
pub fn clone_inner(&self) -> Self {
self.clone()
}
}
impl RateLimiter {
pub async fn run<F, Fut, T, E>(&self, mut f: F) -> Result<T, RateLimitResult<E>>
where
F: FnMut() -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Clone + Send,
{
if !self.try_consume(1) {
return Err(RateLimitResult::RateLimited);
}
f().await.map_err(RateLimitResult::Inner)
}
}
impl<T, E> Policy<T, E> for RateLimiter
where
E: Send,
{
fn call<F, Fut>(&self, f: &mut F) -> impl Future<Output = Result<T, E>> + Send
where
F: FnMut() -> Fut + Send,
Fut: Future<Output = Result<T, E>> + Send,
T: Send,
E: Send,
{
let this = self.clone();
async move {
if !this.try_consume(1) {
return f().await;
}
f().await
}
}
}
fn _assert_send() {
fn is_send<T: Send>() {}
is_send::<RateLimiter>();
}