use std::fmt;
use std::sync::Arc;
use std::sync::atomic::Ordering;
fn usize_to_f64(value: usize) -> f64 {
#[allow(clippy::cast_precision_loss)]
{
value as f64
}
}
use super::reservation::{AdmissionError, Reservation, ReservationGuard};
use super::state::SharedBufferState;
#[derive(Debug)]
pub struct AdmissionController {
buffer_state: Arc<SharedBufferState>,
max_bytes: usize,
max_ops: usize,
}
impl AdmissionController {
#[must_use]
pub fn new(buffer_state: Arc<SharedBufferState>, max_bytes: usize, max_ops: usize) -> Self {
Self {
buffer_state,
max_bytes,
max_ops,
}
}
#[must_use]
pub fn buffer_state(&self) -> &Arc<SharedBufferState> {
&self.buffer_state
}
#[must_use]
pub fn max_bytes(&self) -> usize {
self.max_bytes
}
#[must_use]
pub fn max_ops(&self) -> usize {
self.max_ops
}
pub fn try_reserve(
&self,
bytes: usize,
ops: usize,
) -> Result<ReservationGuard, AdmissionError> {
if bytes == 0 && ops == 0 {
return Err(AdmissionError::ZeroReservation);
}
self.try_reserve_bytes(bytes)?;
self.try_reserve_ops(ops, bytes)?;
Ok(ReservationGuard::new(
Arc::clone(&self.buffer_state),
Reservation { bytes, ops },
))
}
fn try_reserve_bytes(&self, bytes: usize) -> Result<(), AdmissionError> {
if bytes == 0 {
return Ok(());
}
loop {
let (current_reserved, committed) = self.load_bytes_snapshot();
self.check_byte_limit(committed, current_reserved, bytes)?;
if self.try_update_reserved_bytes(current_reserved, bytes) {
return Ok(());
}
}
}
fn try_reserve_ops(&self, ops: usize, bytes: usize) -> Result<(), AdmissionError> {
if ops == 0 {
return Ok(());
}
loop {
let (current_reserved, committed) = self.load_ops_snapshot();
if let Err(err) = self.check_ops_limit(committed, current_reserved, ops) {
self.rollback_bytes_if_needed(bytes);
return Err(err);
}
if self.try_update_reserved_ops(current_reserved, ops) {
return Ok(());
}
}
}
fn load_bytes_snapshot(&self) -> (usize, usize) {
let reserved = self.buffer_state.reserved_bytes.load(Ordering::Acquire);
let committed = self.buffer_state.committed_bytes.load(Ordering::Acquire);
(reserved, committed)
}
fn load_ops_snapshot(&self) -> (usize, usize) {
let reserved = self.buffer_state.reserved_ops.load(Ordering::Acquire);
let committed = self.buffer_state.committed_ops.load(Ordering::Acquire);
(reserved, committed)
}
fn check_byte_limit(
&self,
committed: usize,
reserved: usize,
requested: usize,
) -> Result<(), AdmissionError> {
let new_total = committed.saturating_add(reserved).saturating_add(requested);
if new_total > self.max_bytes {
return Err(AdmissionError::ByteLimitExceeded {
requested,
available: self
.max_bytes
.saturating_sub(committed.saturating_add(reserved)),
max: self.max_bytes,
});
}
Ok(())
}
fn check_ops_limit(
&self,
committed: usize,
reserved: usize,
requested: usize,
) -> Result<(), AdmissionError> {
let new_total = committed.saturating_add(reserved).saturating_add(requested);
if new_total > self.max_ops {
return Err(AdmissionError::OpsLimitExceeded {
requested,
available: self
.max_ops
.saturating_sub(committed.saturating_add(reserved)),
max: self.max_ops,
});
}
Ok(())
}
fn try_update_reserved_bytes(&self, current_reserved: usize, bytes: usize) -> bool {
self.buffer_state
.reserved_bytes
.compare_exchange_weak(
current_reserved,
current_reserved + bytes,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
}
fn try_update_reserved_ops(&self, current_reserved: usize, ops: usize) -> bool {
self.buffer_state
.reserved_ops
.compare_exchange_weak(
current_reserved,
current_reserved + ops,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
}
fn rollback_bytes_if_needed(&self, bytes: usize) {
if bytes > 0 {
self.compensating_rollback_bytes(bytes);
}
}
fn compensating_rollback_bytes(&self, bytes: usize) {
loop {
let current = self.buffer_state.reserved_bytes.load(Ordering::Acquire);
if current >= bytes {
if self
.buffer_state
.reserved_bytes
.compare_exchange_weak(
current,
current - bytes,
Ordering::AcqRel,
Ordering::Relaxed,
)
.is_ok()
{
return;
}
} else {
return;
}
}
}
#[must_use]
pub fn utilization(&self) -> f64 {
let bytes_util = if self.max_bytes > 0 {
usize_to_f64(self.buffer_state.total_bytes()) / usize_to_f64(self.max_bytes)
} else {
0.0
};
let ops_util = if self.max_ops > 0 {
usize_to_f64(self.buffer_state.total_ops()) / usize_to_f64(self.max_ops)
} else {
0.0
};
bytes_util.max(ops_util)
}
#[must_use]
pub fn is_near_limit(&self) -> bool {
self.utilization() >= 0.8
}
#[must_use]
pub fn stats(&self) -> AdmissionControllerStats {
AdmissionControllerStats {
max_bytes: self.max_bytes,
max_ops: self.max_ops,
committed_bytes: self.buffer_state.committed_bytes(),
committed_ops: self.buffer_state.committed_ops(),
reserved_bytes: self.buffer_state.reserved_bytes(),
reserved_ops: self.buffer_state.reserved_ops(),
active_guards: self.buffer_state.active_guards(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AdmissionControllerStats {
pub max_bytes: usize,
pub max_ops: usize,
pub committed_bytes: usize,
pub committed_ops: usize,
pub reserved_bytes: usize,
pub reserved_ops: usize,
pub active_guards: usize,
}
impl AdmissionControllerStats {
#[inline]
#[must_use]
pub const fn total_bytes(&self) -> usize {
self.committed_bytes + self.reserved_bytes
}
#[inline]
#[must_use]
pub const fn total_ops(&self) -> usize {
self.committed_ops + self.reserved_ops
}
#[inline]
#[must_use]
pub const fn available_bytes(&self) -> usize {
self.max_bytes.saturating_sub(self.total_bytes())
}
#[inline]
#[must_use]
pub const fn available_ops(&self) -> usize {
self.max_ops.saturating_sub(self.total_ops())
}
}
impl fmt::Display for AdmissionControllerStats {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"bytes: {}/{} ({}% used), ops: {}/{} ({}% used), guards: {}",
self.total_bytes(),
self.max_bytes,
if self.max_bytes > 0 {
(self.total_bytes() * 100) / self.max_bytes
} else {
0
},
self.total_ops(),
self.max_ops,
if self.max_ops > 0 {
(self.total_ops() * 100) / self.max_ops
} else {
0
},
self.active_guards
)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_controller(max_bytes: usize, max_ops: usize) -> AdmissionController {
let state = Arc::new(SharedBufferState::new());
AdmissionController::new(state, max_bytes, max_ops)
}
#[test]
fn test_new() {
let state = Arc::new(SharedBufferState::new());
let controller = AdmissionController::new(Arc::clone(&state), 1000, 100);
assert_eq!(controller.max_bytes(), 1000);
assert_eq!(controller.max_ops(), 100);
}
#[test]
fn test_try_reserve_success() {
let controller = make_controller(1000, 100);
let guard = controller.try_reserve(100, 1).expect("should succeed");
assert_eq!(controller.buffer_state().reserved_bytes(), 100);
assert_eq!(controller.buffer_state().reserved_ops(), 1);
assert_eq!(controller.buffer_state().active_guards(), 1);
let _ = guard.abort();
assert_eq!(controller.buffer_state().reserved_bytes(), 0);
assert_eq!(controller.buffer_state().reserved_ops(), 0);
assert_eq!(controller.buffer_state().active_guards(), 0);
}
#[test]
fn test_try_reserve_zero_rejected() {
let controller = make_controller(1000, 100);
let result = controller.try_reserve(0, 0);
assert!(matches!(result, Err(AdmissionError::ZeroReservation)));
}
#[test]
fn test_try_reserve_bytes_only() {
let controller = make_controller(1000, 100);
let guard = controller.try_reserve(100, 0).expect("should succeed");
assert_eq!(controller.buffer_state().reserved_bytes(), 100);
assert_eq!(controller.buffer_state().reserved_ops(), 0);
let _ = guard.abort();
}
#[test]
fn test_try_reserve_ops_only() {
let controller = make_controller(1000, 100);
let guard = controller.try_reserve(0, 10).expect("should succeed");
assert_eq!(controller.buffer_state().reserved_bytes(), 0);
assert_eq!(controller.buffer_state().reserved_ops(), 10);
let _ = guard.abort();
}
#[test]
fn test_byte_limit_exceeded() {
let controller = make_controller(100, 100);
let result = controller.try_reserve(150, 1);
match result {
Err(AdmissionError::ByteLimitExceeded {
requested,
available,
max,
}) => {
assert_eq!(requested, 150);
assert_eq!(available, 100);
assert_eq!(max, 100);
}
_ => panic!("Expected ByteLimitExceeded"),
}
}
#[test]
fn test_ops_limit_exceeded() {
let controller = make_controller(1000, 10);
let result = controller.try_reserve(100, 20);
match result {
Err(AdmissionError::OpsLimitExceeded {
requested,
available,
max,
}) => {
assert_eq!(requested, 20);
assert_eq!(available, 10);
assert_eq!(max, 10);
}
_ => panic!("Expected OpsLimitExceeded"),
}
assert_eq!(controller.buffer_state().reserved_bytes(), 0);
}
#[test]
fn test_multiple_reservations() {
let controller = make_controller(1000, 100);
let guard1 = controller.try_reserve(100, 5).unwrap();
let guard2 = controller.try_reserve(200, 10).unwrap();
let guard3 = controller.try_reserve(300, 15).unwrap();
assert_eq!(controller.buffer_state().reserved_bytes(), 600);
assert_eq!(controller.buffer_state().reserved_ops(), 30);
assert_eq!(controller.buffer_state().active_guards(), 3);
let _ = guard1.abort();
assert_eq!(controller.buffer_state().reserved_bytes(), 500);
assert_eq!(controller.buffer_state().active_guards(), 2);
let _ = guard2.commit();
assert_eq!(controller.buffer_state().reserved_bytes(), 300);
assert_eq!(controller.buffer_state().committed_bytes(), 200);
assert_eq!(controller.buffer_state().active_guards(), 1);
let _ = guard3.abort();
assert_eq!(controller.buffer_state().reserved_bytes(), 0);
assert_eq!(controller.buffer_state().committed_bytes(), 200);
assert_eq!(controller.buffer_state().active_guards(), 0);
}
#[test]
fn test_commit_with_actual() {
let controller = make_controller(1000, 100);
let guard = controller.try_reserve(100, 10).unwrap();
guard.commit_with_actual(50, 5).unwrap();
assert_eq!(controller.buffer_state().reserved_bytes(), 0);
assert_eq!(controller.buffer_state().committed_bytes(), 50);
assert_eq!(controller.buffer_state().committed_ops(), 5);
}
#[test]
fn test_limit_boundary() {
let controller = make_controller(100, 10);
let guard = controller.try_reserve(100, 10).unwrap();
assert_eq!(controller.buffer_state().reserved_bytes(), 100);
assert_eq!(controller.buffer_state().reserved_ops(), 10);
let result = controller.try_reserve(1, 1);
assert!(result.is_err());
let _ = guard.abort();
}
#[test]
fn test_utilization() {
let controller = make_controller(100, 100);
assert!(controller.utilization().abs() < f64::EPSILON);
let guard = controller.try_reserve(50, 25).unwrap();
assert!((controller.utilization() - 0.5).abs() < 0.01);
assert!(!controller.is_near_limit());
let _ = guard.abort();
let guard2 = controller.try_reserve(80, 10).unwrap();
assert!((controller.utilization() - 0.8).abs() < 0.01);
assert!(controller.is_near_limit());
let _ = guard2.abort();
}
#[test]
fn test_stats() {
let controller = make_controller(1000, 100);
let guard = controller.try_reserve(200, 20).unwrap();
guard.commit_with_actual(150, 15).unwrap();
let stats = controller.stats();
assert_eq!(stats.max_bytes, 1000);
assert_eq!(stats.max_ops, 100);
assert_eq!(stats.committed_bytes, 150);
assert_eq!(stats.committed_ops, 15);
assert_eq!(stats.reserved_bytes, 0);
assert_eq!(stats.reserved_ops, 0);
assert_eq!(stats.total_bytes(), 150);
assert_eq!(stats.available_bytes(), 850);
}
#[test]
fn test_stats_display() {
let stats = AdmissionControllerStats {
max_bytes: 1000,
max_ops: 100,
committed_bytes: 100,
committed_ops: 10,
reserved_bytes: 50,
reserved_ops: 5,
active_guards: 1,
};
let display = format!("{stats}");
assert!(display.contains("bytes: 150/1000"));
assert!(display.contains("ops: 15/100"));
assert!(display.contains("guards: 1"));
}
#[test]
fn test_concurrent_reservations() {
use std::thread;
let state = Arc::new(SharedBufferState::new());
let controller = Arc::new(AdmissionController::new(Arc::clone(&state), 10_000, 1_000));
let mut handles = vec![];
for _ in 0..10 {
let controller = Arc::clone(&controller);
handles.push(thread::spawn(move || {
for _ in 0..10 {
if let Ok(guard) = controller.try_reserve(100, 1) {
std::thread::yield_now();
let _ = guard.commit();
} else {
}
}
}));
}
for h in handles {
h.join().unwrap();
}
assert_eq!(state.active_guards(), 0);
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.reserved_ops(), 0);
}
#[test]
fn test_compensating_rollback_under_contention() {
let state = Arc::new(SharedBufferState::new());
let controller = AdmissionController::new(Arc::clone(&state), 10_000, 1);
let guard1 = controller.try_reserve(100, 1).unwrap();
let result = controller.try_reserve(100, 1);
assert!(matches!(
result,
Err(AdmissionError::OpsLimitExceeded { .. })
));
assert_eq!(state.reserved_bytes(), 100); assert_eq!(state.reserved_ops(), 1);
let _ = guard1.abort();
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.reserved_ops(), 0);
}
}