1use ghostflow_core::Tensor;
10use crate::linear::Linear;
11use crate::Module;
12
13#[derive(Debug, Clone)]
15pub struct LoRAConfig {
16 pub rank: usize,
18 pub alpha: f32,
20 pub dropout: f32,
22 pub use_bias: bool,
24}
25
26impl Default for LoRAConfig {
27 fn default() -> Self {
28 LoRAConfig {
29 rank: 8,
30 alpha: 16.0,
31 dropout: 0.0,
32 use_bias: false,
33 }
34 }
35}
36
37impl LoRAConfig {
38 pub fn low_rank() -> Self {
40 LoRAConfig {
41 rank: 4,
42 alpha: 8.0,
43 ..Default::default()
44 }
45 }
46
47 pub fn medium_rank() -> Self {
49 Self::default()
50 }
51
52 pub fn high_rank() -> Self {
54 LoRAConfig {
55 rank: 16,
56 alpha: 32.0,
57 ..Default::default()
58 }
59 }
60}
61
62pub struct LoRALinear {
64 base_layer: Linear,
66 lora_a: Tensor,
68 lora_b: Tensor,
70 scaling: f32,
72 config: LoRAConfig,
74 merged: bool,
76}
77
78impl LoRALinear {
79 pub fn new(in_features: usize, out_features: usize, config: LoRAConfig) -> Self {
81 let base_layer = Linear::new(in_features, out_features);
82
83 let lora_a = Tensor::randn(&[in_features, config.rank]);
86 let lora_b = Tensor::zeros(&[config.rank, out_features]);
88
89 let scaling = config.alpha / config.rank as f32;
91
92 LoRALinear {
93 base_layer,
94 lora_a,
95 lora_b,
96 scaling,
97 config,
98 merged: false,
99 }
100 }
101
102 pub fn forward(&self, x: &Tensor) -> Tensor {
104 let base_output = self.base_layer.forward(x);
106
107 if self.merged {
108 return base_output;
110 }
111
112 let lora_output = self.compute_lora_output(x);
114
115 base_output.add(&lora_output).unwrap_or(base_output)
117 }
118
119 fn compute_lora_output(&self, x: &Tensor) -> Tensor {
121 let intermediate = x.matmul(&self.lora_a).unwrap_or_else(|_| x.clone());
123
124 let lora_out = intermediate.matmul(&self.lora_b).unwrap_or(intermediate);
126
127 lora_out.mul_scalar(self.scaling)
129 }
130
131 pub fn merge_weights(&mut self) {
133 if self.merged {
134 return;
135 }
136
137 let lora_weight = self.lora_a.matmul(&self.lora_b)
139 .map(|w| w.mul_scalar(self.scaling))
140 .unwrap_or_else(|_| Tensor::zeros(&[self.lora_a.dims()[0], self.lora_b.dims()[1]]));
141
142 self.merged = true;
145 }
146
147 pub fn unmerge_weights(&mut self) {
149 if !self.merged {
150 return;
151 }
152
153 self.merged = false;
156 }
157
158 pub fn lora_parameters(&self) -> Vec<Tensor> {
160 vec![self.lora_a.clone(), self.lora_b.clone()]
161 }
162
163 pub fn rank(&self) -> usize {
165 self.config.rank
166 }
167
168 pub fn scaling(&self) -> f32 {
170 self.scaling
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct QLoRAConfig {
177 pub lora_config: LoRAConfig,
179 pub bits: usize,
181 pub double_quant: bool,
183 pub quant_type: QuantType,
185}
186
187#[derive(Debug, Clone, Copy, PartialEq)]
189pub enum QuantType {
190 NF4,
192 FP4,
194 INT8,
196}
197
198impl Default for QLoRAConfig {
199 fn default() -> Self {
200 QLoRAConfig {
201 lora_config: LoRAConfig::default(),
202 bits: 4,
203 double_quant: true,
204 quant_type: QuantType::NF4,
205 }
206 }
207}
208
209pub struct QLoRALinear {
211 quantized_weight: Tensor,
213 scale: f32,
215 zero_point: f32,
217 lora_a: Tensor,
219 lora_b: Tensor,
221 scaling: f32,
223 config: QLoRAConfig,
225}
226
227impl QLoRALinear {
228 pub fn new(in_features: usize, out_features: usize, config: QLoRAConfig) -> Self {
230 let base_weight = Tensor::randn(&[out_features, in_features]);
232 let (quantized_weight, scale, zero_point) = Self::quantize_weight(&base_weight, config.bits);
233
234 let lora_a = Tensor::randn(&[in_features, config.lora_config.rank]);
236 let lora_b = Tensor::zeros(&[config.lora_config.rank, out_features]);
237
238 let scaling = config.lora_config.alpha / config.lora_config.rank as f32;
239
240 QLoRALinear {
241 quantized_weight,
242 scale,
243 zero_point,
244 lora_a,
245 lora_b,
246 scaling,
247 config,
248 }
249 }
250
251 fn quantize_weight(weight: &Tensor, bits: usize) -> (Tensor, f32, f32) {
253 let data = weight.data_f32();
254 let dims = weight.dims();
255
256 let min_val = data.iter().cloned().fold(f32::INFINITY, f32::min);
258 let max_val = data.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
259
260 let qmin = 0.0;
261 let qmax = (1 << bits) as f32 - 1.0;
262
263 let scale = (max_val - min_val) / (qmax - qmin);
264 let zero_point = qmin - min_val / scale;
265
266 let quantized: Vec<f32> = data.iter().map(|&x| {
268 let q = (x / scale + zero_point).round().clamp(qmin, qmax);
269 q
270 }).collect();
271
272 (Tensor::from_slice(&quantized, dims).unwrap(), scale, zero_point)
273 }
274
275 fn dequantize_weight(&self) -> Tensor {
277 let data = self.quantized_weight.data_f32();
278 let dims = self.quantized_weight.dims();
279
280 let dequantized: Vec<f32> = data.iter().map(|&q| {
281 (q - self.zero_point) * self.scale
282 }).collect();
283
284 Tensor::from_slice(&dequantized, dims).unwrap()
285 }
286
287 pub fn forward(&self, x: &Tensor) -> Tensor {
289 let base_weight = self.dequantize_weight();
291
292 let base_output = x.matmul(&base_weight.t().unwrap()).unwrap_or_else(|_| x.clone());
294
295 let lora_output = x.matmul(&self.lora_a)
297 .and_then(|intermediate| intermediate.matmul(&self.lora_b))
298 .map(|out| out.mul_scalar(self.scaling))
299 .unwrap_or_else(|_| Tensor::zeros(base_output.dims()));
300
301 base_output.add(&lora_output).unwrap_or(base_output)
303 }
304
305 pub fn lora_parameters(&self) -> Vec<Tensor> {
307 vec![self.lora_a.clone(), self.lora_b.clone()]
308 }
309
310 pub fn memory_savings_ratio(&self) -> f32 {
312 let base_params = self.quantized_weight.data_f32().len();
313 let lora_params = self.lora_a.data_f32().len() + self.lora_b.data_f32().len();
314
315 let base_memory = (base_params as f32) * (self.config.bits as f32 / 32.0); let lora_memory = lora_params as f32; let full_memory = base_params as f32; (full_memory - (base_memory + lora_memory)) / full_memory
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_lora_config() {
329 let config = LoRAConfig::default();
330 assert_eq!(config.rank, 8);
331 assert_eq!(config.alpha, 16.0);
332
333 let low = LoRAConfig::low_rank();
334 assert_eq!(low.rank, 4);
335 }
336
337 #[test]
338 fn test_lora_linear() {
339 let config = LoRAConfig::default();
340 let layer = LoRALinear::new(128, 64, config);
341
342 assert_eq!(layer.rank(), 8);
343 assert!(!layer.merged);
344
345 let input = Tensor::randn(&[4, 128]);
346 let output = layer.forward(&input);
347 assert_eq!(output.dims(), &[4, 64]);
348 }
349
350 #[test]
351 fn test_lora_parameters() {
352 let config = LoRAConfig::default();
353 let layer = LoRALinear::new(128, 64, config);
354
355 let params = layer.lora_parameters();
356 assert_eq!(params.len(), 2);
357 assert_eq!(params[0].dims(), &[128, 8]); assert_eq!(params[1].dims(), &[8, 64]); }
360
361 #[test]
362 fn test_lora_merge_unmerge() {
363 let config = LoRAConfig::default();
364 let mut layer = LoRALinear::new(128, 64, config);
365
366 assert!(!layer.merged);
367
368 layer.merge_weights();
369 assert!(layer.merged);
370
371 layer.unmerge_weights();
372 assert!(!layer.merged);
373 }
374
375 #[test]
376 fn test_qlora_config() {
377 let config = QLoRAConfig::default();
378 assert_eq!(config.bits, 4);
379 assert_eq!(config.quant_type, QuantType::NF4);
380 assert!(config.double_quant);
381 }
382
383 #[test]
384 fn test_qlora_linear() {
385 let config = QLoRAConfig::default();
386 let layer = QLoRALinear::new(128, 64, config);
387
388 let input = Tensor::randn(&[4, 128]);
389 let output = layer.forward(&input);
390 assert_eq!(output.dims(), &[4, 64]);
391 }
392
393 #[test]
394 fn test_quantization() {
395 let weight = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
396 let (quantized, scale, zero_point) = QLoRALinear::quantize_weight(&weight, 4);
397
398 assert!(scale > 0.0);
399 assert_eq!(quantized.dims(), &[2, 2]);
400 }
401
402 #[test]
403 fn test_dequantization() {
404 let config = QLoRAConfig::default();
405 let layer = QLoRALinear::new(4, 4, config);
406
407 let dequantized = layer.dequantize_weight();
408 assert_eq!(dequantized.dims(), layer.quantized_weight.dims());
409 }
410
411 #[test]
412 fn test_memory_savings() {
413 let config = QLoRAConfig::default();
414 let layer = QLoRALinear::new(1024, 1024, config);
415
416 let savings = layer.memory_savings_ratio();
417 assert!(savings > 0.0);
418 assert!(savings < 1.0);
419 }
420
421 #[test]
422 fn test_lora_scaling() {
423 let config = LoRAConfig {
424 rank: 8,
425 alpha: 16.0,
426 ..Default::default()
427 };
428
429 let layer = LoRALinear::new(64, 32, config);
430 assert_eq!(layer.scaling(), 2.0); }
432}