Skip to main content

entrenar/hf_pipeline/loader/
memory.rs

1//! Memory estimation for model loading
2
3/// Memory estimation for model loading
4#[derive(Debug, Clone, Copy)]
5pub struct MemoryEstimate {
6    /// Memory for model weights
7    pub weights: u64,
8    /// Memory for activations during forward pass
9    pub activations: u64,
10    /// Memory for gradients (0 for frozen teacher)
11    pub gradients: u64,
12}
13
14impl MemoryEstimate {
15    /// Total memory required
16    #[must_use]
17    pub fn total(&self) -> u64 {
18        self.weights + self.activations + self.gradients
19    }
20
21    /// Check if model fits in available memory
22    #[must_use]
23    pub fn fits_in(&self, available: u64) -> bool {
24        self.total() <= available
25    }
26
27    /// Create estimate for FP32 model
28    #[must_use]
29    pub fn fp32(param_count: u64, batch_size: usize, seq_len: usize, hidden_size: usize) -> Self {
30        Self {
31            weights: param_count * 4,
32            activations: (batch_size * seq_len * hidden_size * 4) as u64,
33            gradients: 0, // Frozen teacher
34        }
35    }
36
37    /// Create estimate for FP16 model
38    #[must_use]
39    pub fn fp16(param_count: u64, batch_size: usize, seq_len: usize, hidden_size: usize) -> Self {
40        Self {
41            weights: param_count * 2,
42            activations: (batch_size * seq_len * hidden_size * 2) as u64,
43            gradients: 0,
44        }
45    }
46
47    /// Create estimate for INT4/Q4 model
48    #[must_use]
49    pub fn int4(param_count: u64, batch_size: usize, seq_len: usize, hidden_size: usize) -> Self {
50        Self {
51            weights: param_count / 2, // 4-bit = 0.5 bytes per param
52            // Activations still in FP16 for compute
53            activations: (batch_size * seq_len * hidden_size * 2) as u64,
54            gradients: 0,
55        }
56    }
57}