1use crate::memory::MemoryPlanner;
4use crate::Method;
5use entrenar_common::{EntrenarError, Result};
6
7#[derive(Debug, Clone)]
9pub struct OptimalConfig {
10 pub method: Method,
12 pub rank: u32,
14 pub alpha: f32,
16 pub target_modules: Vec<String>,
18 pub trainable_params: u64,
20 pub trainable_percent: f64,
22 pub memory_gb: f64,
24 pub utilization_percent: f64,
26 pub speedup: f64,
28}
29
30impl OptimalConfig {
31 pub fn to_comparison_table(&self) -> String {
33 format!(
34 "Optimal Configuration:\n Method: {:?}\n Rank: {}\n Alpha: {:.1}\n Trainable: {} ({:.2}%)\n Memory: {:.1} GB ({:.0}% utilization)\n Speedup: {:.1}x vs full",
35 self.method,
36 self.rank,
37 self.alpha,
38 format_params(self.trainable_params),
39 self.trainable_percent,
40 self.memory_gb,
41 self.utilization_percent,
42 self.speedup
43 )
44 }
45}
46
47fn format_params(params: u64) -> String {
48 if params >= 1_000_000_000 {
49 format!("{:.1}B", params as f64 / 1e9)
50 } else if params >= 1_000_000 {
51 format!("{:.1}M", params as f64 / 1e6)
52 } else {
53 format!("{:.1}K", params as f64 / 1e3)
54 }
55}
56
57#[derive(Debug)]
59pub struct LoraOptimizer {
60 model_params: u64,
61 available_vram_bytes: u64,
62 target_utilization: f64,
63}
64
65impl LoraOptimizer {
66 pub fn new(model_params: u64, available_vram_gb: f64) -> Self {
68 Self {
69 model_params,
70 available_vram_bytes: (available_vram_gb * 1e9) as u64,
71 target_utilization: 0.85, }
73 }
74
75 pub fn with_target_utilization(mut self, utilization: f64) -> Self {
77 self.target_utilization = utilization.clamp(0.5, 0.95);
78 self
79 }
80
81 pub fn optimize(&self, method: Method) -> Result<OptimalConfig> {
83 let method = if method == Method::Auto { self.select_method() } else { method };
84
85 let rank = self.find_optimal_rank(method)?;
86 let planner = MemoryPlanner::new(self.model_params);
87 let memory = planner.estimate(method, rank);
88
89 let trainable_params = self.calculate_trainable_params(method, rank);
90 let trainable_percent = (trainable_params as f64 / self.model_params as f64) * 100.0;
91
92 let memory_gb = memory.total_bytes as f64 / 1e9;
93 let utilization = memory.total_bytes as f64 / self.available_vram_bytes as f64 * 100.0;
94
95 let speedup = match method {
96 Method::Full => 1.0,
97 Method::LoRA => 2.5,
98 Method::QLoRA => 1.8, Method::Auto => 2.0,
100 };
101
102 Ok(OptimalConfig {
103 method,
104 rank,
105 alpha: rank as f32 / 4.0, target_modules: vec![
107 "q_proj".to_string(),
108 "k_proj".to_string(),
109 "v_proj".to_string(),
110 "o_proj".to_string(),
111 ],
112 trainable_params,
113 trainable_percent,
114 memory_gb,
115 utilization_percent: utilization,
116 speedup,
117 })
118 }
119
120 fn select_method(&self) -> Method {
121 let planner = MemoryPlanner::new(self.model_params);
122
123 let full_mem = planner.estimate_full().total_bytes;
125 if full_mem < (self.available_vram_bytes as f64 * self.target_utilization) as u64 {
126 return Method::Full;
127 }
128
129 let lora_mem = planner.estimate_lora(64).total_bytes;
131 if lora_mem < (self.available_vram_bytes as f64 * self.target_utilization) as u64 {
132 return Method::LoRA;
133 }
134
135 Method::QLoRA
137 }
138
139 fn find_optimal_rank(&self, method: Method) -> Result<u32> {
140 if method == Method::Full {
141 return Ok(0);
142 }
143
144 let planner = MemoryPlanner::new(self.model_params);
145 let target_mem = (self.available_vram_bytes as f64 * self.target_utilization) as u64;
146
147 let mut low = 8u32;
149 let mut high = 256u32;
150 let mut best_rank = 64u32;
151
152 while low <= high {
153 let mid = u32::midpoint(low, high);
154 let mem = if method == Method::QLoRA {
155 planner.estimate_qlora(mid, 4).total_bytes
156 } else {
157 planner.estimate_lora(mid).total_bytes
158 };
159
160 if mem <= target_mem {
161 best_rank = mid;
162 low = mid + 1;
163 } else {
164 if mid == 0 {
165 break;
166 }
167 high = mid - 1;
168 }
169 }
170
171 if best_rank < 8 {
172 return Err(EntrenarError::InsufficientMemory {
173 required: planner.estimate_qlora(8, 4).total_bytes as f64 / 1e9,
174 available: self.available_vram_bytes as f64 / 1e9,
175 });
176 }
177
178 Ok(best_rank)
179 }
180
181 fn calculate_trainable_params(&self, method: Method, rank: u32) -> u64 {
182 if method == Method::Full {
183 return self.model_params;
184 }
185
186 let (hidden_dim, num_layers) = if self.model_params > 60_000_000_000 {
188 (8192u64, 80u64)
189 } else if self.model_params > 10_000_000_000 {
190 (5120, 40)
191 } else if self.model_params > 5_000_000_000 {
192 (4096, 32)
193 } else if self.model_params > 1_000_000_000 {
194 (2048, 22)
195 } else {
196 (1024, 12)
197 };
198
199 (hidden_dim * u64::from(rank) * 2) * 4 * num_layers
202 }
203}
204
205pub fn compare_methods(model_params: u64, available_vram_gb: f64) -> Vec<MethodComparison> {
207 let methods = [Method::Full, Method::LoRA, Method::QLoRA];
208 let optimizer = LoraOptimizer::new(model_params, available_vram_gb);
209
210 methods
211 .iter()
212 .filter_map(|&method| {
213 optimizer.optimize(method).ok().map(|config| MethodComparison {
214 method,
215 fits: config.utilization_percent <= 100.0,
216 memory_gb: config.memory_gb,
217 trainable_params: config.trainable_params,
218 speedup: config.speedup,
219 rank: config.rank,
220 })
221 })
222 .collect()
223}
224
225#[derive(Debug, Clone)]
227pub struct MethodComparison {
228 pub method: Method,
229 pub fits: bool,
230 pub memory_gb: f64,
231 pub trainable_params: u64,
232 pub speedup: f64,
233 pub rank: u32,
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_optimizer_selects_qlora_for_small_vram() {
242 let optimizer = LoraOptimizer::new(7_000_000_000, 8.0);
243 let config = optimizer.optimize(Method::Auto).expect("config should be valid");
244
245 assert_eq!(config.method, Method::QLoRA);
247 }
248
249 #[test]
250 fn test_optimizer_selects_lora_for_medium_vram() {
251 let optimizer = LoraOptimizer::new(7_000_000_000, 24.0);
252 let config = optimizer.optimize(Method::Auto).expect("config should be valid");
253
254 assert!(matches!(config.method, Method::LoRA | Method::QLoRA | Method::Full));
256 }
257
258 #[test]
259 fn test_optimal_rank_is_positive() {
260 let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
261 let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
262
263 assert!(config.rank >= 8);
264 assert!(config.rank <= 256);
265 }
266
267 #[test]
268 fn test_trainable_params_less_than_total() {
269 let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
270 let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
271
272 assert!(config.trainable_params < 7_000_000_000);
273 assert!(config.trainable_percent < 10.0);
274 }
275
276 #[test]
277 fn test_compare_methods() {
278 let comparisons = compare_methods(7_000_000_000, 16.0);
279
280 assert!(!comparisons.is_empty());
281 assert!(comparisons.iter().any(|c| c.method == Method::QLoRA));
282 }
283
284 #[test]
285 fn test_alpha_is_rank_over_4() {
286 let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
287 let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
288
289 assert!((config.alpha - config.rank as f32 / 4.0).abs() < 0.01);
290 }
291
292 #[test]
293 fn test_target_modules_populated() {
294 let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
295 let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
296
297 assert!(!config.target_modules.is_empty());
298 assert!(config.target_modules.contains(&"q_proj".to_string()));
299 }
300
301 #[test]
302 fn test_with_target_utilization() {
303 let optimizer = LoraOptimizer::new(7_000_000_000, 16.0).with_target_utilization(0.75);
304 let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
305
306 let high_util = LoraOptimizer::new(7_000_000_000, 16.0)
308 .with_target_utilization(0.95)
309 .optimize(Method::LoRA)
310 .expect("operation should succeed");
311
312 assert!(config.rank <= high_util.rank);
313 }
314
315 #[test]
316 fn test_target_utilization_clamping() {
317 let low = LoraOptimizer::new(7_000_000_000, 16.0).with_target_utilization(0.1);
319 assert!(low.target_utilization >= 0.5);
320
321 let high = LoraOptimizer::new(7_000_000_000, 16.0).with_target_utilization(1.5);
322 assert!(high.target_utilization <= 0.95);
323 }
324
325 #[test]
326 fn test_format_params_billion() {
327 assert_eq!(format_params(7_000_000_000), "7.0B");
328 assert_eq!(format_params(1_500_000_000), "1.5B");
329 }
330
331 #[test]
332 fn test_format_params_million() {
333 assert_eq!(format_params(350_000_000), "350.0M");
334 assert_eq!(format_params(1_500_000), "1.5M");
335 }
336
337 #[test]
338 fn test_format_params_thousand() {
339 assert_eq!(format_params(500_000), "500.0K");
340 assert_eq!(format_params(1_500), "1.5K");
341 }
342
343 #[test]
344 fn test_to_comparison_table() {
345 let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
346 let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
347 let table = config.to_comparison_table();
348
349 assert!(table.contains("Optimal Configuration"));
350 assert!(table.contains("Method:"));
351 assert!(table.contains("Rank:"));
352 assert!(table.contains("Alpha:"));
353 assert!(table.contains("Memory:"));
354 }
355
356 #[test]
357 fn test_full_method_rank_zero() {
358 let optimizer = LoraOptimizer::new(1_000_000_000, 100.0);
359 let config = optimizer.optimize(Method::Full).expect("config should be valid");
360 assert_eq!(config.rank, 0);
361 }
362
363 #[test]
364 fn test_full_method_all_params_trainable() {
365 let optimizer = LoraOptimizer::new(1_000_000_000, 100.0);
366 let config = optimizer.optimize(Method::Full).expect("config should be valid");
367 assert_eq!(config.trainable_params, 1_000_000_000);
368 assert_eq!(config.trainable_percent, 100.0);
369 }
370
371 #[test]
372 fn test_speedup_values() {
373 let optimizer = LoraOptimizer::new(7_000_000_000, 100.0);
374
375 let full = optimizer.optimize(Method::Full).expect("operation should succeed");
376 assert_eq!(full.speedup, 1.0);
377
378 let lora = optimizer.optimize(Method::LoRA).expect("operation should succeed");
379 assert_eq!(lora.speedup, 2.5);
380
381 let qlora = optimizer.optimize(Method::QLoRA).expect("operation should succeed");
382 assert_eq!(qlora.speedup, 1.8);
383 }
384
385 #[test]
386 fn test_compare_methods_includes_all() {
387 let comparisons = compare_methods(7_000_000_000, 100.0);
388
389 assert!(comparisons.iter().any(|c| c.method == Method::Full));
390 assert!(comparisons.iter().any(|c| c.method == Method::LoRA));
391 assert!(comparisons.iter().any(|c| c.method == Method::QLoRA));
392 }
393
394 #[test]
395 fn test_compare_methods_small_vram() {
396 let comparisons = compare_methods(7_000_000_000, 4.0);
397
398 let _fitting: Vec<_> = comparisons.iter().filter(|c| c.fits).collect();
400 assert!(!comparisons.is_empty());
402 }
403
404 #[test]
405 fn test_method_comparison_struct() {
406 let comparisons = compare_methods(7_000_000_000, 16.0);
407 let qlora = comparisons.iter().find(|c| c.method == Method::QLoRA);
408
409 if let Some(c) = qlora {
410 assert!(c.memory_gb > 0.0);
411 assert!(c.trainable_params > 0);
412 assert!(c.speedup > 0.0);
413 assert!(c.rank >= 8);
414 }
415 }
416}