1use crate::error::{Result, UnslothError};
25
26pub struct MemoryPool {
31 allocated: usize,
33 peak: usize,
35 limit: Option<usize>,
37 device_type: DeviceType,
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
43pub enum DeviceType {
44 #[default]
46 Cpu,
47 Cuda,
49 Metal,
51 Vulkan,
53}
54
55impl MemoryPool {
56 #[must_use]
58 pub fn new(limit: Option<usize>) -> Self {
59 Self {
60 allocated: 0,
61 peak: 0,
62 limit,
63 device_type: DeviceType::default(),
64 }
65 }
66
67 #[must_use]
69 pub fn with_device(limit: Option<usize>, device_type: DeviceType) -> Self {
70 Self {
71 allocated: 0,
72 peak: 0,
73 limit,
74 device_type,
75 }
76 }
77
78 pub fn allocate(&mut self, bytes: usize) -> Result<()> {
83 let new_total = self.allocated + bytes;
84
85 if let Some(limit) = self.limit {
86 if new_total > limit {
87 return Err(UnslothError::OutOfMemory {
88 required: new_total,
89 available: limit.saturating_sub(self.allocated),
90 });
91 }
92 }
93
94 self.allocated = new_total;
95 self.peak = self.peak.max(self.allocated);
96 Ok(())
97 }
98
99 pub fn free(&mut self, bytes: usize) {
101 self.allocated = self.allocated.saturating_sub(bytes);
102 }
103
104 #[must_use]
106 pub fn allocated(&self) -> usize {
107 self.allocated
108 }
109
110 #[must_use]
112 pub fn peak(&self) -> usize {
113 self.peak
114 }
115
116 #[must_use]
118 pub fn device_type(&self) -> DeviceType {
119 self.device_type
120 }
121
122 pub fn reset_peak(&mut self) {
124 self.peak = self.allocated;
125 }
126
127 #[must_use]
129 pub fn efficiency(&self) -> f64 {
130 if self.peak == 0 {
131 1.0
132 } else {
133 #[allow(clippy::cast_precision_loss)]
135 {
136 self.allocated as f64 / self.peak as f64
137 }
138 }
139 }
140}
141
142#[derive(Debug, Clone)]
147pub struct CheckpointConfig {
148 pub checkpoint_every: usize,
150 pub enabled: bool,
152}
153
154impl Default for CheckpointConfig {
155 fn default() -> Self {
156 Self {
157 checkpoint_every: 1,
158 enabled: true,
159 }
160 }
161}
162
163impl CheckpointConfig {
164 #[must_use]
166 pub fn new(checkpoint_every: usize, enabled: bool) -> Self {
167 Self {
168 checkpoint_every,
169 enabled,
170 }
171 }
172
173 #[must_use]
177 pub fn memory_reduction_factor(&self, num_layers: usize) -> f64 {
178 if !self.enabled || num_layers == 0 {
179 1.0
180 } else {
181 let checkpointed = num_layers.div_ceil(self.checkpoint_every);
182 #[allow(clippy::cast_precision_loss)]
184 {
185 checkpointed as f64 / num_layers as f64
186 }
187 }
188 }
189}
190
191#[must_use]
203pub fn estimate_forward_memory(
204 batch_size: usize,
205 seq_len: usize,
206 hidden_size: usize,
207 num_layers: usize,
208 checkpoint_config: &CheckpointConfig,
209) -> usize {
210 let bytes_per_elem = 4; let activation_per_layer = batch_size * seq_len * hidden_size * bytes_per_elem;
214
215 let stored_layers = if checkpoint_config.enabled {
217 num_layers.div_ceil(checkpoint_config.checkpoint_every)
218 } else {
219 num_layers
220 };
221
222 stored_layers * activation_per_layer
223}
224
225#[must_use]
236pub fn estimate_attention_vram(
237 batch_size: usize,
238 seq_len: usize,
239 hidden_size: usize,
240 num_heads: usize,
241) -> usize {
242 let bytes_per_elem = 4; let qkv_size = batch_size * seq_len * 3 * hidden_size * bytes_per_elem;
246 let scores_size = batch_size * num_heads * seq_len * seq_len * bytes_per_elem;
248 let output_size = batch_size * seq_len * hidden_size * bytes_per_elem;
250
251 qkv_size + scores_size + output_size
252}
253
254#[must_use]
256pub fn format_bytes(bytes: usize) -> String {
257 const KB: usize = 1024;
258 const MB: usize = KB * 1024;
259 const GB: usize = MB * 1024;
260
261 #[allow(clippy::cast_precision_loss)]
263 if bytes >= GB {
264 format!("{:.2} GB", bytes as f64 / GB as f64)
265 } else if bytes >= MB {
266 format!("{:.2} MB", bytes as f64 / MB as f64)
267 } else if bytes >= KB {
268 format!("{:.2} KB", bytes as f64 / KB as f64)
269 } else {
270 format!("{bytes} bytes")
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_memory_pool_allocation() {
280 let mut pool = MemoryPool::new(Some(1000));
281
282 assert!(pool.allocate(500).is_ok());
283 assert_eq!(pool.allocated(), 500);
284
285 assert!(pool.allocate(400).is_ok());
286 assert_eq!(pool.allocated(), 900);
287
288 assert!(pool.allocate(200).is_err());
290
291 pool.free(300);
292 assert_eq!(pool.allocated(), 600);
293 }
294
295 #[test]
296 fn test_memory_pool_with_device() {
297 let pool = MemoryPool::with_device(Some(1024 * 1024), DeviceType::Cuda);
298 assert_eq!(pool.device_type(), DeviceType::Cuda);
299 assert_eq!(pool.allocated(), 0);
300 }
301
302 #[test]
303 fn test_checkpoint_memory_reduction() {
304 let batch = 4;
305 let seq = 2048;
306 let hidden = 4096;
307 let layers = 32;
308
309 let no_checkpoint = CheckpointConfig {
310 enabled: false,
311 ..Default::default()
312 };
313 let with_checkpoint = CheckpointConfig {
314 enabled: true,
315 checkpoint_every: 4,
316 };
317
318 let mem_full = estimate_forward_memory(batch, seq, hidden, layers, &no_checkpoint);
319 let mem_checkpoint = estimate_forward_memory(batch, seq, hidden, layers, &with_checkpoint);
320
321 assert!(mem_checkpoint < mem_full / 2);
323 }
324
325 #[test]
326 fn test_checkpoint_reduction_factor() {
327 let config = CheckpointConfig::new(4, true);
328 let factor = config.memory_reduction_factor(32);
329 assert!((factor - 0.25).abs() < 0.01);
331 }
332
333 #[test]
334 fn test_format_bytes() {
335 assert_eq!(format_bytes(500), "500 bytes");
336 assert_eq!(format_bytes(1024), "1.00 KB");
337 assert_eq!(format_bytes(1024 * 1024), "1.00 MB");
338 assert_eq!(format_bytes(1024 * 1024 * 1024), "1.00 GB");
339 }
340
341 #[test]
342 fn test_attention_vram_estimate() {
343 let vram = estimate_attention_vram(4, 2048, 4096, 32);
344 assert!(vram > 100 * 1024 * 1024); assert!(vram < 10 * 1024 * 1024 * 1024); }
348}