1use crate::Method;
4
5#[derive(Debug, Clone)]
7pub struct MemoryRequirement {
8 pub model_bytes: u64,
10 pub adapter_bytes: u64,
12 pub optimizer_bytes: u64,
14 pub activation_bytes: u64,
16 pub total_bytes: u64,
18 pub savings_percent: f64,
20}
21
22impl MemoryRequirement {
23 pub fn to_human_readable(&self) -> String {
25 format!(
26 "Memory Requirement:\n Model: {:.1} GB\n Adapter: {:.1} GB\n Optimizer: {:.1} GB\n Activations: {:.1} GB\n Total: {:.1} GB\n Savings: {:.0}%",
27 self.model_bytes as f64 / 1e9,
28 self.adapter_bytes as f64 / 1e9,
29 self.optimizer_bytes as f64 / 1e9,
30 self.activation_bytes as f64 / 1e9,
31 self.total_bytes as f64 / 1e9,
32 self.savings_percent
33 )
34 }
35}
36
37#[derive(Debug)]
39pub struct MemoryPlanner {
40 model_params: u64,
41 hidden_dim: u64,
42 num_layers: u32,
43 batch_size: u32,
44 seq_len: u32,
45}
46
47impl MemoryPlanner {
48 pub fn new(model_params: u64) -> Self {
50 let (hidden_dim, num_layers) = estimate_architecture(model_params);
52
53 Self { model_params, hidden_dim, num_layers, batch_size: 32, seq_len: 512 }
54 }
55
56 pub fn with_batch_size(mut self, batch_size: u32) -> Self {
58 self.batch_size = batch_size;
59 self
60 }
61
62 pub fn with_seq_len(mut self, seq_len: u32) -> Self {
64 self.seq_len = seq_len;
65 self
66 }
67
68 pub fn estimate_full(&self) -> MemoryRequirement {
70 let model_bytes = self.model_params * 2; let optimizer_bytes = self.model_params * 8; let activation_bytes = self.estimate_activations();
73
74 let total_bytes = model_bytes + optimizer_bytes + activation_bytes;
75
76 MemoryRequirement {
77 model_bytes,
78 adapter_bytes: 0,
79 optimizer_bytes,
80 activation_bytes,
81 total_bytes,
82 savings_percent: 0.0,
83 }
84 }
85
86 pub fn estimate_lora(&self, rank: u32) -> MemoryRequirement {
88 let model_bytes = self.model_params * 2; let adapter_params =
93 (self.hidden_dim * u64::from(rank) * 2) * 4 * u64::from(self.num_layers);
94 let adapter_bytes = adapter_params * 2; let optimizer_bytes = adapter_params * 8; let activation_bytes = self.estimate_activations();
100
101 let total_bytes = model_bytes + adapter_bytes + optimizer_bytes + activation_bytes;
102 let full_total = self.estimate_full().total_bytes;
103 let savings_percent = (1.0 - total_bytes as f64 / full_total as f64) * 100.0;
104
105 MemoryRequirement {
106 model_bytes,
107 adapter_bytes,
108 optimizer_bytes,
109 activation_bytes,
110 total_bytes,
111 savings_percent,
112 }
113 }
114
115 pub fn estimate_qlora(&self, rank: u32, bits: u8) -> MemoryRequirement {
117 let model_bytes = self.model_params * u64::from(bits) / 8;
119
120 let adapter_params =
122 (self.hidden_dim * u64::from(rank) * 2) * 4 * u64::from(self.num_layers);
123 let adapter_bytes = adapter_params * 2;
124
125 let optimizer_bytes = adapter_params * 8;
127
128 let activation_bytes = self.estimate_activations();
129
130 let total_bytes = model_bytes + adapter_bytes + optimizer_bytes + activation_bytes;
131 let full_total = self.estimate_full().total_bytes;
132 let savings_percent = (1.0 - total_bytes as f64 / full_total as f64) * 100.0;
133
134 MemoryRequirement {
135 model_bytes,
136 adapter_bytes,
137 optimizer_bytes,
138 activation_bytes,
139 total_bytes,
140 savings_percent,
141 }
142 }
143
144 pub fn estimate(&self, method: Method, rank: u32) -> MemoryRequirement {
146 match method {
147 Method::Full => self.estimate_full(),
148 Method::LoRA => self.estimate_lora(rank),
149 Method::QLoRA => self.estimate_qlora(rank, 4),
150 Method::Auto => {
151 self.estimate_qlora(rank, 4)
153 }
154 }
155 }
156
157 fn estimate_activations(&self) -> u64 {
158 let per_layer =
160 u64::from(self.batch_size) * u64::from(self.seq_len) * self.hidden_dim * 2 * 2; per_layer * u64::from(self.num_layers)
163 }
164}
165
166fn estimate_architecture(params: u64) -> (u64, u32) {
167 if params > 60_000_000_000 {
169 (8192, 80) } else if params > 10_000_000_000 {
171 (5120, 40) } else if params > 5_000_000_000 {
173 (4096, 32) } else if params > 1_000_000_000 {
175 (2048, 22) } else if params > 300_000_000 {
177 (1024, 12) } else {
179 (768, 12) }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_memory_planner_7b() {
189 let planner = MemoryPlanner::new(7_000_000_000);
190
191 let full = planner.estimate_full();
192 let lora = planner.estimate_lora(64);
193 let qlora = planner.estimate_qlora(64, 4);
194
195 assert!(full.total_bytes > lora.total_bytes);
197 assert!(lora.total_bytes > qlora.total_bytes);
198
199 assert!(qlora.savings_percent > 50.0);
201 }
202
203 #[test]
204 fn test_lora_adapter_memory_scales_with_rank() {
205 let planner = MemoryPlanner::new(7_000_000_000);
206
207 let lora_16 = planner.estimate_lora(16);
208 let lora_64 = planner.estimate_lora(64);
209 let lora_128 = planner.estimate_lora(128);
210
211 assert!(lora_16.adapter_bytes < lora_64.adapter_bytes);
212 assert!(lora_64.adapter_bytes < lora_128.adapter_bytes);
213 }
214
215 #[test]
216 fn test_qlora_4bit_vs_8bit() {
217 let planner = MemoryPlanner::new(7_000_000_000);
218
219 let qlora_4 = planner.estimate_qlora(64, 4);
220 let qlora_8 = planner.estimate_qlora(64, 8);
221
222 assert!(qlora_4.model_bytes < qlora_8.model_bytes);
224 }
225
226 #[test]
227 fn test_batch_size_affects_activations() {
228 let planner_small = MemoryPlanner::new(7_000_000_000).with_batch_size(8);
229 let planner_large = MemoryPlanner::new(7_000_000_000).with_batch_size(64);
230
231 let small = planner_small.estimate_full();
232 let large = planner_large.estimate_full();
233
234 assert!(small.activation_bytes < large.activation_bytes);
235 }
236
237 #[test]
238 fn test_architecture_estimation() {
239 let (hidden, layers) = estimate_architecture(7_000_000_000);
240 assert_eq!(hidden, 4096);
241 assert_eq!(layers, 32);
242
243 let (hidden, layers) = estimate_architecture(350_000_000);
244 assert_eq!(hidden, 1024);
245 assert_eq!(layers, 12);
246 }
247
248 #[test]
249 fn test_architecture_estimation_all_tiers() {
250 let (hidden, layers) = estimate_architecture(70_000_000_000);
252 assert_eq!(hidden, 8192);
253 assert_eq!(layers, 80);
254
255 let (hidden, layers) = estimate_architecture(13_000_000_000);
257 assert_eq!(hidden, 5120);
258 assert_eq!(layers, 40);
259
260 let (hidden, layers) = estimate_architecture(2_000_000_000);
262 assert_eq!(hidden, 2048);
263 assert_eq!(layers, 22);
264
265 let (hidden, layers) = estimate_architecture(100_000_000);
267 assert_eq!(hidden, 768);
268 assert_eq!(layers, 12);
269 }
270
271 #[test]
272 fn test_with_seq_len() {
273 let planner = MemoryPlanner::new(7_000_000_000).with_seq_len(1024);
274 let full_1024 = planner.estimate_full();
275
276 let planner_short = MemoryPlanner::new(7_000_000_000).with_seq_len(256);
277 let full_256 = planner_short.estimate_full();
278
279 assert!(full_1024.activation_bytes > full_256.activation_bytes);
281 }
282
283 #[test]
284 fn test_estimate_method_dispatch() {
285 let planner = MemoryPlanner::new(7_000_000_000);
286
287 let full = planner.estimate(Method::Full, 64);
288 assert_eq!(full.adapter_bytes, 0);
289
290 let lora = planner.estimate(Method::LoRA, 64);
291 assert!(lora.adapter_bytes > 0);
292
293 let qlora = planner.estimate(Method::QLoRA, 64);
294 assert!(qlora.model_bytes < lora.model_bytes);
295
296 let auto = planner.estimate(Method::Auto, 64);
297 assert!(auto.savings_percent > 0.0);
298 }
299
300 #[test]
301 fn test_to_human_readable() {
302 let planner = MemoryPlanner::new(7_000_000_000);
303 let req = planner.estimate_full();
304 let readable = req.to_human_readable();
305
306 assert!(readable.contains("Memory Requirement"));
307 assert!(readable.contains("GB"));
308 assert!(readable.contains("Model:"));
309 assert!(readable.contains("Total:"));
310 }
311
312 #[test]
313 fn test_full_has_zero_savings() {
314 let planner = MemoryPlanner::new(7_000_000_000);
315 let full = planner.estimate_full();
316 assert_eq!(full.savings_percent, 0.0);
317 }
318
319 #[test]
320 fn test_lora_has_positive_savings() {
321 let planner = MemoryPlanner::new(7_000_000_000);
322 let lora = planner.estimate_lora(64);
323 assert!(lora.savings_percent > 0.0);
324 }
325
326 #[test]
327 fn test_qlora_saves_more_than_lora() {
328 let planner = MemoryPlanner::new(7_000_000_000);
329 let lora = planner.estimate_lora(64);
330 let qlora = planner.estimate_qlora(64, 4);
331 assert!(qlora.savings_percent > lora.savings_percent);
332 }
333}