use crate::error::{CoreError, Result};
use candle_core::DType;
use std::sync::atomic::{AtomicUsize, Ordering};
pub const DEFAULT_OVERHEAD_FACTOR: f64 = 1.1;
#[must_use]
pub fn estimate_tensor_bytes(shape: &[usize], dtype: DType) -> usize {
let numel: usize = shape.iter().product();
numel * dtype.size_in_bytes()
}
#[must_use]
pub fn estimate_attention_memory(
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_dim: usize,
dtype: DType,
) -> usize {
let bytes_per_elem = dtype.size_in_bytes();
let qkv_bytes = 3 * batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
let attn_weights_bytes = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
let output_bytes = batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
qkv_bytes + attn_weights_bytes + output_bytes
}
#[derive(Debug)]
pub struct MemoryTracker {
allocated: AtomicUsize,
peak: AtomicUsize,
limit: AtomicUsize,
overhead_factor: f64,
}
impl Default for MemoryTracker {
fn default() -> Self {
Self::new()
}
}
impl MemoryTracker {
#[must_use]
pub fn new() -> Self {
Self {
allocated: AtomicUsize::new(0),
peak: AtomicUsize::new(0),
limit: AtomicUsize::new(0),
overhead_factor: DEFAULT_OVERHEAD_FACTOR,
}
}
#[must_use]
pub fn with_limit(limit_bytes: usize) -> Self {
Self {
allocated: AtomicUsize::new(0),
peak: AtomicUsize::new(0),
limit: AtomicUsize::new(limit_bytes),
overhead_factor: DEFAULT_OVERHEAD_FACTOR,
}
}
#[must_use]
pub fn with_overhead_factor(mut self, factor: f64) -> Self {
self.overhead_factor = factor;
self
}
pub fn allocate(&self, bytes: usize) -> Result<()> {
let limit = self.limit.load(Ordering::SeqCst);
let current = self.allocated.load(Ordering::SeqCst);
let new_allocated = current + bytes;
if limit > 0 && new_allocated > limit {
return Err(CoreError::oom(format!(
"allocation of {bytes} bytes would exceed limit of {limit} bytes \
(current: {current} bytes)"
)));
}
let actual_new = self.allocated.fetch_add(bytes, Ordering::SeqCst) + bytes;
let mut current_peak = self.peak.load(Ordering::SeqCst);
while actual_new > current_peak {
match self.peak.compare_exchange_weak(
current_peak,
actual_new,
Ordering::SeqCst,
Ordering::SeqCst,
) {
Ok(_) => break,
Err(p) => current_peak = p,
}
}
Ok(())
}
pub fn deallocate(&self, bytes: usize) {
self.allocated.fetch_sub(bytes, Ordering::SeqCst);
}
#[must_use]
pub fn allocated_bytes(&self) -> usize {
self.allocated.load(Ordering::SeqCst)
}
#[must_use]
pub fn peak_bytes(&self) -> usize {
self.peak.load(Ordering::SeqCst)
}
#[must_use]
pub fn limit_bytes(&self) -> usize {
self.limit.load(Ordering::SeqCst)
}
#[must_use]
pub fn estimate_with_overhead(&self, shape: &[usize], dtype: DType) -> usize {
let raw = estimate_tensor_bytes(shape, dtype);
#[allow(
clippy::cast_sign_loss,
clippy::cast_possible_truncation,
clippy::cast_precision_loss
)]
{
(raw as f64 * self.overhead_factor) as usize
}
}
#[must_use]
pub fn would_fit(&self, bytes: usize) -> bool {
let limit = self.limit.load(Ordering::SeqCst);
if limit == 0 {
return true; }
self.allocated.load(Ordering::SeqCst) + bytes <= limit
}
pub fn reset(&self) {
self.allocated.store(0, Ordering::SeqCst);
self.peak.store(0, Ordering::SeqCst);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_estimate_tensor_bytes() {
assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F32), 4000);
assert_eq!(estimate_tensor_bytes(&[10, 100], DType::F16), 2000);
assert_eq!(estimate_tensor_bytes(&[0], DType::F32), 0);
}
#[test]
fn test_estimate_attention_memory() {
let bytes = estimate_attention_memory(1, 1, 4, 2, DType::F32);
assert_eq!(bytes, 192);
}
#[test]
fn test_memory_tracker_allocation() {
let tracker = MemoryTracker::with_limit(1000);
assert!(tracker.allocate(500).is_ok());
assert_eq!(tracker.allocated_bytes(), 500);
assert!(tracker.allocate(400).is_ok());
assert_eq!(tracker.allocated_bytes(), 900);
assert!(tracker.allocate(200).is_err());
assert_eq!(tracker.allocated_bytes(), 900);
tracker.deallocate(400);
assert_eq!(tracker.allocated_bytes(), 500);
assert!(tracker.allocate(200).is_ok());
}
#[test]
fn test_memory_tracker_peak() {
let tracker = MemoryTracker::new();
tracker.allocate(100).unwrap();
tracker.allocate(200).unwrap();
assert_eq!(tracker.peak_bytes(), 300);
tracker.deallocate(200);
assert_eq!(tracker.allocated_bytes(), 100);
assert_eq!(tracker.peak_bytes(), 300);
tracker.allocate(50).unwrap();
assert_eq!(tracker.peak_bytes(), 300);
tracker.allocate(300).unwrap();
assert_eq!(tracker.peak_bytes(), 450); }
#[test]
fn test_would_fit() {
let tracker = MemoryTracker::with_limit(1000);
tracker.allocate(500).unwrap();
assert!(tracker.would_fit(400));
assert!(tracker.would_fit(500));
assert!(!tracker.would_fit(501));
let unlimited = MemoryTracker::new();
assert!(unlimited.would_fit(usize::MAX));
}
}