use std::fmt;
use std::sync::Arc;
use super::state::SharedBufferState;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AdmissionError {
ByteLimitExceeded {
requested: usize,
available: usize,
max: usize,
},
OpsLimitExceeded {
requested: usize,
available: usize,
max: usize,
},
ZeroReservation,
}
impl fmt::Display for AdmissionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::ByteLimitExceeded {
requested,
available,
max,
} => {
write!(
f,
"byte limit exceeded: requested {requested} but only {available} available (max {max})"
)
}
Self::OpsLimitExceeded {
requested,
available,
max,
} => {
write!(
f,
"ops limit exceeded: requested {requested} but only {available} available (max {max})"
)
}
Self::ZeroReservation => {
write!(f, "zero reservation not allowed: bytes and ops both zero")
}
}
}
}
impl std::error::Error for AdmissionError {}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CommitError {
AlreadyConsumed,
ActualExceedsReservation {
actual_bytes: usize,
reserved_bytes: usize,
actual_ops: usize,
reserved_ops: usize,
},
}
impl fmt::Display for CommitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::AlreadyConsumed => {
write!(f, "reservation already consumed")
}
Self::ActualExceedsReservation {
actual_bytes,
reserved_bytes,
actual_ops,
reserved_ops,
} => {
write!(
f,
"actual exceeds reservation: {actual_bytes} bytes > {reserved_bytes} reserved, {actual_ops} ops > {reserved_ops} reserved"
)
}
}
}
}
impl std::error::Error for CommitError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Reservation {
pub bytes: usize,
pub ops: usize,
}
impl Reservation {
#[inline]
#[must_use]
#[allow(dead_code)] pub(crate) const fn new(bytes: usize, ops: usize) -> Self {
Self { bytes, ops }
}
}
pub struct ReservationGuard {
buffer_state: Arc<SharedBufferState>,
reservation: Option<Reservation>,
}
impl ReservationGuard {
#[allow(dead_code)] pub(crate) fn new(buffer_state: Arc<SharedBufferState>, reservation: Reservation) -> Self {
buffer_state.increment_active_guards();
Self {
buffer_state,
reservation: Some(reservation),
}
}
#[must_use]
pub fn reservation(&self) -> Option<Reservation> {
self.reservation
}
#[must_use]
pub fn reserved_bytes(&self) -> usize {
self.reservation.map_or(0, |r| r.bytes)
}
#[must_use]
pub fn reserved_ops(&self) -> usize {
self.reservation.map_or(0, |r| r.ops)
}
#[must_use]
pub fn is_consumed(&self) -> bool {
self.reservation.is_none()
}
pub fn commit(self) -> Result<(), CommitError> {
let r = self.reservation.ok_or(CommitError::AlreadyConsumed)?;
self.commit_with_actual(r.bytes, r.ops)
}
pub fn commit_with_actual(
mut self,
actual_bytes: usize,
actual_ops: usize,
) -> Result<(), CommitError> {
let r = self
.reservation
.take()
.ok_or(CommitError::AlreadyConsumed)?;
if actual_bytes > r.bytes || actual_ops > r.ops {
self.reservation = Some(r);
return Err(CommitError::ActualExceedsReservation {
actual_bytes,
reserved_bytes: r.bytes,
actual_ops,
reserved_ops: r.ops,
});
}
self.buffer_state
.transfer_reserved_to_committed(r.bytes, r.ops, actual_bytes, actual_ops);
Ok(())
}
pub fn abort(mut self) -> Result<(), CommitError> {
let r = self
.reservation
.take()
.ok_or(CommitError::AlreadyConsumed)?;
self.buffer_state.sub_reserved(r.bytes, r.ops);
Ok(())
}
#[must_use]
#[allow(dead_code)] pub(crate) fn extract(mut self) -> Option<Reservation> {
self.reservation.take()
}
}
impl Drop for ReservationGuard {
fn drop(&mut self) {
if let Some(r) = self.reservation.take() {
let prev_bytes = self
.buffer_state
.reserved_bytes
.fetch_sub(r.bytes, std::sync::atomic::Ordering::AcqRel);
let prev_ops = self
.buffer_state
.reserved_ops
.fetch_sub(r.ops, std::sync::atomic::Ordering::AcqRel);
assert!(
prev_bytes >= r.bytes,
"reserved_bytes underflow in Drop: {} < {}",
prev_bytes,
r.bytes
);
assert!(
prev_ops >= r.ops,
"reserved_ops underflow in Drop: {} < {}",
prev_ops,
r.ops
);
}
self.buffer_state.decrement_active_guards();
}
}
impl fmt::Debug for ReservationGuard {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ReservationGuard")
.field("buffer_state", &self.buffer_state)
.field("reservation", &self.reservation)
.field("consumed", &self.is_consumed())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_state() -> Arc<SharedBufferState> {
Arc::new(SharedBufferState::new())
}
#[test]
fn test_reservation_creation() {
let r = Reservation::new(100, 5);
assert_eq!(r.bytes, 100);
assert_eq!(r.ops, 5);
}
#[test]
fn test_guard_creation() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
assert_eq!(state.active_guards(), 1);
assert_eq!(guard.reserved_bytes(), 100);
assert_eq!(guard.reserved_ops(), 5);
assert!(!guard.is_consumed());
guard.abort().unwrap();
assert_eq!(state.active_guards(), 0);
}
#[test]
fn test_guard_commit() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
guard.commit().unwrap();
assert_eq!(state.active_guards(), 0);
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.committed_bytes(), 100);
}
#[test]
fn test_guard_commit_with_actual() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
guard.commit_with_actual(80, 4).unwrap();
assert_eq!(state.active_guards(), 0);
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.reserved_ops(), 0);
assert_eq!(state.committed_bytes(), 80);
assert_eq!(state.committed_ops(), 4);
}
#[test]
fn test_guard_commit_with_actual_validation() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
let result = guard.commit_with_actual(150, 6);
assert!(matches!(
result,
Err(CommitError::ActualExceedsReservation { .. })
));
assert_eq!(state.active_guards(), 0);
assert_eq!(state.committed_bytes(), 0); }
#[test]
fn test_guard_abort() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
guard.abort().unwrap();
assert_eq!(state.active_guards(), 0);
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.committed_bytes(), 0);
}
#[test]
fn test_guard_auto_abort_on_drop() {
let state = make_state();
state.add_reserved(100, 5);
{
let _guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
}
assert_eq!(state.active_guards(), 0);
assert_eq!(state.reserved_bytes(), 0);
assert_eq!(state.committed_bytes(), 0);
}
#[test]
fn test_guard_already_consumed() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
let mut guard = guard;
let _extracted = guard.reservation.take();
let result = guard.commit();
assert!(matches!(result, Err(CommitError::AlreadyConsumed)));
}
#[test]
fn test_multiple_guards() {
let state = make_state();
state.add_reserved(200, 10);
let guard1 = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
let guard2 = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
assert_eq!(state.active_guards(), 2);
guard1.commit().unwrap();
assert_eq!(state.active_guards(), 1);
guard2.abort().unwrap();
assert_eq!(state.active_guards(), 0);
assert_eq!(state.committed_bytes(), 100);
assert_eq!(state.reserved_bytes(), 0);
}
#[test]
fn test_guard_debug() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
let debug = format!("{guard:?}");
assert!(debug.contains("ReservationGuard"));
assert!(debug.contains("reservation"));
assert!(debug.contains("consumed"));
guard.abort().unwrap();
}
#[test]
fn test_admission_error_display() {
let err = AdmissionError::ByteLimitExceeded {
requested: 100,
available: 50,
max: 1000,
};
assert!(format!("{err}").contains("byte limit exceeded"));
assert!(format!("{err}").contains("available"));
let err = AdmissionError::OpsLimitExceeded {
requested: 10,
available: 5,
max: 100,
};
assert!(format!("{err}").contains("ops limit exceeded"));
let err = AdmissionError::ZeroReservation;
assert!(format!("{err}").contains("zero reservation"));
}
#[test]
fn test_commit_error_display() {
let err = CommitError::AlreadyConsumed;
assert!(format!("{err}").contains("already consumed"));
let err = CommitError::ActualExceedsReservation {
actual_bytes: 150,
reserved_bytes: 100,
actual_ops: 6,
reserved_ops: 5,
};
assert!(format!("{err}").contains("actual exceeds reservation"));
}
#[test]
fn test_guard_extract() {
let state = make_state();
state.add_reserved(100, 5);
let guard = ReservationGuard::new(Arc::clone(&state), Reservation::new(100, 5));
let extracted = guard.extract();
assert!(extracted.is_some());
assert_eq!(extracted.unwrap().bytes, 100);
assert_eq!(state.active_guards(), 0);
}
}