use astrid_core::principal::PrincipalId;
use uuid::Uuid;
pub type RateLimiterKey = (Uuid, PrincipalId);
pub const MAX_IPC_PAYLOAD_BYTES: usize = 5 * 1024 * 1024;
#[derive(Debug)]
pub struct IpcRateLimiter {
state: dashmap::DashMap<RateLimiterKey, (std::time::Instant, usize)>,
last_prune: std::sync::Mutex<std::time::Instant>,
}
impl IpcRateLimiter {
#[must_use]
pub fn new() -> Self {
Self {
state: dashmap::DashMap::new(),
last_prune: std::sync::Mutex::new(std::time::Instant::now()),
}
}
#[expect(clippy::collapsible_if)]
pub fn check_quota(
&self,
capsule_uuid: Uuid,
principal: &PrincipalId,
size_bytes: usize,
max_throughput_bytes_per_sec: usize,
) -> Result<(), String> {
if size_bytes > MAX_IPC_PAYLOAD_BYTES {
return Err("Payload too large".to_string());
}
let now = std::time::Instant::now();
if self.state.len() > 1000 {
if let Ok(mut last) = self.last_prune.try_lock() {
if now.saturating_duration_since(*last).as_secs() > 60 {
*last = now;
self.state
.retain(|_, v| now.saturating_duration_since(v.0).as_secs() < 1);
}
}
}
let key: RateLimiterKey = (capsule_uuid, principal.clone());
let mut entry = self.state.entry(key).or_insert((now, 0));
if now.saturating_duration_since(entry.0).as_secs() >= 1 {
entry.0 = now;
entry.1 = 0;
}
if entry.1.saturating_add(size_bytes) > max_throughput_bytes_per_sec {
return Err("Rate limit exceeded".to_string());
}
entry.1 = entry.1.saturating_add(size_bytes);
Ok(())
}
}
impl Default for IpcRateLimiter {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pid(name: &str) -> PrincipalId {
PrincipalId::new(name).expect("valid principal")
}
const ONE_MIB: usize = 1024 * 1024;
const TEN_MIB: usize = 10 * 1024 * 1024;
#[test]
fn payload_above_hard_cap_rejected_regardless_of_profile() {
let rl = IpcRateLimiter::new();
let err = rl
.check_quota(
Uuid::new_v4(),
&pid("alice"),
MAX_IPC_PAYLOAD_BYTES + 1,
usize::MAX,
)
.expect_err("payload > 5 MiB must reject");
assert!(err.contains("Payload too large"));
}
#[test]
fn single_principal_honors_profile_ceiling() {
let rl = IpcRateLimiter::new();
let cap = Uuid::new_v4();
let p = pid("alice");
rl.check_quota(cap, &p, ONE_MIB, ONE_MIB)
.expect("first send fits");
let err = rl
.check_quota(cap, &p, 1, ONE_MIB)
.expect_err("next byte should bust the 1 MiB cap");
assert!(err.contains("Rate limit exceeded"));
}
#[test]
fn two_principals_have_independent_buckets() {
let rl = IpcRateLimiter::new();
let cap = Uuid::new_v4();
let alice = pid("alice");
let bob = pid("bob");
rl.check_quota(cap, &alice, ONE_MIB, ONE_MIB)
.expect("alice fills her bucket");
assert!(
rl.check_quota(cap, &alice, 1, ONE_MIB).is_err(),
"alice must be rate-limited now"
);
rl.check_quota(cap, &bob, ONE_MIB, TEN_MIB)
.expect("bob unaffected");
rl.check_quota(cap, &bob, ONE_MIB, TEN_MIB)
.expect("bob unaffected still");
}
#[test]
fn same_principal_on_two_capsules_has_independent_buckets() {
let rl = IpcRateLimiter::new();
let cap_a = Uuid::new_v4();
let cap_b = Uuid::new_v4();
let p = pid("alice");
rl.check_quota(cap_a, &p, ONE_MIB, ONE_MIB)
.expect("cap_a fills");
assert!(rl.check_quota(cap_a, &p, 1, ONE_MIB).is_err());
rl.check_quota(cap_b, &p, ONE_MIB, ONE_MIB)
.expect("cap_b independent");
}
#[test]
fn window_resets_after_one_second() {
let rl = IpcRateLimiter::new();
let cap = Uuid::new_v4();
let p = pid("slow");
rl.check_quota(cap, &p, 100, 100).expect("initial fits");
assert!(rl.check_quota(cap, &p, 1, 100).is_err(), "at cap");
std::thread::sleep(std::time::Duration::from_millis(1100));
rl.check_quota(cap, &p, 100, 100)
.expect("after window reset, fresh budget");
}
}