use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use crate::error::{Error, Result};
#[derive(Debug, Clone, Default)]
pub struct MemoryTracker {
inner: Option<Arc<MemoryTrackerInner>>,
}
#[derive(Debug)]
struct MemoryTrackerInner {
allocated: AtomicU64,
limit: u64,
}
impl MemoryTracker {
pub fn unlimited() -> Self {
Self { inner: None }
}
pub fn with_limit(limit: u64) -> Self {
Self {
inner: Some(Arc::new(MemoryTrackerInner {
allocated: AtomicU64::new(0),
limit,
})),
}
}
pub fn from_limit(limit: Option<u64>) -> Self {
match limit {
Some(limit) => Self::with_limit(limit),
None => Self::unlimited(),
}
}
pub fn has_limit(&self) -> bool {
self.inner.is_some()
}
pub fn allocated(&self) -> u64 {
self.inner
.as_ref()
.map(|i| i.allocated.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub fn limit(&self) -> Option<u64> {
self.inner.as_ref().map(|i| i.limit)
}
pub fn try_allocate(&self, bytes: u64) -> Result<()> {
let Some(inner) = &self.inner else {
return Ok(());
};
let prev = inner.allocated.fetch_add(bytes, Ordering::Relaxed);
let new_total = prev.saturating_add(bytes);
if new_total > inner.limit {
inner.allocated.fetch_sub(bytes, Ordering::Relaxed);
return Err(Error::LimitExceeded {
resource: "memory_bytes",
actual: new_total,
limit: inner.limit,
});
}
Ok(())
}
pub fn release(&self, bytes: u64) {
if let Some(inner) = &self.inner {
inner.allocated.fetch_sub(bytes, Ordering::Relaxed);
}
}
pub fn would_exceed(&self, bytes: u64) -> bool {
let Some(inner) = &self.inner else {
return false;
};
let current = inner.allocated.load(Ordering::Relaxed);
current.saturating_add(bytes) > inner.limit
}
pub fn check_alloc(&self, bytes: u64) -> Result<()> {
if self.would_exceed(bytes) {
let Some(inner) = &self.inner else {
return Ok(());
};
return Err(Error::LimitExceeded {
resource: "memory_bytes",
actual: inner
.allocated
.load(Ordering::Relaxed)
.saturating_add(bytes),
limit: inner.limit,
});
}
Ok(())
}
}
#[derive(Debug)]
pub struct MemoryGuard {
tracker: MemoryTracker,
bytes: u64,
}
impl MemoryGuard {
pub fn new(tracker: MemoryTracker, bytes: u64) -> Self {
Self { tracker, bytes }
}
pub fn disarm(mut self) {
self.bytes = 0;
}
#[deprecated(note = "use disarm() instead — forget() leaks the internal Arc")]
#[allow(dead_code)]
pub fn forget(self) {
std::mem::forget(self);
}
}
impl Drop for MemoryGuard {
fn drop(&mut self) {
self.tracker.release(self.bytes);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unlimited() {
let tracker = MemoryTracker::unlimited();
assert!(!tracker.has_limit());
assert!(tracker.try_allocate(u64::MAX).is_ok());
}
#[test]
fn test_with_limit() {
let tracker = MemoryTracker::with_limit(1000);
assert!(tracker.has_limit());
assert_eq!(tracker.limit(), Some(1000));
assert!(tracker.try_allocate(500).is_ok());
assert_eq!(tracker.allocated(), 500);
assert!(tracker.try_allocate(500).is_ok());
assert_eq!(tracker.allocated(), 1000);
assert!(tracker.try_allocate(1).is_err());
assert_eq!(tracker.allocated(), 1000);
tracker.release(500);
assert_eq!(tracker.allocated(), 500);
assert!(tracker.try_allocate(100).is_ok());
assert_eq!(tracker.allocated(), 600);
}
#[test]
fn test_clone_shares_state() {
let tracker1 = MemoryTracker::with_limit(1000);
let tracker2 = tracker1.clone();
tracker1.try_allocate(500).unwrap();
assert_eq!(tracker2.allocated(), 500);
tracker2.try_allocate(300).unwrap();
assert_eq!(tracker1.allocated(), 800);
}
#[test]
fn test_memory_guard() {
let tracker = MemoryTracker::with_limit(1000);
tracker.try_allocate(500).unwrap();
{
let _guard = MemoryGuard::new(tracker.clone(), 500);
}
assert_eq!(tracker.allocated(), 0);
}
#[test]
fn test_memory_guard_disarm() {
let tracker = MemoryTracker::with_limit(1000);
tracker.try_allocate(500).unwrap();
let guard = MemoryGuard::new(tracker.clone(), 500);
guard.disarm();
assert_eq!(tracker.allocated(), 500);
}
}