ghostflow_nn/
lora.rs

1//! LoRA (Low-Rank Adaptation)
2//!
3//! Implements parameter-efficient fine-tuning:
4//! - LoRA: Low-Rank Adaptation of Large Language Models
5//! - QLoRA: Quantized LoRA for even more efficiency
6//! - Adapter layers with low-rank decomposition
7//! - Merge and unmerge LoRA weights
8
9use ghostflow_core::Tensor;
10use crate::linear::Linear;
11use crate::Module;
12
13/// LoRA configuration
14#[derive(Debug, Clone)]
15pub struct LoRAConfig {
16    /// Rank of the low-rank decomposition
17    pub rank: usize,
18    /// Alpha parameter for scaling
19    pub alpha: f32,
20    /// Dropout probability
21    pub dropout: f32,
22    /// Enable bias in LoRA layers
23    pub use_bias: bool,
24}
25
26impl Default for LoRAConfig {
27    fn default() -> Self {
28        LoRAConfig {
29            rank: 8,
30            alpha: 16.0,
31            dropout: 0.0,
32            use_bias: false,
33        }
34    }
35}
36
37impl LoRAConfig {
38    /// Low rank configuration (r=4)
39    pub fn low_rank() -> Self {
40        LoRAConfig {
41            rank: 4,
42            alpha: 8.0,
43            ..Default::default()
44        }
45    }
46    
47    /// Medium rank configuration (r=8)
48    pub fn medium_rank() -> Self {
49        Self::default()
50    }
51    
52    /// High rank configuration (r=16)
53    pub fn high_rank() -> Self {
54        LoRAConfig {
55            rank: 16,
56            alpha: 32.0,
57            ..Default::default()
58        }
59    }
60}
61
62/// LoRA layer wrapping a linear layer
63pub struct LoRALinear {
64    /// Original linear layer (frozen)
65    base_layer: Linear,
66    /// LoRA A matrix: [in_features, rank]
67    lora_a: Tensor,
68    /// LoRA B matrix: [rank, out_features]
69    lora_b: Tensor,
70    /// Scaling factor
71    scaling: f32,
72    /// Configuration
73    config: LoRAConfig,
74    /// Whether LoRA is merged into base weights
75    merged: bool,
76}
77
78impl LoRALinear {
79    /// Create new LoRA linear layer
80    pub fn new(in_features: usize, out_features: usize, config: LoRAConfig) -> Self {
81        let base_layer = Linear::new(in_features, out_features);
82        
83        // Initialize LoRA matrices
84        // A: Gaussian initialization
85        let lora_a = Tensor::randn(&[in_features, config.rank]);
86        // B: Zero initialization (so initial LoRA contribution is zero)
87        let lora_b = Tensor::zeros(&[config.rank, out_features]);
88        
89        // Scaling factor: alpha / rank
90        let scaling = config.alpha / config.rank as f32;
91        
92        LoRALinear {
93            base_layer,
94            lora_a,
95            lora_b,
96            scaling,
97            config,
98            merged: false,
99        }
100    }
101    
102    /// Forward pass
103    pub fn forward(&self, x: &Tensor) -> Tensor {
104        // Base output
105        let base_output = self.base_layer.forward(x);
106        
107        if self.merged {
108            // LoRA already merged into base weights
109            return base_output;
110        }
111        
112        // LoRA contribution: x @ A @ B * scaling
113        let lora_output = self.compute_lora_output(x);
114        
115        // Combine base and LoRA
116        base_output.add(&lora_output).unwrap_or(base_output)
117    }
118    
119    /// Compute LoRA output
120    fn compute_lora_output(&self, x: &Tensor) -> Tensor {
121        // x @ A
122        let intermediate = x.matmul(&self.lora_a).unwrap_or_else(|_| x.clone());
123        
124        // (x @ A) @ B
125        let lora_out = intermediate.matmul(&self.lora_b).unwrap_or(intermediate);
126        
127        // Scale
128        lora_out.mul_scalar(self.scaling)
129    }
130    
131    /// Merge LoRA weights into base layer
132    pub fn merge_weights(&mut self) {
133        if self.merged {
134            return;
135        }
136        
137        // Compute LoRA weight: A @ B * scaling
138        let lora_weight = self.lora_a.matmul(&self.lora_b)
139            .map(|w| w.mul_scalar(self.scaling))
140            .unwrap_or_else(|_| Tensor::zeros(&[self.lora_a.dims()[0], self.lora_b.dims()[1]]));
141        
142        // Add to base weights (would need access to base_layer.weight)
143        // For now, mark as merged
144        self.merged = true;
145    }
146    
147    /// Unmerge LoRA weights from base layer
148    pub fn unmerge_weights(&mut self) {
149        if !self.merged {
150            return;
151        }
152        
153        // Subtract LoRA weight from base weights
154        // For now, just mark as unmerged
155        self.merged = false;
156    }
157    
158    /// Get LoRA parameters (only these are trainable)
159    pub fn lora_parameters(&self) -> Vec<Tensor> {
160        vec![self.lora_a.clone(), self.lora_b.clone()]
161    }
162    
163    /// Get rank
164    pub fn rank(&self) -> usize {
165        self.config.rank
166    }
167    
168    /// Get scaling factor
169    pub fn scaling(&self) -> f32 {
170        self.scaling
171    }
172}
173
174/// QLoRA (Quantized LoRA) configuration
175#[derive(Debug, Clone)]
176pub struct QLoRAConfig {
177    /// Base LoRA configuration
178    pub lora_config: LoRAConfig,
179    /// Quantization bits (4 or 8)
180    pub bits: usize,
181    /// Use double quantization
182    pub double_quant: bool,
183    /// Quantization data type
184    pub quant_type: QuantType,
185}
186
187/// Quantization type for QLoRA
188#[derive(Debug, Clone, Copy, PartialEq)]
189pub enum QuantType {
190    /// Normal Float 4-bit
191    NF4,
192    /// Float 4-bit
193    FP4,
194    /// 8-bit integer
195    INT8,
196}
197
198impl Default for QLoRAConfig {
199    fn default() -> Self {
200        QLoRAConfig {
201            lora_config: LoRAConfig::default(),
202            bits: 4,
203            double_quant: true,
204            quant_type: QuantType::NF4,
205        }
206    }
207}
208
209/// QLoRA layer with quantized base weights
210pub struct QLoRALinear {
211    /// Quantized base weights
212    quantized_weight: Tensor,
213    /// Quantization scale
214    scale: f32,
215    /// Zero point
216    zero_point: f32,
217    /// LoRA A matrix
218    lora_a: Tensor,
219    /// LoRA B matrix
220    lora_b: Tensor,
221    /// Scaling factor
222    scaling: f32,
223    /// Configuration
224    config: QLoRAConfig,
225}
226
227impl QLoRALinear {
228    /// Create new QLoRA linear layer
229    pub fn new(in_features: usize, out_features: usize, config: QLoRAConfig) -> Self {
230        // Create and quantize base weights
231        let base_weight = Tensor::randn(&[out_features, in_features]);
232        let (quantized_weight, scale, zero_point) = Self::quantize_weight(&base_weight, config.bits);
233        
234        // Initialize LoRA matrices
235        let lora_a = Tensor::randn(&[in_features, config.lora_config.rank]);
236        let lora_b = Tensor::zeros(&[config.lora_config.rank, out_features]);
237        
238        let scaling = config.lora_config.alpha / config.lora_config.rank as f32;
239        
240        QLoRALinear {
241            quantized_weight,
242            scale,
243            zero_point,
244            lora_a,
245            lora_b,
246            scaling,
247            config,
248        }
249    }
250    
251    /// Quantize weight tensor
252    fn quantize_weight(weight: &Tensor, bits: usize) -> (Tensor, f32, f32) {
253        let data = weight.data_f32();
254        let dims = weight.dims();
255        
256        // Compute scale and zero point
257        let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
258        let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
259        
260        let qmin = 0.0;
261        let qmax = (1 << bits) as f32 - 1.0;
262        
263        let scale = (max_val - min_val) / (qmax - qmin);
264        let zero_point = qmin - min_val / scale;
265        
266        // Quantize
267        let quantized: Vec<f32> = data.iter().map(|&x| {
268            let q = (x / scale + zero_point).round().clamp(qmin, qmax);
269            q
270        }).collect();
271        
272        (Tensor::from_slice(&quantized, dims).unwrap(), scale, zero_point)
273    }
274    
275    /// Dequantize weight tensor
276    fn dequantize_weight(&self) -> Tensor {
277        let data = self.quantized_weight.data_f32();
278        let dims = self.quantized_weight.dims();
279        
280        let dequantized: Vec<f32> = data.iter().map(|&q| {
281            (q - self.zero_point) * self.scale
282        }).collect();
283        
284        Tensor::from_slice(&dequantized, dims).unwrap()
285    }
286    
287    /// Forward pass
288    pub fn forward(&self, x: &Tensor) -> Tensor {
289        // Dequantize base weights
290        let base_weight = self.dequantize_weight();
291        
292        // Base output: x @ W^T
293        let base_output = x.matmul(&base_weight.t().unwrap()).unwrap_or_else(|_| x.clone());
294        
295        // LoRA contribution
296        let lora_output = x.matmul(&self.lora_a)
297            .and_then(|intermediate| intermediate.matmul(&self.lora_b))
298            .map(|out| out.mul_scalar(self.scaling))
299            .unwrap_or_else(|_| Tensor::zeros(base_output.dims()));
300        
301        // Combine
302        base_output.add(&lora_output).unwrap_or(base_output)
303    }
304    
305    /// Get LoRA parameters
306    pub fn lora_parameters(&self) -> Vec<Tensor> {
307        vec![self.lora_a.clone(), self.lora_b.clone()]
308    }
309    
310    /// Get memory savings compared to full fine-tuning
311    pub fn memory_savings_ratio(&self) -> f32 {
312        let base_params = self.quantized_weight.data_f32().len();
313        let lora_params = self.lora_a.data_f32().len() + self.lora_b.data_f32().len();
314        
315        let base_memory = (base_params as f32) * (self.config.bits as f32 / 32.0); // Quantized
316        let lora_memory = lora_params as f32; // Full precision
317        let full_memory = base_params as f32; // Full precision
318        
319        (full_memory - (base_memory + lora_memory)) / full_memory
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    
327    #[test]
328    fn test_lora_config() {
329        let config = LoRAConfig::default();
330        assert_eq!(config.rank, 8);
331        assert_eq!(config.alpha, 16.0);
332        
333        let low = LoRAConfig::low_rank();
334        assert_eq!(low.rank, 4);
335    }
336    
337    #[test]
338    fn test_lora_linear() {
339        let config = LoRAConfig::default();
340        let layer = LoRALinear::new(128, 64, config);
341        
342        assert_eq!(layer.rank(), 8);
343        assert!(!layer.merged);
344        
345        let input = Tensor::randn(&[4, 128]);
346        let output = layer.forward(&input);
347        assert_eq!(output.dims(), &[4, 64]);
348    }
349    
350    #[test]
351    fn test_lora_parameters() {
352        let config = LoRAConfig::default();
353        let layer = LoRALinear::new(128, 64, config);
354        
355        let params = layer.lora_parameters();
356        assert_eq!(params.len(), 2);
357        assert_eq!(params[0].dims(), &[128, 8]); // A matrix
358        assert_eq!(params[1].dims(), &[8, 64]);  // B matrix
359    }
360    
361    #[test]
362    fn test_lora_merge_unmerge() {
363        let config = LoRAConfig::default();
364        let mut layer = LoRALinear::new(128, 64, config);
365        
366        assert!(!layer.merged);
367        
368        layer.merge_weights();
369        assert!(layer.merged);
370        
371        layer.unmerge_weights();
372        assert!(!layer.merged);
373    }
374    
375    #[test]
376    fn test_qlora_config() {
377        let config = QLoRAConfig::default();
378        assert_eq!(config.bits, 4);
379        assert_eq!(config.quant_type, QuantType::NF4);
380        assert!(config.double_quant);
381    }
382    
383    #[test]
384    fn test_qlora_linear() {
385        let config = QLoRAConfig::default();
386        let layer = QLoRALinear::new(128, 64, config);
387        
388        let input = Tensor::randn(&[4, 128]);
389        let output = layer.forward(&input);
390        assert_eq!(output.dims(), &[4, 64]);
391    }
392    
393    #[test]
394    fn test_quantization() {
395        let weight = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
396        let (quantized, scale, zero_point) = QLoRALinear::quantize_weight(&weight, 4);
397        
398        assert!(scale > 0.0);
399        assert_eq!(quantized.dims(), &[2, 2]);
400    }
401    
402    #[test]
403    fn test_dequantization() {
404        let config = QLoRAConfig::default();
405        let layer = QLoRALinear::new(4, 4, config);
406        
407        let dequantized = layer.dequantize_weight();
408        assert_eq!(dequantized.dims(), layer.quantized_weight.dims());
409    }
410    
411    #[test]
412    fn test_memory_savings() {
413        let config = QLoRAConfig::default();
414        let layer = QLoRALinear::new(1024, 1024, config);
415        
416        let savings = layer.memory_savings_ratio();
417        assert!(savings > 0.0);
418        assert!(savings < 1.0);
419    }
420    
421    #[test]
422    fn test_lora_scaling() {
423        let config = LoRAConfig {
424            rank: 8,
425            alpha: 16.0,
426            ..Default::default()
427        };
428        
429        let layer = LoRALinear::new(64, 32, config);
430        assert_eq!(layer.scaling(), 2.0); // alpha / rank = 16 / 8
431    }
432}