Skip to main content

entrenar_lora/
memory.rs

1//! Memory planning for LoRA configurations (Heijunka principle).
2
3use crate::Method;
4
5/// Memory requirement estimation.
6#[derive(Debug, Clone)]
7pub struct MemoryRequirement {
8    /// Base model memory in bytes
9    pub model_bytes: u64,
10    /// Adapter memory in bytes
11    pub adapter_bytes: u64,
12    /// Optimizer state memory in bytes
13    pub optimizer_bytes: u64,
14    /// Activation memory in bytes
15    pub activation_bytes: u64,
16    /// Total memory in bytes
17    pub total_bytes: u64,
18    /// Memory savings compared to full fine-tuning (percentage)
19    pub savings_percent: f64,
20}
21
22impl MemoryRequirement {
23    /// Format as human-readable string.
24    pub fn to_human_readable(&self) -> String {
25        format!(
26            "Memory Requirement:\n  Model: {:.1} GB\n  Adapter: {:.1} GB\n  Optimizer: {:.1} GB\n  Activations: {:.1} GB\n  Total: {:.1} GB\n  Savings: {:.0}%",
27            self.model_bytes as f64 / 1e9,
28            self.adapter_bytes as f64 / 1e9,
29            self.optimizer_bytes as f64 / 1e9,
30            self.activation_bytes as f64 / 1e9,
31            self.total_bytes as f64 / 1e9,
32            self.savings_percent
33        )
34    }
35}
36
37/// Memory planner for different fine-tuning methods.
38#[derive(Debug)]
39pub struct MemoryPlanner {
40    model_params: u64,
41    hidden_dim: u64,
42    num_layers: u32,
43    batch_size: u32,
44    seq_len: u32,
45}
46
47impl MemoryPlanner {
48    /// Create a new memory planner.
49    pub fn new(model_params: u64) -> Self {
50        // Estimate architecture from param count
51        let (hidden_dim, num_layers) = estimate_architecture(model_params);
52
53        Self { model_params, hidden_dim, num_layers, batch_size: 32, seq_len: 512 }
54    }
55
56    /// Set batch size.
57    pub fn with_batch_size(mut self, batch_size: u32) -> Self {
58        self.batch_size = batch_size;
59        self
60    }
61
62    /// Set sequence length.
63    pub fn with_seq_len(mut self, seq_len: u32) -> Self {
64        self.seq_len = seq_len;
65        self
66    }
67
68    /// Estimate memory for full fine-tuning.
69    pub fn estimate_full(&self) -> MemoryRequirement {
70        let model_bytes = self.model_params * 2; // FP16
71        let optimizer_bytes = self.model_params * 8; // Adam: 2 FP32 states
72        let activation_bytes = self.estimate_activations();
73
74        let total_bytes = model_bytes + optimizer_bytes + activation_bytes;
75
76        MemoryRequirement {
77            model_bytes,
78            adapter_bytes: 0,
79            optimizer_bytes,
80            activation_bytes,
81            total_bytes,
82            savings_percent: 0.0,
83        }
84    }
85
86    /// Estimate memory for LoRA fine-tuning.
87    pub fn estimate_lora(&self, rank: u32) -> MemoryRequirement {
88        let model_bytes = self.model_params * 2; // FP16 (frozen)
89
90        // LoRA adapters: 2 matrices per target module (typically 4 modules per layer)
91        // A: d × r, B: r × d for each module
92        let adapter_params =
93            (self.hidden_dim * u64::from(rank) * 2) * 4 * u64::from(self.num_layers);
94        let adapter_bytes = adapter_params * 2; // FP16
95
96        // Optimizer only for adapter params
97        let optimizer_bytes = adapter_params * 8; // Adam states
98
99        let activation_bytes = self.estimate_activations();
100
101        let total_bytes = model_bytes + adapter_bytes + optimizer_bytes + activation_bytes;
102        let full_total = self.estimate_full().total_bytes;
103        let savings_percent = (1.0 - total_bytes as f64 / full_total as f64) * 100.0;
104
105        MemoryRequirement {
106            model_bytes,
107            adapter_bytes,
108            optimizer_bytes,
109            activation_bytes,
110            total_bytes,
111            savings_percent,
112        }
113    }
114
115    /// Estimate memory for QLoRA fine-tuning.
116    pub fn estimate_qlora(&self, rank: u32, bits: u8) -> MemoryRequirement {
117        // Base model in quantized format
118        let model_bytes = self.model_params * u64::from(bits) / 8;
119
120        // LoRA adapters in FP16
121        let adapter_params =
122            (self.hidden_dim * u64::from(rank) * 2) * 4 * u64::from(self.num_layers);
123        let adapter_bytes = adapter_params * 2;
124
125        // Optimizer only for adapter params
126        let optimizer_bytes = adapter_params * 8;
127
128        let activation_bytes = self.estimate_activations();
129
130        let total_bytes = model_bytes + adapter_bytes + optimizer_bytes + activation_bytes;
131        let full_total = self.estimate_full().total_bytes;
132        let savings_percent = (1.0 - total_bytes as f64 / full_total as f64) * 100.0;
133
134        MemoryRequirement {
135            model_bytes,
136            adapter_bytes,
137            optimizer_bytes,
138            activation_bytes,
139            total_bytes,
140            savings_percent,
141        }
142    }
143
144    /// Estimate memory for a given method.
145    pub fn estimate(&self, method: Method, rank: u32) -> MemoryRequirement {
146        match method {
147            Method::Full => self.estimate_full(),
148            Method::LoRA => self.estimate_lora(rank),
149            Method::QLoRA => self.estimate_qlora(rank, 4),
150            Method::Auto => {
151                // Try QLoRA first, then LoRA, then full
152                self.estimate_qlora(rank, 4)
153            }
154        }
155    }
156
157    fn estimate_activations(&self) -> u64 {
158        // Activations per layer: batch × seq × hidden × 2 (forward + backward)
159        let per_layer =
160            u64::from(self.batch_size) * u64::from(self.seq_len) * self.hidden_dim * 2 * 2; // FP16
161
162        per_layer * u64::from(self.num_layers)
163    }
164}
165
166fn estimate_architecture(params: u64) -> (u64, u32) {
167    // Rough estimates based on common model sizes
168    if params > 60_000_000_000 {
169        (8192, 80) // 70B class
170    } else if params > 10_000_000_000 {
171        (5120, 40) // 13B class
172    } else if params > 5_000_000_000 {
173        (4096, 32) // 7B class
174    } else if params > 1_000_000_000 {
175        (2048, 22) // 1-3B class
176    } else if params > 300_000_000 {
177        (1024, 12) // 350M class (BERT-base)
178    } else {
179        (768, 12) // Small models
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_memory_planner_7b() {
189        let planner = MemoryPlanner::new(7_000_000_000);
190
191        let full = planner.estimate_full();
192        let lora = planner.estimate_lora(64);
193        let qlora = planner.estimate_qlora(64, 4);
194
195        // Full should use most memory
196        assert!(full.total_bytes > lora.total_bytes);
197        assert!(lora.total_bytes > qlora.total_bytes);
198
199        // QLoRA should have significant savings
200        assert!(qlora.savings_percent > 50.0);
201    }
202
203    #[test]
204    fn test_lora_adapter_memory_scales_with_rank() {
205        let planner = MemoryPlanner::new(7_000_000_000);
206
207        let lora_16 = planner.estimate_lora(16);
208        let lora_64 = planner.estimate_lora(64);
209        let lora_128 = planner.estimate_lora(128);
210
211        assert!(lora_16.adapter_bytes < lora_64.adapter_bytes);
212        assert!(lora_64.adapter_bytes < lora_128.adapter_bytes);
213    }
214
215    #[test]
216    fn test_qlora_4bit_vs_8bit() {
217        let planner = MemoryPlanner::new(7_000_000_000);
218
219        let qlora_4 = planner.estimate_qlora(64, 4);
220        let qlora_8 = planner.estimate_qlora(64, 8);
221
222        // 4-bit should use less model memory
223        assert!(qlora_4.model_bytes < qlora_8.model_bytes);
224    }
225
226    #[test]
227    fn test_batch_size_affects_activations() {
228        let planner_small = MemoryPlanner::new(7_000_000_000).with_batch_size(8);
229        let planner_large = MemoryPlanner::new(7_000_000_000).with_batch_size(64);
230
231        let small = planner_small.estimate_full();
232        let large = planner_large.estimate_full();
233
234        assert!(small.activation_bytes < large.activation_bytes);
235    }
236
237    #[test]
238    fn test_architecture_estimation() {
239        let (hidden, layers) = estimate_architecture(7_000_000_000);
240        assert_eq!(hidden, 4096);
241        assert_eq!(layers, 32);
242
243        let (hidden, layers) = estimate_architecture(350_000_000);
244        assert_eq!(hidden, 1024);
245        assert_eq!(layers, 12);
246    }
247
248    #[test]
249    fn test_architecture_estimation_all_tiers() {
250        // 70B class
251        let (hidden, layers) = estimate_architecture(70_000_000_000);
252        assert_eq!(hidden, 8192);
253        assert_eq!(layers, 80);
254
255        // 13B class
256        let (hidden, layers) = estimate_architecture(13_000_000_000);
257        assert_eq!(hidden, 5120);
258        assert_eq!(layers, 40);
259
260        // 1-3B class
261        let (hidden, layers) = estimate_architecture(2_000_000_000);
262        assert_eq!(hidden, 2048);
263        assert_eq!(layers, 22);
264
265        // Small models
266        let (hidden, layers) = estimate_architecture(100_000_000);
267        assert_eq!(hidden, 768);
268        assert_eq!(layers, 12);
269    }
270
271    #[test]
272    fn test_with_seq_len() {
273        let planner = MemoryPlanner::new(7_000_000_000).with_seq_len(1024);
274        let full_1024 = planner.estimate_full();
275
276        let planner_short = MemoryPlanner::new(7_000_000_000).with_seq_len(256);
277        let full_256 = planner_short.estimate_full();
278
279        // Longer sequences require more activation memory
280        assert!(full_1024.activation_bytes > full_256.activation_bytes);
281    }
282
283    #[test]
284    fn test_estimate_method_dispatch() {
285        let planner = MemoryPlanner::new(7_000_000_000);
286
287        let full = planner.estimate(Method::Full, 64);
288        assert_eq!(full.adapter_bytes, 0);
289
290        let lora = planner.estimate(Method::LoRA, 64);
291        assert!(lora.adapter_bytes > 0);
292
293        let qlora = planner.estimate(Method::QLoRA, 64);
294        assert!(qlora.model_bytes < lora.model_bytes);
295
296        let auto = planner.estimate(Method::Auto, 64);
297        assert!(auto.savings_percent > 0.0);
298    }
299
300    #[test]
301    fn test_to_human_readable() {
302        let planner = MemoryPlanner::new(7_000_000_000);
303        let req = planner.estimate_full();
304        let readable = req.to_human_readable();
305
306        assert!(readable.contains("Memory Requirement"));
307        assert!(readable.contains("GB"));
308        assert!(readable.contains("Model:"));
309        assert!(readable.contains("Total:"));
310    }
311
312    #[test]
313    fn test_full_has_zero_savings() {
314        let planner = MemoryPlanner::new(7_000_000_000);
315        let full = planner.estimate_full();
316        assert_eq!(full.savings_percent, 0.0);
317    }
318
319    #[test]
320    fn test_lora_has_positive_savings() {
321        let planner = MemoryPlanner::new(7_000_000_000);
322        let lora = planner.estimate_lora(64);
323        assert!(lora.savings_percent > 0.0);
324    }
325
326    #[test]
327    fn test_qlora_saves_more_than_lora() {
328        let planner = MemoryPlanner::new(7_000_000_000);
329        let lora = planner.estimate_lora(64);
330        let qlora = planner.estimate_qlora(64, 4);
331        assert!(qlora.savings_percent > lora.savings_percent);
332    }
333}