use crate::error::{Result, UnslothError};
pub struct MemoryPool {
allocated: usize,
peak: usize,
limit: Option<usize>,
device_type: DeviceType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DeviceType {
#[default]
Cpu,
Cuda,
Metal,
Vulkan,
}
impl MemoryPool {
#[must_use]
pub fn new(limit: Option<usize>) -> Self {
Self {
allocated: 0,
peak: 0,
limit,
device_type: DeviceType::default(),
}
}
#[must_use]
pub fn with_device(limit: Option<usize>, device_type: DeviceType) -> Self {
Self {
allocated: 0,
peak: 0,
limit,
device_type,
}
}
pub fn allocate(&mut self, bytes: usize) -> Result<()> {
let new_total = self.allocated + bytes;
if let Some(limit) = self.limit {
if new_total > limit {
return Err(UnslothError::OutOfMemory {
required: new_total,
available: limit.saturating_sub(self.allocated),
});
}
}
self.allocated = new_total;
self.peak = self.peak.max(self.allocated);
Ok(())
}
pub fn free(&mut self, bytes: usize) {
self.allocated = self.allocated.saturating_sub(bytes);
}
#[must_use]
pub fn allocated(&self) -> usize {
self.allocated
}
#[must_use]
pub fn peak(&self) -> usize {
self.peak
}
#[must_use]
pub fn device_type(&self) -> DeviceType {
self.device_type
}
pub fn reset_peak(&mut self) {
self.peak = self.allocated;
}
#[must_use]
pub fn efficiency(&self) -> f64 {
if self.peak == 0 {
1.0
} else {
#[allow(clippy::cast_precision_loss)]
{
self.allocated as f64 / self.peak as f64
}
}
}
}
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub checkpoint_every: usize,
pub enabled: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
checkpoint_every: 1,
enabled: true,
}
}
}
impl CheckpointConfig {
#[must_use]
pub fn new(checkpoint_every: usize, enabled: bool) -> Self {
Self {
checkpoint_every,
enabled,
}
}
#[must_use]
pub fn memory_reduction_factor(&self, num_layers: usize) -> f64 {
if !self.enabled || num_layers == 0 {
1.0
} else {
let checkpointed = num_layers.div_ceil(self.checkpoint_every);
#[allow(clippy::cast_precision_loss)]
{
checkpointed as f64 / num_layers as f64
}
}
}
}
#[must_use]
pub fn estimate_forward_memory(
batch_size: usize,
seq_len: usize,
hidden_size: usize,
num_layers: usize,
checkpoint_config: &CheckpointConfig,
) -> usize {
let bytes_per_elem = 4;
let activation_per_layer = batch_size * seq_len * hidden_size * bytes_per_elem;
let stored_layers = if checkpoint_config.enabled {
num_layers.div_ceil(checkpoint_config.checkpoint_every)
} else {
num_layers
};
stored_layers * activation_per_layer
}
#[must_use]
pub fn estimate_attention_vram(
batch_size: usize,
seq_len: usize,
hidden_size: usize,
num_heads: usize,
) -> usize {
let bytes_per_elem = 4;
let qkv_size = batch_size * seq_len * 3 * hidden_size * bytes_per_elem;
let scores_size = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
let output_size = batch_size * seq_len * hidden_size * bytes_per_elem;
qkv_size + scores_size + output_size
}
#[must_use]
pub fn format_bytes(bytes: usize) -> String {
const KB: usize = 1024;
const MB: usize = KB * 1024;
const GB: usize = MB * 1024;
#[allow(clippy::cast_precision_loss)]
if bytes >= GB {
format!("{:.2} GB", bytes as f64 / GB as f64)
} else if bytes >= MB {
format!("{:.2} MB", bytes as f64 / MB as f64)
} else if bytes >= KB {
format!("{:.2} KB", bytes as f64 / KB as f64)
} else {
format!("{bytes} bytes")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_memory_pool_allocation() {
let mut pool = MemoryPool::new(Some(1000));
assert!(pool.allocate(500).is_ok());
assert_eq!(pool.allocated(), 500);
assert!(pool.allocate(400).is_ok());
assert_eq!(pool.allocated(), 900);
assert!(pool.allocate(200).is_err());
pool.free(300);
assert_eq!(pool.allocated(), 600);
}
#[test]
fn test_memory_pool_with_device() {
let pool = MemoryPool::with_device(Some(1024 * 1024), DeviceType::Cuda);
assert_eq!(pool.device_type(), DeviceType::Cuda);
assert_eq!(pool.allocated(), 0);
}
#[test]
fn test_checkpoint_memory_reduction() {
let batch = 4;
let seq = 2048;
let hidden = 4096;
let layers = 32;
let no_checkpoint = CheckpointConfig {
enabled: false,
..Default::default()
};
let with_checkpoint = CheckpointConfig {
enabled: true,
checkpoint_every: 4,
};
let mem_full = estimate_forward_memory(batch, seq, hidden, layers, &no_checkpoint);
let mem_checkpoint = estimate_forward_memory(batch, seq, hidden, layers, &with_checkpoint);
assert!(mem_checkpoint < mem_full / 2);
}
#[test]
fn test_checkpoint_reduction_factor() {
let config = CheckpointConfig::new(4, true);
let factor = config.memory_reduction_factor(32);
assert!((factor - 0.25).abs() < 0.01);
}
#[test]
fn test_format_bytes() {
assert_eq!(format_bytes(500), "500 bytes");
assert_eq!(format_bytes(1024), "1.00 KB");
assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
}
#[test]
fn test_attention_vram_estimate() {
let vram = estimate_attention_vram(4, 2048, 4096, 32);
assert!(vram > 100 * 1024 * 1024); assert!(vram < 10 * 1024 * 1024 * 1024); }
}