use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use parking_lot::Mutex;
use rakka_core::actor::{Actor, Context};
use rakka_distributed_data::{DeltaCrdt, GCounter};
use tokio::sync::oneshot;
use inference_core::deployment::RateLimits;
use inference_core::error::InferenceError;
#[derive(Debug)]
pub struct Permit {
pub requests: u32,
pub tokens: u32,
}
pub struct AcquirePermit {
pub requests: u32,
pub tokens: u32,
pub reply: oneshot::Sender<Result<Permit, InferenceError>>,
}
#[derive(Default)]
struct Window {
started_at: Option<Instant>,
requests: u64,
tokens: u64,
}
#[derive(Default, Clone)]
pub struct RateLimiterHandle {
state: Arc<Mutex<Window>>,
}
impl RateLimiterHandle {
pub fn snapshot(&self) -> (u64, u64) {
let s = self.state.lock();
(s.requests, s.tokens)
}
}
pub struct RateLimiterActor {
node_id: String,
limits: RateLimits,
requests_counter: GCounter,
tokens_counter: GCounter,
window: Arc<Mutex<Window>>,
}
impl RateLimiterActor {
pub fn new(node_id: impl Into<String>, limits: RateLimits) -> Self {
Self {
node_id: node_id.into(),
limits,
requests_counter: GCounter::new(),
tokens_counter: GCounter::new(),
window: Arc::new(Mutex::new(Window::default())),
}
}
pub fn handle(&self) -> RateLimiterHandle {
RateLimiterHandle {
state: self.window.clone(),
}
}
pub fn merge_remote_delta_requests(&mut self, delta: &<GCounter as DeltaCrdt>::Delta) {
self.requests_counter.merge_delta(delta);
}
pub fn merge_remote_delta_tokens(&mut self, delta: &<GCounter as DeltaCrdt>::Delta) {
self.tokens_counter.merge_delta(delta);
}
fn rotate_window_if_needed(&mut self) {
let mut w = self.window.lock();
let needs_reset = match w.started_at {
Some(started) => started.elapsed() >= Duration::from_secs(60),
None => true,
};
if needs_reset {
*w = Window {
started_at: Some(Instant::now()),
requests: 0,
tokens: 0,
};
self.requests_counter = GCounter::new();
self.tokens_counter = GCounter::new();
}
}
fn acquire(&mut self, req: AcquirePermit) -> Result<Permit, InferenceError> {
self.rotate_window_if_needed();
let mut w = self.window.lock();
if let Some(rpm) = self.limits.requests_per_minute {
if w.requests + req.requests as u64 > rpm {
return Err(InferenceError::Backpressure(format!(
"requests-per-minute limit reached ({}/{})",
w.requests, rpm
)));
}
}
if let Some(tpm) = self.limits.tokens_per_minute {
if w.tokens + req.tokens as u64 > tpm {
return Err(InferenceError::Backpressure(format!(
"tokens-per-minute limit reached ({}/{})",
w.tokens, tpm
)));
}
}
w.requests += req.requests as u64;
w.tokens += req.tokens as u64;
drop(w);
self.requests_counter
.increment(&self.node_id, req.requests as u64);
self.tokens_counter.increment(&self.node_id, req.tokens as u64);
Ok(Permit {
requests: req.requests,
tokens: req.tokens,
})
}
}
#[async_trait]
impl Actor for RateLimiterActor {
type Msg = AcquirePermit;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
let reply = msg.reply;
let res = self.acquire(AcquirePermit {
reply: dummy_reply(),
..msg
});
let _ = reply.send(res);
}
}
pub struct StrictRateLimiterActor {
inner: RateLimiterActor,
}
impl StrictRateLimiterActor {
pub fn new(inner: RateLimiterActor) -> Self {
Self { inner }
}
}
#[async_trait]
impl Actor for StrictRateLimiterActor {
type Msg = AcquirePermit;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
let reply = msg.reply;
let res = self.inner.acquire(AcquirePermit {
reply: dummy_reply(),
..msg
});
let _ = reply.send(res);
}
}
fn dummy_reply() -> oneshot::Sender<Result<Permit, InferenceError>> {
let (tx, rx) = oneshot::channel();
drop(rx);
tx
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn approximate_limiter_blocks_on_rpm() {
let mut a = RateLimiterActor::new(
"node-a",
RateLimits {
requests_per_minute: Some(2),
tokens_per_minute: None,
concurrent_requests: None,
strict: false,
},
);
let (tx1, _) = oneshot::channel();
let (tx2, _) = oneshot::channel();
let (tx3, _) = oneshot::channel();
assert!(a
.acquire(AcquirePermit {
requests: 1,
tokens: 0,
reply: tx1
})
.is_ok());
assert!(a
.acquire(AcquirePermit {
requests: 1,
tokens: 0,
reply: tx2
})
.is_ok());
assert!(matches!(
a.acquire(AcquirePermit {
requests: 1,
tokens: 0,
reply: tx3
}),
Err(InferenceError::Backpressure(_))
));
}
}