use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::OnceLock;
use super::error::{ResourceError, ResourceResult};
use super::estimate::MemoryEstimate;
use super::system::get_available_memory;
use super::{DEFAULT_MAX_MEMORY_BYTES, SYSTEM_MEMORY_MARGIN};
#[derive(Debug)]
pub struct ResourceGuard {
max_memory_bytes: AtomicU64,
current_memory_bytes: AtomicU64,
reserved_bytes: AtomicU64,
enforce_limits: AtomicBool,
safety_margin: f32,
}
impl Default for ResourceGuard {
fn default() -> Self {
Self::new()
}
}
impl ResourceGuard {
#[must_use]
pub fn new() -> Self {
Self {
max_memory_bytes: AtomicU64::new(DEFAULT_MAX_MEMORY_BYTES),
current_memory_bytes: AtomicU64::new(0),
reserved_bytes: AtomicU64::new(0),
enforce_limits: AtomicBool::new(true),
safety_margin: 0.3, }
}
#[must_use]
pub fn with_max_memory(max_bytes: u64) -> Self {
Self {
max_memory_bytes: AtomicU64::new(max_bytes),
current_memory_bytes: AtomicU64::new(0),
reserved_bytes: AtomicU64::new(0),
enforce_limits: AtomicBool::new(true),
safety_margin: 0.3,
}
}
#[must_use]
pub fn with_safety_margin(mut self, margin: f32) -> Self {
self.safety_margin = margin.clamp(0.0, 0.9);
self
}
#[must_use]
pub fn unguarded() -> Self {
let guard = Self::new();
guard.enforce_limits.store(false, Ordering::SeqCst);
guard
}
pub fn set_max_memory(&self, max_bytes: u64) {
self.max_memory_bytes.store(max_bytes, Ordering::SeqCst);
}
#[must_use]
pub fn max_memory(&self) -> u64 {
self.max_memory_bytes.load(Ordering::SeqCst)
}
#[must_use]
pub fn current_memory(&self) -> u64 {
self.current_memory_bytes.load(Ordering::SeqCst)
}
#[must_use]
pub fn reserved_memory(&self) -> u64 {
self.reserved_bytes.load(Ordering::SeqCst)
}
#[must_use]
pub fn available_memory(&self) -> u64 {
let max = self.max_memory();
let current = self.current_memory();
let reserved = self.reserved_memory();
max.saturating_sub(current).saturating_sub(reserved)
}
#[must_use]
pub fn can_allocate(&self, bytes: u64) -> bool {
if !self.enforce_limits.load(Ordering::SeqCst) {
return true;
}
bytes <= self.available_memory()
}
pub fn can_allocate_safe(&self, bytes: u64) -> ResourceResult<()> {
if !self.enforce_limits.load(Ordering::SeqCst) {
return Ok(());
}
let current = self.current_memory();
let reserved = self.reserved_memory();
let max = self.max_memory();
let used = current.saturating_add(reserved);
if used.saturating_add(bytes) > max {
return Err(ResourceError::MemoryLimitExceeded {
requested: bytes,
current: used,
max,
});
}
if let Some(available) = get_available_memory() {
if bytes > available.saturating_sub(SYSTEM_MEMORY_MARGIN) {
return Err(ResourceError::InsufficientSystemMemory {
requested: bytes,
available,
margin: SYSTEM_MEMORY_MARGIN,
});
}
}
Ok(())
}
pub fn record_allocation(&self, bytes: u64) {
self.current_memory_bytes.fetch_add(bytes, Ordering::SeqCst);
}
pub fn record_deallocation(&self, bytes: u64) {
self.current_memory_bytes.fetch_sub(bytes, Ordering::SeqCst);
}
pub fn reserve(&self, bytes: u64) -> ResourceResult<ReservationGuard<'_>> {
self.can_allocate_safe(bytes)?;
self.reserved_bytes.fetch_add(bytes, Ordering::SeqCst);
Ok(ReservationGuard {
guard: self,
bytes,
committed: false,
})
}
pub fn validate(&self, estimate: &MemoryEstimate) -> ResourceResult<()> {
self.can_allocate_safe(estimate.peak_bytes)
}
#[must_use]
pub fn max_safe_elements(&self, bytes_per_element: usize) -> usize {
if bytes_per_element == 0 {
return usize::MAX;
}
let available = self.available_memory();
let safe_bytes = (available as f64 * (1.0 - self.safety_margin as f64)) as u64;
(safe_bytes / bytes_per_element as u64) as usize
}
pub fn set_enforce_limits(&self, enforce: bool) {
self.enforce_limits.store(enforce, Ordering::SeqCst);
}
#[must_use]
pub fn is_enforcing(&self) -> bool {
self.enforce_limits.load(Ordering::SeqCst)
}
}
#[derive(Debug)]
pub struct ReservationGuard<'a> {
guard: &'a ResourceGuard,
bytes: u64,
committed: bool,
}
impl<'a> ReservationGuard<'a> {
#[must_use]
pub fn bytes(&self) -> u64 {
self.bytes
}
pub fn commit(mut self) {
self.guard
.reserved_bytes
.fetch_sub(self.bytes, Ordering::SeqCst);
self.guard.record_allocation(self.bytes);
self.committed = true;
}
pub fn release(mut self) {
self.guard
.reserved_bytes
.fetch_sub(self.bytes, Ordering::SeqCst);
self.committed = true; }
}
impl<'a> Drop for ReservationGuard<'a> {
fn drop(&mut self) {
if !self.committed {
self.guard
.reserved_bytes
.fetch_sub(self.bytes, Ordering::SeqCst);
}
}
}
static GLOBAL_GUARD: OnceLock<ResourceGuard> = OnceLock::new();
pub fn global_guard() -> &'static ResourceGuard {
GLOBAL_GUARD.get_or_init(ResourceGuard::new)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_resource_guard_new() {
let guard = ResourceGuard::new();
assert_eq!(guard.current_memory(), 0);
assert_eq!(guard.max_memory(), DEFAULT_MAX_MEMORY_BYTES);
}
#[test]
fn test_resource_guard_allocation() {
let guard = ResourceGuard::with_max_memory(1_000_000);
assert!(guard.can_allocate(100_000));
assert!(!guard.can_allocate(2_000_000));
guard.record_allocation(500_000);
assert_eq!(guard.current_memory(), 500_000);
assert!(!guard.can_allocate(600_000));
assert!(guard.can_allocate(400_000));
}
#[test]
fn test_resource_guard_deallocation() {
let guard = ResourceGuard::with_max_memory(1_000_000);
guard.record_allocation(500_000);
assert_eq!(guard.current_memory(), 500_000);
guard.record_deallocation(200_000);
assert_eq!(guard.current_memory(), 300_000);
}
#[test]
fn test_reservation_guard() {
let guard = ResourceGuard::with_max_memory(1_000_000);
{
let reservation = guard.reserve(500_000).unwrap();
assert_eq!(guard.reserved_memory(), 500_000);
assert_eq!(reservation.bytes(), 500_000);
assert!(!guard.can_allocate(600_000));
}
assert_eq!(guard.reserved_memory(), 0);
assert!(guard.can_allocate(600_000));
}
#[test]
fn test_reservation_commit() {
let guard = ResourceGuard::with_max_memory(1_000_000);
let reservation = guard.reserve(500_000).unwrap();
reservation.commit();
assert_eq!(guard.reserved_memory(), 0);
assert_eq!(guard.current_memory(), 500_000);
}
#[test]
fn test_reservation_release() {
let guard = ResourceGuard::with_max_memory(1_000_000);
let reservation = guard.reserve(500_000).unwrap();
reservation.release();
assert_eq!(guard.reserved_memory(), 0);
assert_eq!(guard.current_memory(), 0);
}
#[test]
fn test_max_safe_elements() {
let guard = ResourceGuard::with_max_memory(1_000_000);
let max_elements = guard.max_safe_elements(100);
assert!(
(6999..=7001).contains(&max_elements),
"max_elements {} not in range [6999, 7001]",
max_elements
);
}
#[test]
fn test_unguarded() {
let guard = ResourceGuard::unguarded();
assert!(guard.can_allocate(u64::MAX / 2));
}
#[test]
fn test_global_guard() {
let guard = global_guard();
assert_eq!(guard.current_memory(), 0);
}
}