use crate::container::is_running_in_container;
use crate::shutdown;
use crate::sysinfo::{cpu::CPU, memory::Memory};
use parking_lot::Mutex;
use ringbuf::traits::*;
use ringbuf::HeapRb;
use std::sync::atomic::{AtomicBool, AtomicU64, AtomicU8, Ordering};
use std::sync::Arc;
use std::time::{Duration, Instant};
use tracing::{info, warn};
pub struct RequestGuard<'a> {
limiter: &'a BBR,
created_at: Instant,
}
impl Drop for RequestGuard<'_> {
fn drop(&mut self) {
let rt = self.created_at.elapsed().as_millis().max(1) as u64;
self.limiter.rolling_window.add(rt);
self.limiter.rolling_window.sub_in_flight();
}
}
#[derive(Clone)]
pub struct BBRConfig {
pub bucket_count: u32,
pub bucket_interval: Duration,
pub cpu_threshold: u8,
pub memory_threshold: u8,
pub shed_cooldown: Duration,
pub collect_interval: Duration,
}
impl Default for BBRConfig {
fn default() -> Self {
Self {
bucket_count: 50,
bucket_interval: Duration::from_millis(200),
cpu_threshold: 85,
memory_threshold: 85,
shed_cooldown: Duration::from_secs(5),
collect_interval: Duration::from_secs(3),
}
}
}
pub struct BBR {
rolling_window: RollingWindow,
overload_collector: Arc<OverloadCollector>,
shed_cooldown: Duration,
shed_at: Mutex<Option<Instant>>,
}
impl BBR {
pub async fn new(config: BBRConfig) -> Self {
let overload_collector = Arc::new(OverloadCollector::new(config.clone()));
overload_collector.collect_overloaded().await;
let overload_collector_clone = overload_collector.clone();
tokio::spawn(async move {
overload_collector_clone.run().await;
});
Self {
rolling_window: RollingWindow::new(config.bucket_count, config.bucket_interval),
overload_collector,
shed_cooldown: config.shed_cooldown,
shed_at: Mutex::new(None),
}
}
pub async fn acquire(&self) -> Option<RequestGuard<'_>> {
self.rolling_window.add_in_flight();
if self.should_shed().await {
self.rolling_window.sub_in_flight();
return None;
}
Some(RequestGuard {
limiter: self,
created_at: Instant::now(),
})
}
async fn should_shed(&self) -> bool {
if self.is_in_cooldown() {
return true;
}
if !self.overload_collector.is_overloaded() {
return false;
}
let (max_pass, min_rt, in_flight) = self.rolling_window.get_stats();
if max_pass == 0 || in_flight == 0 {
return false;
}
let estimated_limit =
(max_pass as f64 * min_rt as f64 * self.rolling_window.bucket_count() as f64 / 1000.0)
.round() as u64;
if estimated_limit >= in_flight {
return false;
}
warn!(
"overloaded: cpu={}%, memory={}%, estimated_limit={}, in_flight={}",
self.overload_collector.cpu_used_percent(),
self.overload_collector.memory_used_percent(),
estimated_limit,
in_flight
);
self.shed_at.lock().replace(Instant::now());
true
}
#[inline]
fn is_in_cooldown(&self) -> bool {
self.shed_at
.lock()
.map_or(false, |shed_at| shed_at.elapsed() < self.shed_cooldown)
}
}
struct OverloadCollector {
is_overloaded: AtomicBool,
cpu: CPU,
cpu_threshold: u8,
cpu_used_percent: AtomicU8,
memory: Memory,
memory_threshold: u8,
memory_used_percent: AtomicU8,
collect_interval: Duration,
pid: u32,
is_running_in_container: bool,
}
impl OverloadCollector {
pub fn new(config: BBRConfig) -> Self {
Self {
is_overloaded: AtomicBool::new(false),
cpu: CPU::new(),
cpu_threshold: config.cpu_threshold,
cpu_used_percent: AtomicU8::new(0),
memory: Memory::default(),
memory_threshold: config.memory_threshold,
memory_used_percent: AtomicU8::new(0),
collect_interval: config.collect_interval,
pid: std::process::id(),
is_running_in_container: is_running_in_container(),
}
}
pub async fn run(&self) {
let mut interval = tokio::time::interval(self.collect_interval);
loop {
tokio::select! {
_ = interval.tick() => {
self.collect_overloaded().await;
}
_ = shutdown::shutdown_signal() => {
info!("ratelimiter's collecting server shutting down");
return
}
}
}
}
pub async fn collect_overloaded(&self) {
self.is_overloaded.store(
self.is_cpu_overloaded().await || self.is_memory_overloaded(),
Ordering::Relaxed,
);
}
pub fn is_overloaded(&self) -> bool {
self.is_overloaded.load(Ordering::Relaxed)
}
pub fn cpu_used_percent(&self) -> u8 {
self.cpu_used_percent.load(Ordering::Relaxed)
}
pub fn memory_used_percent(&self) -> u8 {
self.memory_used_percent.load(Ordering::Relaxed)
}
#[inline]
fn is_memory_overloaded(&self) -> bool {
let used_percent = if self.is_running_in_container {
match self.memory.get_cgroup_stats(self.pid) {
Some(stats) => (stats.used_percent * 100.0).round() as u8,
None => {
warn!("container detected but cgroup memory stats unavailable, falling back to process stats");
(self.memory.get_process_stats(self.pid).used_percent * 100.0).round() as u8
}
}
} else {
(self.memory.get_process_stats(self.pid).used_percent * 100.0).round() as u8
};
self.memory_used_percent
.store(used_percent, Ordering::Relaxed);
used_percent >= self.memory_threshold
}
#[inline]
async fn is_cpu_overloaded(&self) -> bool {
let used_percent = if self.is_running_in_container {
match self.cpu.get_cgroup_stats(self.pid).await {
Some(stats) => stats.used_percent.round() as u8,
None => {
warn!("container detected but cgroup CPU stats unavailable, falling back to process stats");
self.cpu.get_process_stats(self.pid).used_percent.round() as u8
}
}
} else {
self.cpu.get_process_stats(self.pid).used_percent.round() as u8
};
self.cpu_used_percent.store(used_percent, Ordering::Relaxed);
used_percent >= self.cpu_threshold
}
}
#[derive(Clone, Copy)]
struct Sample {
pass: u64,
min_rt: u64,
}
pub struct RollingWindow {
ring: Mutex<HeapRb<Sample>>,
bucket_count: u32,
bucket_interval: Duration,
current_bucket: Mutex<(Instant, u64, u64)>,
in_flight: AtomicU64,
}
impl RollingWindow {
pub fn new(bucket_count: u32, bucket_interval: Duration) -> Self {
Self {
ring: Mutex::new(HeapRb::new(bucket_count as usize)),
bucket_count,
bucket_interval,
current_bucket: Mutex::new((Instant::now(), 0, u64::MAX)),
in_flight: AtomicU64::new(0),
}
}
pub fn add(&self, rt: u64) {
let now = Instant::now();
let mut current_bucket = self.current_bucket.lock();
if now.duration_since(current_bucket.0) >= self.bucket_interval {
if current_bucket.1 > 0 {
let mut ring = self.ring.lock();
ring.push_overwrite(Sample {
pass: current_bucket.1,
min_rt: current_bucket.2,
});
}
*current_bucket = (now, 0, u64::MAX);
}
current_bucket.1 += 1;
current_bucket.2 = current_bucket.2.min(rt);
}
pub fn get_stats(&self) -> (u64, u64, u64) {
let ring = self.ring.lock();
let (max_pass, min_rt) = ring.iter().fold((0, u64::MAX), |(max_pass, min_rt), s| {
(max_pass.max(s.pass), min_rt.min(s.min_rt))
});
let min_rt = if min_rt == u64::MAX { 1 } else { min_rt };
(max_pass, min_rt, self.in_flight())
}
#[inline]
pub fn in_flight(&self) -> u64 {
self.in_flight.load(Ordering::Relaxed)
}
#[inline]
pub fn add_in_flight(&self) -> u64 {
self.in_flight.fetch_add(1, Ordering::Relaxed) + 1
}
#[inline]
pub fn sub_in_flight(&self) {
self.in_flight.fetch_sub(1, Ordering::Relaxed);
}
#[inline]
pub fn bucket_count(&self) -> u32 {
self.bucket_count
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
#[test]
fn test_new_rolling_window() {
let window = RollingWindow::new(5, Duration::from_millis(50));
assert_eq!(window.in_flight(), 0);
let (max_pass, min_rt, in_flight) = window.get_stats();
assert_eq!(max_pass, 0);
assert_eq!(min_rt, 1);
assert_eq!(in_flight, 0);
}
#[test]
fn test_add_single_request() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add(100);
let current_bucket = window.current_bucket.lock();
assert_eq!(current_bucket.1, 1);
assert_eq!(current_bucket.2, 100);
}
#[test]
fn test_add_multiple_requests_same_bucket() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add(100);
window.add(50);
window.add(200);
let current_bucket = window.current_bucket.lock();
assert_eq!(current_bucket.1, 3);
assert_eq!(current_bucket.2, 50);
}
#[test]
fn test_bucket_rotation() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add(100);
window.add(80);
thread::sleep(Duration::from_millis(150));
window.add(150);
let ring = window.ring.lock();
assert_eq!(ring.occupied_len(), 1);
let sample = ring.iter().next().unwrap();
assert_eq!(sample.pass, 2);
assert_eq!(sample.min_rt, 80);
}
#[test]
fn test_in_flight_initial_value() {
let window = RollingWindow::new(10, Duration::from_millis(100));
assert_eq!(window.in_flight(), 0);
}
#[test]
fn test_add_in_flight() {
let window = RollingWindow::new(10, Duration::from_millis(100));
assert_eq!(window.add_in_flight(), 1);
assert_eq!(window.add_in_flight(), 2);
assert_eq!(window.add_in_flight(), 3);
assert_eq!(window.in_flight(), 3);
}
#[test]
fn test_sub_in_flight() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add_in_flight();
window.add_in_flight();
window.add_in_flight();
window.sub_in_flight();
assert_eq!(window.in_flight(), 2);
window.sub_in_flight();
assert_eq!(window.in_flight(), 1);
window.sub_in_flight();
assert_eq!(window.in_flight(), 0);
}
#[test]
fn test_get_stats_empty_window() {
let window = RollingWindow::new(10, Duration::from_millis(100));
let (max_pass, min_rt, in_flight) = window.get_stats();
assert_eq!(max_pass, 0);
assert_eq!(min_rt, 1);
assert_eq!(in_flight, 0);
}
#[test]
fn test_get_stats_with_data() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add(100);
window.add(50);
thread::sleep(Duration::from_millis(120));
window.add(200);
window.add(150);
window.add(120);
thread::sleep(Duration::from_millis(110));
window.add(80);
window.add_in_flight();
window.add_in_flight();
let (max_pass, min_rt, in_flight) = window.get_stats();
assert_eq!(max_pass, 3);
assert_eq!(min_rt, 50);
assert_eq!(in_flight, 2);
}
#[test]
fn test_get_stats_includes_in_flight() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add_in_flight();
window.add_in_flight();
window.add_in_flight();
let (_, _, in_flight) = window.get_stats();
assert_eq!(in_flight, 3);
}
#[test]
fn test_expired_samples_filtered() {
let window = RollingWindow::new(3, Duration::from_millis(100));
window.add(100);
thread::sleep(Duration::from_millis(120));
window.add(80);
thread::sleep(Duration::from_millis(110));
window.add(60);
thread::sleep(Duration::from_millis(110));
window.add(40);
thread::sleep(Duration::from_millis(200));
window.add(30);
let (max_pass, min_rt, _) = window.get_stats();
assert_eq!(max_pass, 1);
assert!(min_rt == 40);
}
#[test]
fn test_zero_response_time() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add(0);
let current_bucket = window.current_bucket.lock();
assert_eq!(current_bucket.1, 1);
assert_eq!(current_bucket.2, 0);
}
#[test]
fn test_max_response_time() {
let window = RollingWindow::new(10, Duration::from_millis(100));
window.add(u64::MAX - 1);
let current_bucket = window.current_bucket.lock();
assert_eq!(current_bucket.1, 1);
assert_eq!(current_bucket.2, u64::MAX - 1);
}
#[test]
fn test_single_bucket_window() {
let window = RollingWindow::new(1, Duration::from_millis(100));
window.add(100);
thread::sleep(Duration::from_millis(150));
window.add(50);
let ring = window.ring.lock();
assert_eq!(ring.occupied_len(), 1);
}
#[test]
fn test_empty_bucket_not_flushed() {
let window = RollingWindow::new(10, Duration::from_millis(100));
thread::sleep(Duration::from_millis(120));
window.add(100);
let ring = window.ring.lock();
assert_eq!(ring.occupied_len(), 0);
}
#[test]
fn test_concurrent_add_requests() {
let window = Arc::new(RollingWindow::new(10, Duration::from_millis(100)));
let mut handles = vec![];
for _ in 0..4 {
let w = window.clone();
handles.push(thread::spawn(move || {
for i in 0..100 {
w.add(i as u64);
}
}));
}
for handle in handles {
handle.join().unwrap();
}
let current_bucket = window.current_bucket.lock();
assert!(current_bucket.1 > 0);
assert!(current_bucket.2 <= 99);
}
#[tokio::test]
async fn test_concurrent_add_and_stats() {
let window = Arc::new(RollingWindow::new(10, Duration::from_millis(100)));
let mut handles = vec![];
for _ in 0..2 {
let w = window.clone();
handles.push(tokio::spawn(async move {
for i in 0..50 {
w.add(i as u64 + 10);
tokio::time::sleep(Duration::from_micros(100)).await;
}
}));
}
for _ in 0..2 {
let w = window.clone();
handles.push(tokio::spawn(async move {
for _ in 0..50 {
let (max_pass, min_rt, _) = w.get_stats();
assert!(min_rt > 0);
assert!(max_pass <= 100);
tokio::time::sleep(Duration::from_micros(100)).await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
}
#[test]
fn test_bbr_style_usage() {
let window = RollingWindow::new(10, Duration::from_millis(100));
for rt in [100, 80, 120, 60, 90, 70, 110, 50, 85, 95] {
let _ = window.add_in_flight();
window.add(rt);
window.sub_in_flight();
}
thread::sleep(Duration::from_millis(110));
for rt in [40, 55, 45, 60, 50] {
let _ = window.add_in_flight();
window.add(rt);
window.sub_in_flight();
}
let (max_pass, min_rt, _) = window.get_stats();
let estimated_limit = max_pass as f64 * min_rt as f64 / 1000.0;
assert!(estimated_limit == 0.5 || max_pass == 0);
}
}