Skip to main content

unsloth_rs/
memory.rs

1// SPDX-License-Identifier: MIT
2// Copyright 2026 Tyler Zervas
3
4//! Memory management utilities for tracking allocations.
5//!
6//! This module provides tools for memory tracking and estimation, which are
7//! essential for managing GPU memory when training large language models.
8//!
9//! ## Why Memory Management?
10//!
11//! Large language models can easily exhaust GPU memory. These utilities help:
12//! - Estimate memory requirements before running operations
13//! - Track actual allocations during execution
14//! - Configure gradient checkpointing to trade compute for memory
15//!
16//! ## Provided Utilities
17//!
18//! - `MemoryPool`: Tracks allocations with optional limit enforcement
19//! - `CheckpointConfig`: Configuration for gradient checkpointing
20//! - `estimate_forward_memory`: Estimates memory for forward passes
21//! - `estimate_attention_vram`: Estimates memory for attention operations
22//! - `format_bytes`: Human-readable byte formatting
23
24use crate::error::{Result, UnslothError};
25
26/// Memory pool for efficient GPU allocation.
27///
28/// Tracks memory allocations and provides limit enforcement.
29/// Future versions will integrate with `CubeCL` for device-aware allocation.
30pub struct MemoryPool {
31    /// Total allocated bytes
32    allocated: usize,
33    /// Peak memory usage
34    peak: usize,
35    /// Memory limit (if set)
36    limit: Option<usize>,
37    /// Device type for allocation tracking
38    device_type: DeviceType,
39}
40
41/// Device type for memory tracking.
42#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
43pub enum DeviceType {
44    /// CPU memory
45    #[default]
46    Cpu,
47    /// CUDA GPU memory
48    Cuda,
49    /// Metal GPU memory (Apple)
50    Metal,
51    /// Vulkan GPU memory
52    Vulkan,
53}
54
55impl MemoryPool {
56    /// Create a new memory pool.
57    #[must_use]
58    pub fn new(limit: Option<usize>) -> Self {
59        Self {
60            allocated: 0,
61            peak: 0,
62            limit,
63            device_type: DeviceType::default(),
64        }
65    }
66
67    /// Create a new memory pool for a specific device.
68    #[must_use]
69    pub fn with_device(limit: Option<usize>, device_type: DeviceType) -> Self {
70        Self {
71            allocated: 0,
72            peak: 0,
73            limit,
74            device_type,
75        }
76    }
77
78    /// Request memory allocation.
79    ///
80    /// # Errors
81    /// Returns `OutOfMemory` if limit would be exceeded.
82    pub fn allocate(&mut self, bytes: usize) -> Result<()> {
83        let new_total = self.allocated + bytes;
84
85        if let Some(limit) = self.limit {
86            if new_total > limit {
87                return Err(UnslothError::OutOfMemory {
88                    required: new_total,
89                    available: limit.saturating_sub(self.allocated),
90                });
91            }
92        }
93
94        self.allocated = new_total;
95        self.peak = self.peak.max(self.allocated);
96        Ok(())
97    }
98
99    /// Free memory.
100    pub fn free(&mut self, bytes: usize) {
101        self.allocated = self.allocated.saturating_sub(bytes);
102    }
103
104    /// Get current allocation.
105    #[must_use]
106    pub fn allocated(&self) -> usize {
107        self.allocated
108    }
109
110    /// Get peak allocation.
111    #[must_use]
112    pub fn peak(&self) -> usize {
113        self.peak
114    }
115
116    /// Get the device type.
117    #[must_use]
118    pub fn device_type(&self) -> DeviceType {
119        self.device_type
120    }
121
122    /// Reset peak tracking.
123    pub fn reset_peak(&mut self) {
124        self.peak = self.allocated;
125    }
126
127    /// Calculate memory efficiency (allocated vs peak).
128    #[must_use]
129    pub fn efficiency(&self) -> f64 {
130        if self.peak == 0 {
131            1.0
132        } else {
133            // Precision loss acceptable for efficiency metric
134            #[allow(clippy::cast_precision_loss)]
135            {
136                self.allocated as f64 / self.peak as f64
137            }
138        }
139    }
140}
141
142/// Gradient checkpointing configuration.
143///
144/// Controls how activations are stored during forward pass.
145/// Higher `checkpoint_every` values reduce memory but increase compute.
146#[derive(Debug, Clone)]
147pub struct CheckpointConfig {
148    /// Checkpoint every N layers
149    pub checkpoint_every: usize,
150    /// Enable checkpointing
151    pub enabled: bool,
152}
153
154impl Default for CheckpointConfig {
155    fn default() -> Self {
156        Self {
157            checkpoint_every: 1,
158            enabled: true,
159        }
160    }
161}
162
163impl CheckpointConfig {
164    /// Create a new checkpoint config.
165    #[must_use]
166    pub fn new(checkpoint_every: usize, enabled: bool) -> Self {
167        Self {
168            checkpoint_every,
169            enabled,
170        }
171    }
172
173    /// Calculate expected memory reduction factor.
174    ///
175    /// Returns a value between 0 and 1, where lower is better.
176    #[must_use]
177    pub fn memory_reduction_factor(&self, num_layers: usize) -> f64 {
178        if !self.enabled || num_layers == 0 {
179            1.0
180        } else {
181            let checkpointed = num_layers.div_ceil(self.checkpoint_every);
182            // Precision loss acceptable for memory reduction factor metric
183            #[allow(clippy::cast_precision_loss)]
184            {
185                checkpointed as f64 / num_layers as f64
186            }
187        }
188    }
189}
190
191/// Calculate memory requirements for a forward pass.
192///
193/// # Arguments
194/// * `batch_size` - Batch size
195/// * `seq_len` - Sequence length
196/// * `hidden_size` - Hidden dimension size
197/// * `num_layers` - Number of transformer layers
198/// * `checkpoint_config` - Gradient checkpointing configuration
199///
200/// # Returns
201/// Estimated memory usage in bytes
202#[must_use]
203pub fn estimate_forward_memory(
204    batch_size: usize,
205    seq_len: usize,
206    hidden_size: usize,
207    num_layers: usize,
208    checkpoint_config: &CheckpointConfig,
209) -> usize {
210    let bytes_per_elem = 4; // f32
211
212    // Per-layer activation memory
213    let activation_per_layer = batch_size * seq_len * hidden_size * bytes_per_elem;
214
215    // With checkpointing, only store every N layers
216    let stored_layers = if checkpoint_config.enabled {
217        num_layers.div_ceil(checkpoint_config.checkpoint_every)
218    } else {
219        num_layers
220    };
221
222    stored_layers * activation_per_layer
223}
224
225/// Estimate VRAM for attention operation.
226///
227/// # Arguments
228/// * `batch_size` - Batch size
229/// * `seq_len` - Sequence length
230/// * `hidden_size` - Hidden dimension
231/// * `num_heads` - Number of attention heads
232///
233/// # Returns
234/// Estimated VRAM in bytes
235#[must_use]
236pub fn estimate_attention_vram(
237    batch_size: usize,
238    seq_len: usize,
239    hidden_size: usize,
240    num_heads: usize,
241) -> usize {
242    let bytes_per_elem = 4; // f32
243
244    // QKV projection output
245    let qkv_size = batch_size * seq_len * 3 * hidden_size * bytes_per_elem;
246    // Attention scores: [batch, num_heads, seq_len, seq_len]
247    let scores_size = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
248    // Output
249    let output_size = batch_size * seq_len * hidden_size * bytes_per_elem;
250
251    qkv_size + scores_size + output_size
252}
253
254/// Format bytes as human-readable string.
255#[must_use]
256pub fn format_bytes(bytes: usize) -> String {
257    const KB: usize = 1024;
258    const MB: usize = KB * 1024;
259    const GB: usize = MB * 1024;
260
261    // Precision loss acceptable for human-readable byte formatting
262    #[allow(clippy::cast_precision_loss)]
263    if bytes >= GB {
264        format!("{:.2} GB", bytes as f64 / GB as f64)
265    } else if bytes >= MB {
266        format!("{:.2} MB", bytes as f64 / MB as f64)
267    } else if bytes >= KB {
268        format!("{:.2} KB", bytes as f64 / KB as f64)
269    } else {
270        format!("{bytes} bytes")
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_memory_pool_allocation() {
280        let mut pool = MemoryPool::new(Some(1000));
281
282        assert!(pool.allocate(500).is_ok());
283        assert_eq!(pool.allocated(), 500);
284
285        assert!(pool.allocate(400).is_ok());
286        assert_eq!(pool.allocated(), 900);
287
288        // Should fail - would exceed limit
289        assert!(pool.allocate(200).is_err());
290
291        pool.free(300);
292        assert_eq!(pool.allocated(), 600);
293    }
294
295    #[test]
296    fn test_memory_pool_with_device() {
297        let pool = MemoryPool::with_device(Some(1024 * 1024), DeviceType::Cuda);
298        assert_eq!(pool.device_type(), DeviceType::Cuda);
299        assert_eq!(pool.allocated(), 0);
300    }
301
302    #[test]
303    fn test_checkpoint_memory_reduction() {
304        let batch = 4;
305        let seq = 2048;
306        let hidden = 4096;
307        let layers = 32;
308
309        let no_checkpoint = CheckpointConfig {
310            enabled: false,
311            ..Default::default()
312        };
313        let with_checkpoint = CheckpointConfig {
314            enabled: true,
315            checkpoint_every: 4,
316        };
317
318        let mem_full = estimate_forward_memory(batch, seq, hidden, layers, &no_checkpoint);
319        let mem_checkpoint = estimate_forward_memory(batch, seq, hidden, layers, &with_checkpoint);
320
321        // Checkpointing should reduce memory significantly
322        assert!(mem_checkpoint < mem_full / 2);
323    }
324
325    #[test]
326    fn test_checkpoint_reduction_factor() {
327        let config = CheckpointConfig::new(4, true);
328        let factor = config.memory_reduction_factor(32);
329        // 32 layers, checkpoint every 4 = 8 checkpoints = 8/32 = 0.25
330        assert!((factor - 0.25).abs() < 0.01);
331    }
332
333    #[test]
334    fn test_format_bytes() {
335        assert_eq!(format_bytes(500), "500 bytes");
336        assert_eq!(format_bytes(1024), "1.00 KB");
337        assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
338        assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
339    }
340
341    #[test]
342    fn test_attention_vram_estimate() {
343        let vram = estimate_attention_vram(4, 2048, 4096, 32);
344        // Should be substantial but not unreasonable
345        assert!(vram > 100 * 1024 * 1024); // > 100 MB
346        assert!(vram < 10 * 1024 * 1024 * 1024); // < 10 GB
347    }
348}