Skip to main content

entrenar_lora/
optimizer.rs

1//! LoRA configuration optimizer (Kaizen principle).
2
3use crate::memory::MemoryPlanner;
4use crate::Method;
5use entrenar_common::{EntrenarError, Result};
6
7/// Optimal LoRA configuration result.
8#[derive(Debug, Clone)]
9pub struct OptimalConfig {
10    /// Recommended fine-tuning method
11    pub method: Method,
12    /// Recommended LoRA rank
13    pub rank: u32,
14    /// Recommended alpha scaling
15    pub alpha: f32,
16    /// Target modules to apply LoRA
17    pub target_modules: Vec<String>,
18    /// Estimated trainable parameters
19    pub trainable_params: u64,
20    /// Percentage of total parameters that are trainable
21    pub trainable_percent: f64,
22    /// Estimated memory requirement in GB
23    pub memory_gb: f64,
24    /// VRAM utilization percentage
25    pub utilization_percent: f64,
26    /// Training speedup compared to full fine-tuning
27    pub speedup: f64,
28}
29
30impl OptimalConfig {
31    /// Format as human-readable comparison table.
32    pub fn to_comparison_table(&self) -> String {
33        format!(
34            "Optimal Configuration:\n  Method: {:?}\n  Rank: {}\n  Alpha: {:.1}\n  Trainable: {} ({:.2}%)\n  Memory: {:.1} GB ({:.0}% utilization)\n  Speedup: {:.1}x vs full",
35            self.method,
36            self.rank,
37            self.alpha,
38            format_params(self.trainable_params),
39            self.trainable_percent,
40            self.memory_gb,
41            self.utilization_percent,
42            self.speedup
43        )
44    }
45}
46
47fn format_params(params: u64) -> String {
48    if params >= 1_000_000_000 {
49        format!("{:.1}B", params as f64 / 1e9)
50    } else if params >= 1_000_000 {
51        format!("{:.1}M", params as f64 / 1e6)
52    } else {
53        format!("{:.1}K", params as f64 / 1e3)
54    }
55}
56
57/// LoRA configuration optimizer.
58#[derive(Debug)]
59pub struct LoraOptimizer {
60    model_params: u64,
61    available_vram_bytes: u64,
62    target_utilization: f64,
63}
64
65impl LoraOptimizer {
66    /// Create a new optimizer.
67    pub fn new(model_params: u64, available_vram_gb: f64) -> Self {
68        Self {
69            model_params,
70            available_vram_bytes: (available_vram_gb * 1e9) as u64,
71            target_utilization: 0.85, // Target 85% VRAM utilization
72        }
73    }
74
75    /// Set target VRAM utilization (0.0 - 1.0).
76    pub fn with_target_utilization(mut self, utilization: f64) -> Self {
77        self.target_utilization = utilization.clamp(0.5, 0.95);
78        self
79    }
80
81    /// Find optimal configuration for the given method.
82    pub fn optimize(&self, method: Method) -> Result<OptimalConfig> {
83        let method = if method == Method::Auto { self.select_method() } else { method };
84
85        let rank = self.find_optimal_rank(method)?;
86        let planner = MemoryPlanner::new(self.model_params);
87        let memory = planner.estimate(method, rank);
88
89        let trainable_params = self.calculate_trainable_params(method, rank);
90        let trainable_percent = (trainable_params as f64 / self.model_params as f64) * 100.0;
91
92        let memory_gb = memory.total_bytes as f64 / 1e9;
93        let utilization = memory.total_bytes as f64 / self.available_vram_bytes as f64 * 100.0;
94
95        let speedup = match method {
96            Method::Full => 1.0,
97            Method::LoRA => 2.5,
98            Method::QLoRA => 1.8, // QLoRA has dequantization overhead
99            Method::Auto => 2.0,
100        };
101
102        Ok(OptimalConfig {
103            method,
104            rank,
105            alpha: rank as f32 / 4.0, // Common heuristic: alpha = rank/4
106            target_modules: vec![
107                "q_proj".to_string(),
108                "k_proj".to_string(),
109                "v_proj".to_string(),
110                "o_proj".to_string(),
111            ],
112            trainable_params,
113            trainable_percent,
114            memory_gb,
115            utilization_percent: utilization,
116            speedup,
117        })
118    }
119
120    fn select_method(&self) -> Method {
121        let planner = MemoryPlanner::new(self.model_params);
122
123        // Check if full fine-tuning fits
124        let full_mem = planner.estimate_full().total_bytes;
125        if full_mem < (self.available_vram_bytes as f64 * self.target_utilization) as u64 {
126            return Method::Full;
127        }
128
129        // Check if LoRA fits
130        let lora_mem = planner.estimate_lora(64).total_bytes;
131        if lora_mem < (self.available_vram_bytes as f64 * self.target_utilization) as u64 {
132            return Method::LoRA;
133        }
134
135        // Default to QLoRA
136        Method::QLoRA
137    }
138
139    fn find_optimal_rank(&self, method: Method) -> Result<u32> {
140        if method == Method::Full {
141            return Ok(0);
142        }
143
144        let planner = MemoryPlanner::new(self.model_params);
145        let target_mem = (self.available_vram_bytes as f64 * self.target_utilization) as u64;
146
147        // Binary search for optimal rank
148        let mut low = 8u32;
149        let mut high = 256u32;
150        let mut best_rank = 64u32;
151
152        while low <= high {
153            let mid = u32::midpoint(low, high);
154            let mem = if method == Method::QLoRA {
155                planner.estimate_qlora(mid, 4).total_bytes
156            } else {
157                planner.estimate_lora(mid).total_bytes
158            };
159
160            if mem <= target_mem {
161                best_rank = mid;
162                low = mid + 1;
163            } else {
164                if mid == 0 {
165                    break;
166                }
167                high = mid - 1;
168            }
169        }
170
171        if best_rank < 8 {
172            return Err(EntrenarError::InsufficientMemory {
173                required: planner.estimate_qlora(8, 4).total_bytes as f64 / 1e9,
174                available: self.available_vram_bytes as f64 / 1e9,
175            });
176        }
177
178        Ok(best_rank)
179    }
180
181    fn calculate_trainable_params(&self, method: Method, rank: u32) -> u64 {
182        if method == Method::Full {
183            return self.model_params;
184        }
185
186        // Estimate hidden dim and layers
187        let (hidden_dim, num_layers) = if self.model_params > 60_000_000_000 {
188            (8192u64, 80u64)
189        } else if self.model_params > 10_000_000_000 {
190            (5120, 40)
191        } else if self.model_params > 5_000_000_000 {
192            (4096, 32)
193        } else if self.model_params > 1_000_000_000 {
194            (2048, 22)
195        } else {
196            (1024, 12)
197        };
198
199        // LoRA params: 2 matrices × 4 modules × num_layers
200        // Each matrix is either (hidden × rank) or (rank × hidden)
201        (hidden_dim * u64::from(rank) * 2) * 4 * num_layers
202    }
203}
204
205/// Compare multiple fine-tuning methods.
206pub fn compare_methods(model_params: u64, available_vram_gb: f64) -> Vec<MethodComparison> {
207    let methods = [Method::Full, Method::LoRA, Method::QLoRA];
208    let optimizer = LoraOptimizer::new(model_params, available_vram_gb);
209
210    methods
211        .iter()
212        .filter_map(|&method| {
213            optimizer.optimize(method).ok().map(|config| MethodComparison {
214                method,
215                fits: config.utilization_percent <= 100.0,
216                memory_gb: config.memory_gb,
217                trainable_params: config.trainable_params,
218                speedup: config.speedup,
219                rank: config.rank,
220            })
221        })
222        .collect()
223}
224
225/// Method comparison result.
226#[derive(Debug, Clone)]
227pub struct MethodComparison {
228    pub method: Method,
229    pub fits: bool,
230    pub memory_gb: f64,
231    pub trainable_params: u64,
232    pub speedup: f64,
233    pub rank: u32,
234}
235
236#[cfg(test)]
237mod tests {
238    use super::*;
239
240    #[test]
241    fn test_optimizer_selects_qlora_for_small_vram() {
242        let optimizer = LoraOptimizer::new(7_000_000_000, 8.0);
243        let config = optimizer.optimize(Method::Auto).expect("config should be valid");
244
245        // With only 8GB, should select QLoRA
246        assert_eq!(config.method, Method::QLoRA);
247    }
248
249    #[test]
250    fn test_optimizer_selects_lora_for_medium_vram() {
251        let optimizer = LoraOptimizer::new(7_000_000_000, 24.0);
252        let config = optimizer.optimize(Method::Auto).expect("config should be valid");
253
254        // With 24GB for 7B model, optimizer may select LoRA, QLoRA, or Full
255        assert!(matches!(config.method, Method::LoRA | Method::QLoRA | Method::Full));
256    }
257
258    #[test]
259    fn test_optimal_rank_is_positive() {
260        let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
261        let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
262
263        assert!(config.rank >= 8);
264        assert!(config.rank <= 256);
265    }
266
267    #[test]
268    fn test_trainable_params_less_than_total() {
269        let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
270        let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
271
272        assert!(config.trainable_params < 7_000_000_000);
273        assert!(config.trainable_percent < 10.0);
274    }
275
276    #[test]
277    fn test_compare_methods() {
278        let comparisons = compare_methods(7_000_000_000, 16.0);
279
280        assert!(!comparisons.is_empty());
281        assert!(comparisons.iter().any(|c| c.method == Method::QLoRA));
282    }
283
284    #[test]
285    fn test_alpha_is_rank_over_4() {
286        let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
287        let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
288
289        assert!((config.alpha - config.rank as f32 / 4.0).abs() < 0.01);
290    }
291
292    #[test]
293    fn test_target_modules_populated() {
294        let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
295        let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
296
297        assert!(!config.target_modules.is_empty());
298        assert!(config.target_modules.contains(&"q_proj".to_string()));
299    }
300
301    #[test]
302    fn test_with_target_utilization() {
303        let optimizer = LoraOptimizer::new(7_000_000_000, 16.0).with_target_utilization(0.75);
304        let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
305
306        // Lower target utilization should give smaller rank
307        let high_util = LoraOptimizer::new(7_000_000_000, 16.0)
308            .with_target_utilization(0.95)
309            .optimize(Method::LoRA)
310            .expect("operation should succeed");
311
312        assert!(config.rank <= high_util.rank);
313    }
314
315    #[test]
316    fn test_target_utilization_clamping() {
317        // Test that utilization is clamped to 0.5-0.95
318        let low = LoraOptimizer::new(7_000_000_000, 16.0).with_target_utilization(0.1);
319        assert!(low.target_utilization >= 0.5);
320
321        let high = LoraOptimizer::new(7_000_000_000, 16.0).with_target_utilization(1.5);
322        assert!(high.target_utilization <= 0.95);
323    }
324
325    #[test]
326    fn test_format_params_billion() {
327        assert_eq!(format_params(7_000_000_000), "7.0B");
328        assert_eq!(format_params(1_500_000_000), "1.5B");
329    }
330
331    #[test]
332    fn test_format_params_million() {
333        assert_eq!(format_params(350_000_000), "350.0M");
334        assert_eq!(format_params(1_500_000), "1.5M");
335    }
336
337    #[test]
338    fn test_format_params_thousand() {
339        assert_eq!(format_params(500_000), "500.0K");
340        assert_eq!(format_params(1_500), "1.5K");
341    }
342
343    #[test]
344    fn test_to_comparison_table() {
345        let optimizer = LoraOptimizer::new(7_000_000_000, 16.0);
346        let config = optimizer.optimize(Method::LoRA).expect("config should be valid");
347        let table = config.to_comparison_table();
348
349        assert!(table.contains("Optimal Configuration"));
350        assert!(table.contains("Method:"));
351        assert!(table.contains("Rank:"));
352        assert!(table.contains("Alpha:"));
353        assert!(table.contains("Memory:"));
354    }
355
356    #[test]
357    fn test_full_method_rank_zero() {
358        let optimizer = LoraOptimizer::new(1_000_000_000, 100.0);
359        let config = optimizer.optimize(Method::Full).expect("config should be valid");
360        assert_eq!(config.rank, 0);
361    }
362
363    #[test]
364    fn test_full_method_all_params_trainable() {
365        let optimizer = LoraOptimizer::new(1_000_000_000, 100.0);
366        let config = optimizer.optimize(Method::Full).expect("config should be valid");
367        assert_eq!(config.trainable_params, 1_000_000_000);
368        assert_eq!(config.trainable_percent, 100.0);
369    }
370
371    #[test]
372    fn test_speedup_values() {
373        let optimizer = LoraOptimizer::new(7_000_000_000, 100.0);
374
375        let full = optimizer.optimize(Method::Full).expect("operation should succeed");
376        assert_eq!(full.speedup, 1.0);
377
378        let lora = optimizer.optimize(Method::LoRA).expect("operation should succeed");
379        assert_eq!(lora.speedup, 2.5);
380
381        let qlora = optimizer.optimize(Method::QLoRA).expect("operation should succeed");
382        assert_eq!(qlora.speedup, 1.8);
383    }
384
385    #[test]
386    fn test_compare_methods_includes_all() {
387        let comparisons = compare_methods(7_000_000_000, 100.0);
388
389        assert!(comparisons.iter().any(|c| c.method == Method::Full));
390        assert!(comparisons.iter().any(|c| c.method == Method::LoRA));
391        assert!(comparisons.iter().any(|c| c.method == Method::QLoRA));
392    }
393
394    #[test]
395    fn test_compare_methods_small_vram() {
396        let comparisons = compare_methods(7_000_000_000, 4.0);
397
398        // With very small VRAM, only QLoRA might fit
399        let _fitting: Vec<_> = comparisons.iter().filter(|c| c.fits).collect();
400        // At least one method should work (QLoRA)
401        assert!(!comparisons.is_empty());
402    }
403
404    #[test]
405    fn test_method_comparison_struct() {
406        let comparisons = compare_methods(7_000_000_000, 16.0);
407        let qlora = comparisons.iter().find(|c| c.method == Method::QLoRA);
408
409        if let Some(c) = qlora {
410            assert!(c.memory_gb > 0.0);
411            assert!(c.trainable_params > 0);
412            assert!(c.speedup > 0.0);
413            assert!(c.rank >= 8);
414        }
415    }
416}