use std::sync::atomic::{AtomicUsize, Ordering};
use crate::{Error, Result};
#[derive(Debug)]
pub struct MemoryTracker {
current_usage: AtomicUsize,
peak_usage: AtomicUsize,
limit: usize,
}
impl MemoryTracker {
pub fn new(limit: usize) -> Self {
Self {
current_usage: AtomicUsize::new(0),
peak_usage: AtomicUsize::new(0),
limit,
}
}
pub fn unlimited() -> Self {
Self::new(usize::MAX)
}
pub fn limit(&self) -> usize {
self.limit
}
pub fn current_usage(&self) -> usize {
self.current_usage.load(Ordering::SeqCst)
}
pub fn peak_usage(&self) -> usize {
self.peak_usage.load(Ordering::SeqCst)
}
pub fn available(&self) -> usize {
self.limit.saturating_sub(self.current_usage())
}
pub fn can_allocate(&self, bytes: usize) -> bool {
self.current_usage() + bytes <= self.limit
}
pub fn allocate(&self, bytes: usize) -> Result<MemoryGuard<'_>> {
loop {
let current = self.current_usage.load(Ordering::SeqCst);
let new_usage = current.checked_add(bytes).ok_or_else(|| {
Error::ResourceLimitExceeded(format!(
"Memory allocation overflow: {} + {} bytes",
current, bytes
))
})?;
if new_usage > self.limit {
return Err(Error::ResourceLimitExceeded(format!(
"Memory limit exceeded: {} + {} = {} bytes (limit: {} bytes)",
current, bytes, new_usage, self.limit
)));
}
if self
.current_usage
.compare_exchange(current, new_usage, Ordering::SeqCst, Ordering::SeqCst)
.is_ok()
{
self.peak_usage.fetch_max(new_usage, Ordering::SeqCst);
return Ok(MemoryGuard {
tracker: self,
bytes,
});
}
}
}
pub fn try_allocate(&self, bytes: usize) -> Option<MemoryGuard<'_>> {
self.allocate(bytes).ok()
}
pub fn allocate_up_to(&self, bytes: usize) -> (MemoryGuard<'_>, usize) {
let available = self.available();
let to_allocate = bytes.min(available);
if to_allocate == 0 {
return (
MemoryGuard {
tracker: self,
bytes: 0,
},
0,
);
}
match self.allocate(to_allocate) {
Ok(guard) => (guard, to_allocate),
Err(_) => {
let available = self.available();
if available > 0 {
match self.allocate(available) {
Ok(guard) => (guard, available),
Err(_) => (
MemoryGuard {
tracker: self,
bytes: 0,
},
0,
),
}
} else {
(
MemoryGuard {
tracker: self,
bytes: 0,
},
0,
)
}
}
}
}
pub fn reset(&self) {
self.current_usage.store(0, Ordering::SeqCst);
}
pub fn reset_peak(&self) {
self.peak_usage
.store(self.current_usage(), Ordering::SeqCst);
}
fn release(&self, bytes: usize) {
self.current_usage.fetch_sub(bytes, Ordering::SeqCst);
}
}
impl Default for MemoryTracker {
fn default() -> Self {
Self::new(64 * 1024 * 1024)
}
}
#[derive(Debug)]
pub struct MemoryGuard<'a> {
tracker: &'a MemoryTracker,
bytes: usize,
}
impl<'a> MemoryGuard<'a> {
pub fn bytes(&self) -> usize {
self.bytes
}
pub fn forget(self) -> usize {
let bytes = self.bytes;
std::mem::forget(self);
bytes
}
}
impl Drop for MemoryGuard<'_> {
fn drop(&mut self) {
if self.bytes > 0 {
self.tracker.release(self.bytes);
}
}
}
#[derive(Debug)]
pub struct TrackedBuffer<'a> {
data: Vec<u8>,
_guard: MemoryGuard<'a>,
}
impl<'a> TrackedBuffer<'a> {
pub fn new(tracker: &'a MemoryTracker, capacity: usize) -> Result<Self> {
let guard = tracker.allocate(capacity)?;
let data = Vec::with_capacity(capacity);
Ok(Self {
data,
_guard: guard,
})
}
pub fn zeroed(tracker: &'a MemoryTracker, size: usize) -> Result<Self> {
let guard = tracker.allocate(size)?;
let data = vec![0u8; size];
Ok(Self {
data,
_guard: guard,
})
}
pub fn as_slice(&self) -> &[u8] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [u8] {
&mut self.data
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn capacity(&self) -> usize {
self.data.capacity()
}
pub fn into_vec(self) -> Vec<u8> {
self.data
}
}
impl AsRef<[u8]> for TrackedBuffer<'_> {
fn as_ref(&self) -> &[u8] {
&self.data
}
}
impl AsMut<[u8]> for TrackedBuffer<'_> {
fn as_mut(&mut self) -> &mut [u8] {
&mut self.data
}
}
impl std::ops::Deref for TrackedBuffer<'_> {
type Target = [u8];
fn deref(&self) -> &Self::Target {
&self.data
}
}
impl std::ops::DerefMut for TrackedBuffer<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tracker_basic() {
let tracker = MemoryTracker::new(1024);
assert_eq!(tracker.limit(), 1024);
assert_eq!(tracker.current_usage(), 0);
assert_eq!(tracker.peak_usage(), 0);
assert_eq!(tracker.available(), 1024);
}
#[test]
fn test_allocate_success() {
let tracker = MemoryTracker::new(1024);
let guard = tracker.allocate(512).unwrap();
assert_eq!(tracker.current_usage(), 512);
assert_eq!(tracker.available(), 512);
assert_eq!(guard.bytes(), 512);
drop(guard);
assert_eq!(tracker.current_usage(), 0);
}
#[test]
fn test_allocate_exceeds_limit() {
let tracker = MemoryTracker::new(1024);
let result = tracker.allocate(2048);
assert!(result.is_err());
assert_eq!(tracker.current_usage(), 0);
}
#[test]
fn test_multiple_allocations() {
let tracker = MemoryTracker::new(1024);
let guard1 = tracker.allocate(256).unwrap();
assert_eq!(tracker.current_usage(), 256);
let guard2 = tracker.allocate(256).unwrap();
assert_eq!(tracker.current_usage(), 512);
drop(guard1);
assert_eq!(tracker.current_usage(), 256);
drop(guard2);
assert_eq!(tracker.current_usage(), 0);
}
#[test]
fn test_peak_usage() {
let tracker = MemoryTracker::new(1024);
let guard1 = tracker.allocate(300).unwrap();
let guard2 = tracker.allocate(400).unwrap();
assert_eq!(tracker.peak_usage(), 700);
drop(guard1);
assert_eq!(tracker.current_usage(), 400);
assert_eq!(tracker.peak_usage(), 700);
let guard3 = tracker.allocate(500).unwrap();
assert_eq!(tracker.peak_usage(), 900);
drop(guard2);
drop(guard3);
}
#[test]
fn test_can_allocate() {
let tracker = MemoryTracker::new(1024);
assert!(tracker.can_allocate(512));
assert!(tracker.can_allocate(1024));
assert!(!tracker.can_allocate(2048));
let _guard = tracker.allocate(512).unwrap();
assert!(tracker.can_allocate(512));
assert!(!tracker.can_allocate(1024));
}
#[test]
fn test_try_allocate() {
let tracker = MemoryTracker::new(1024);
let guard = tracker.try_allocate(512);
assert!(guard.is_some());
let guard2 = tracker.try_allocate(1024);
assert!(guard2.is_none());
drop(guard);
}
#[test]
fn test_allocate_up_to() {
let tracker = MemoryTracker::new(1024);
let (guard1, amount1) = tracker.allocate_up_to(2048);
assert_eq!(amount1, 1024);
assert_eq!(tracker.current_usage(), 1024);
let (guard2, amount2) = tracker.allocate_up_to(512);
assert_eq!(amount2, 0);
drop(guard1);
drop(guard2);
}
#[test]
fn test_guard_forget() {
let tracker = MemoryTracker::new(1024);
let guard = tracker.allocate(256).unwrap();
let bytes = guard.forget();
assert_eq!(bytes, 256);
assert_eq!(tracker.current_usage(), 256);
tracker.reset();
assert_eq!(tracker.current_usage(), 0);
}
#[test]
fn test_tracked_buffer() {
let tracker = MemoryTracker::new(1024);
let buffer = TrackedBuffer::new(&tracker, 256).unwrap();
assert_eq!(buffer.capacity(), 256);
assert_eq!(tracker.current_usage(), 256);
drop(buffer);
assert_eq!(tracker.current_usage(), 0);
}
#[test]
fn test_tracked_buffer_zeroed() {
let tracker = MemoryTracker::new(1024);
let buffer = TrackedBuffer::zeroed(&tracker, 128).unwrap();
assert_eq!(buffer.len(), 128);
assert!(buffer.iter().all(|&b| b == 0));
assert_eq!(tracker.current_usage(), 128);
}
#[test]
fn test_unlimited_tracker() {
let tracker = MemoryTracker::unlimited();
let guard = tracker.allocate(1024 * 1024 * 1024).unwrap();
assert_eq!(guard.bytes(), 1024 * 1024 * 1024);
}
#[test]
fn test_reset_peak() {
let tracker = MemoryTracker::new(1024);
let guard = tracker.allocate(500).unwrap();
assert_eq!(tracker.peak_usage(), 500);
drop(guard);
assert_eq!(tracker.peak_usage(), 500);
tracker.reset_peak();
assert_eq!(tracker.peak_usage(), 0);
}
}