1use crate::autograd::matmul;
10use crate::lora::LoRALayer;
11use crate::quant::{
12 dequantize_4bit, dequantize_4bit_double, quantize_4bit, quantize_4bit_double,
13 DoubleQuantized4Bit, Quantized4Bit,
14};
15use crate::Tensor;
16
17pub struct QLoRALayer {
23 base_weight_quantized: Quantized4Bit,
25 base_weight_double: Option<DoubleQuantized4Bit>,
27 lora_a: Tensor,
29 lora_b: Tensor,
31 d_out: usize,
33 d_in: usize,
35 rank: usize,
37 scale: f32,
39 merged: bool,
41}
42
43impl QLoRALayer {
44 pub fn from_lora(lora_layer: LoRALayer) -> Self {
52 let base_weight_data = lora_layer.base_weight().data().to_vec();
53 let base_weight_quantized = quantize_4bit(&base_weight_data);
54
55 Self {
56 base_weight_quantized,
57 base_weight_double: None,
58 lora_a: lora_layer.lora_a().clone(),
59 lora_b: lora_layer.lora_b().clone(),
60 d_out: lora_layer.d_out(),
61 d_in: lora_layer.d_in(),
62 rank: lora_layer.rank(),
63 scale: lora_layer.scale(),
64 merged: false,
65 }
66 }
67
68 pub fn from_lora_double_quant(lora_layer: LoRALayer) -> Self {
70 let base_weight_data = lora_layer.base_weight().data().to_vec();
71 let base_weight_quantized = quantize_4bit(&base_weight_data);
72 let base_weight_double = Some(quantize_4bit_double(&base_weight_data));
73
74 Self {
75 base_weight_quantized,
76 base_weight_double,
77 lora_a: lora_layer.lora_a().clone(),
78 lora_b: lora_layer.lora_b().clone(),
79 d_out: lora_layer.d_out(),
80 d_in: lora_layer.d_in(),
81 rank: lora_layer.rank(),
82 scale: lora_layer.scale(),
83 merged: false,
84 }
85 }
86
87 pub fn new(base_weight: Tensor, d_out: usize, d_in: usize, rank: usize, alpha: f32) -> Self {
96 let lora_layer = LoRALayer::new(base_weight, d_out, d_in, rank, alpha);
98 Self::from_lora(lora_layer)
99 }
100
101 pub fn forward(&self, x: &Tensor) -> Tensor {
109 assert_eq!(x.len(), self.d_in, "Input size must match d_in");
113
114 let base_weight_data = if let Some(ref dq) = self.base_weight_double {
116 dequantize_4bit_double(dq)
117 } else {
118 dequantize_4bit(&self.base_weight_quantized)
119 };
120 let base_weight = Tensor::new(ndarray::arr1(&base_weight_data), false);
121
122 let base_output = matmul(&base_weight, x, self.d_out, self.d_in, 1);
124
125 if self.merged {
126 base_output
127 } else {
128 let lora_out_a = matmul(&self.lora_a, x, self.rank, self.d_in, 1);
130 let lora_out_b = matmul(&self.lora_b, &lora_out_a, self.d_out, self.rank, 1);
131
132 let mut scaled_lora_data = lora_out_b.data().to_owned();
134 for val in &mut scaled_lora_data {
135 *val *= self.scale;
136 }
137 let scaled_lora = Tensor::new(scaled_lora_data, false);
138
139 let mut result_data = base_output.data().to_owned();
140 for (i, val) in result_data.iter_mut().enumerate() {
141 *val += scaled_lora.data()[i];
142 }
143 Tensor::new(result_data, base_output.requires_grad())
144 }
145 }
146
147 pub fn lora_a(&self) -> &Tensor {
149 &self.lora_a
150 }
151
152 pub fn lora_a_mut(&mut self) -> &mut Tensor {
154 &mut self.lora_a
155 }
156
157 pub fn lora_b(&self) -> &Tensor {
159 &self.lora_b
160 }
161
162 pub fn lora_b_mut(&mut self) -> &mut Tensor {
164 &mut self.lora_b
165 }
166
167 pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
169 vec![&mut self.lora_a, &mut self.lora_b]
170 }
171
172 pub fn rank(&self) -> usize {
174 self.rank
175 }
176
177 pub fn scale(&self) -> f32 {
179 self.scale
180 }
181
182 pub fn d_out(&self) -> usize {
184 self.d_out
185 }
186
187 pub fn d_in(&self) -> usize {
189 self.d_in
190 }
191
192 pub fn memory_stats(&self) -> MemoryStats {
194 let base_unquantized_bytes = self.d_out * self.d_in * 4; let base_quantized_bytes = if let Some(ref dq) = self.base_weight_double {
196 dq.memory_bytes()
197 } else {
198 self.base_weight_quantized.memory_bytes()
199 };
200 let lora_a_bytes = self.lora_a.len() * 4;
201 let lora_b_bytes = self.lora_b.len() * 4;
202
203 MemoryStats {
204 base_unquantized_bytes,
205 base_quantized_bytes,
206 lora_bytes: lora_a_bytes + lora_b_bytes,
207 total_bytes: base_quantized_bytes + lora_a_bytes + lora_b_bytes,
208 compression_ratio: base_unquantized_bytes as f32 / base_quantized_bytes.max(1) as f32,
209 }
210 }
211
212 pub fn is_merged(&self) -> bool {
214 self.merged
215 }
216
217 pub fn merge_to_f32(&self) -> Vec<f32> {
224 let mut merged = if let Some(ref dq) = self.base_weight_double {
226 dequantize_4bit_double(dq)
227 } else {
228 dequantize_4bit(&self.base_weight_quantized)
229 };
230
231 let a_data = self.lora_a.data();
235 let b_data = self.lora_b.data();
236
237 for row in 0..self.d_out {
238 for col in 0..self.d_in {
239 let mut sum = 0.0f32;
240 for r in 0..self.rank {
241 let b_val = b_data[row * self.rank + r];
242 let a_val = a_data[r * self.d_in + col];
243 sum += b_val * a_val;
244 }
245 merged[row * self.d_in + col] += self.scale * sum;
246 }
247 }
248
249 merged
250 }
251
252 pub fn base_weight_quantized(&self) -> &Quantized4Bit {
254 &self.base_weight_quantized
255 }
256
257 pub fn is_double_quantized(&self) -> bool {
259 self.base_weight_double.is_some()
260 }
261}
262
263#[derive(Debug, Clone)]
265pub struct MemoryStats {
266 pub base_unquantized_bytes: usize,
268 pub base_quantized_bytes: usize,
270 pub lora_bytes: usize,
272 pub total_bytes: usize,
274 pub compression_ratio: f32,
276}
277
278#[cfg(test)]
279mod tests {
280 use super::*;
281 use approx::assert_abs_diff_eq;
282 use proptest::prelude::*;
283
284 proptest! {
289 #![proptest_config(proptest::test_runner::Config::with_cases(200))]
290
291 #[test]
293 fn prop_qlora_memory_savings_consistent(
294 d in 8usize..32,
295 rank in 1usize..8,
296 alpha in 1.0f32..32.0
297 ) {
298 let size = d * d;
299 let base_weight = Tensor::from_vec(vec![0.5; size], false);
300 let qlora = QLoRALayer::new(base_weight, d, d, rank, alpha);
301
302 let stats = qlora.memory_stats();
303
304 prop_assert!(stats.base_quantized_bytes <= stats.base_unquantized_bytes);
306
307 prop_assert!(stats.compression_ratio >= 1.0);
309
310 prop_assert_eq!(
312 stats.total_bytes,
313 stats.base_quantized_bytes + stats.lora_bytes
314 );
315
316 let expected_lora_bytes = (d * rank + d * rank) * 4;
318 prop_assert_eq!(stats.lora_bytes, expected_lora_bytes);
319 }
320
321 #[test]
323 fn prop_lora_params_preserved_after_quantization(
324 d_out in 4usize..16,
325 d_in in 4usize..16,
326 rank in 1usize..4,
327 alpha in 1.0f32..16.0
328 ) {
329 let size = d_out * d_in;
330 let base_weight = Tensor::from_vec(vec![1.0; size], false);
331 let lora = LoRALayer::new(base_weight.clone(), d_out, d_in, rank, alpha);
332
333 let qlora = QLoRALayer::from_lora(lora.clone());
334
335 prop_assert_eq!(qlora.d_out(), lora.d_out());
337 prop_assert_eq!(qlora.d_in(), lora.d_in());
338 prop_assert_eq!(qlora.rank(), lora.rank());
339
340 prop_assert!((qlora.scale() - lora.scale()).abs() < 1e-6);
342
343 prop_assert_eq!(qlora.lora_a().data().len(), lora.lora_a().data().len());
345 prop_assert_eq!(qlora.lora_b().data().len(), lora.lora_b().data().len());
346
347 for (a, b) in qlora.lora_a().data().iter().zip(lora.lora_a().data().iter()) {
348 prop_assert!((a - b).abs() < 1e-6);
349 }
350 for (a, b) in qlora.lora_b().data().iter().zip(lora.lora_b().data().iter()) {
351 prop_assert!((a - b).abs() < 1e-6);
352 }
353 }
354
355 #[test]
357 fn prop_quantization_error_bounded(
358 d in 8usize..24,
359 ) {
360 let size = d * d;
361 let base_weight = Tensor::from_vec(
363 (0..size).map(|i| ((i % 16) as f32 - 8.0) * 0.1).collect(),
364 false
365 );
366 let lora = LoRALayer::new(base_weight.clone(), d, d, 2, 4.0);
367 let qlora = QLoRALayer::from_lora(lora.clone());
368
369 let x = Tensor::from_vec(vec![0.1; d], true);
371 let lora_out = lora.forward(&x);
372 let qlora_out = qlora.forward(&x);
373
374 prop_assert_eq!(lora_out.len(), qlora_out.len());
376 for i in 0..lora_out.len() {
377 let diff = (lora_out.data()[i] - qlora_out.data()[i]).abs();
378 let max_diff = lora_out.data()[i].abs() * 0.3 + 0.5;
380 prop_assert!(
381 diff < max_diff,
382 "Quantization error {} > {} at index {}",
383 diff, max_diff, i
384 );
385 }
386 }
387
388 #[test]
390 fn prop_forward_dimensions_correct(
391 d_out in 4usize..16,
392 d_in in 4usize..16,
393 rank in 1usize..4,
394 ) {
395 let size = d_out * d_in;
396 let base_weight = Tensor::from_vec(vec![1.0; size], false);
397 let qlora = QLoRALayer::new(base_weight, d_out, d_in, rank, 4.0);
398
399 let x = Tensor::from_vec(vec![0.5; d_in], true);
400 let output = qlora.forward(&x);
401
402 prop_assert_eq!(output.len(), d_out);
403 }
404
405 #[test]
407 fn prop_trainable_params_dimensions(
408 d_out in 4usize..16,
409 d_in in 4usize..16,
410 rank in 1usize..4,
411 ) {
412 let size = d_out * d_in;
413 let base_weight = Tensor::from_vec(vec![1.0; size], false);
414 let mut qlora = QLoRALayer::new(base_weight, d_out, d_in, rank, 4.0);
415
416 let params = qlora.trainable_params();
417 prop_assert_eq!(params.len(), 2);
418
419 prop_assert_eq!(params[0].len(), rank * d_in);
421 prop_assert_eq!(params[1].len(), d_out * rank);
423
424 prop_assert!(params[0].requires_grad());
426 prop_assert!(params[1].requires_grad());
427 }
428 }
429
430 #[test]
435 fn test_qlora_creation() {
436 let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
437 let qlora = QLoRALayer::new(base_weight, 2, 2, 1, 2.0);
438
439 assert_eq!(qlora.rank(), 1);
440 assert_eq!(qlora.d_out(), 2);
441 assert_eq!(qlora.d_in(), 2);
442 assert_abs_diff_eq!(qlora.scale(), 2.0, epsilon = 1e-6); assert!(!qlora.is_merged());
444 }
445
446 #[test]
447 fn test_qlora_forward_matches_lora() {
448 let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
450 let mut lora = LoRALayer::new(base_weight.clone(), 2, 2, 1, 1.0);
451 *lora.lora_a_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
452 *lora.lora_b_mut().data_mut() = ndarray::arr1(&[0.3, 0.3]);
453
454 let mut qlora = QLoRALayer::new(base_weight, 2, 2, 1, 1.0);
455 *qlora.lora_a_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
456 *qlora.lora_b_mut().data_mut() = ndarray::arr1(&[0.3, 0.3]);
457
458 let x = Tensor::from_vec(vec![2.0, 3.0], true);
459
460 let lora_output = lora.forward(&x);
461 let qlora_output = qlora.forward(&x);
462
463 assert_eq!(lora_output.len(), qlora_output.len());
465 for i in 0..lora_output.len() {
466 let diff = (lora_output.data()[i] - qlora_output.data()[i]).abs();
467 assert!(
468 diff < 0.2,
469 "Output mismatch at {}: {} vs {} (diff: {})",
470 i,
471 lora_output.data()[i],
472 qlora_output.data()[i],
473 diff
474 );
475 }
476 }
477
478 #[test]
479 fn test_qlora_memory_savings() {
480 let d = 16; let size = d * d;
483 let base_weight = Tensor::from_vec(vec![1.0; size], false);
484 let qlora = QLoRALayer::new(base_weight, d, d, 8, 16.0);
485
486 let stats = qlora.memory_stats();
487
488 assert!(
490 stats.base_quantized_bytes < stats.base_unquantized_bytes,
491 "Quantized should use less memory"
492 );
493
494 assert!(
496 stats.compression_ratio > 6.0,
497 "Compression ratio {} should be > 6.0",
498 stats.compression_ratio
499 );
500 }
501
502 #[test]
503 fn test_qlora_trainable_params() {
504 let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
505 let mut qlora = QLoRALayer::new(base_weight, 2, 2, 2, 4.0);
506
507 let params = qlora.trainable_params();
508 assert_eq!(params.len(), 2);
509
510 assert!(params[0].requires_grad());
512 assert!(params[1].requires_grad());
513 }
514
515 #[test]
516 fn test_qlora_from_lora() {
517 let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
518 let lora = LoRALayer::new(base_weight, 3, 2, 2, 8.0);
519
520 let qlora = QLoRALayer::from_lora(lora);
521
522 assert_eq!(qlora.rank(), 2);
523 assert_eq!(qlora.d_out(), 3);
524 assert_eq!(qlora.d_in(), 2);
525 assert_abs_diff_eq!(qlora.scale(), 4.0, epsilon = 1e-6); }
527
528 #[test]
529 fn test_qlora_merge_to_f32_dimensions() {
530 let d_out = 8;
531 let d_in = 16;
532 let base_weight = Tensor::from_vec(vec![1.0; d_out * d_in], false);
533 let qlora = QLoRALayer::new(base_weight, d_out, d_in, 4, 8.0);
534
535 let merged = qlora.merge_to_f32();
536 assert_eq!(merged.len(), d_out * d_in);
537 }
538
539 #[test]
540 fn test_qlora_merge_to_f32_includes_adapter() {
541 let d_out = 4;
542 let d_in = 4;
543 let base_weight = Tensor::from_vec(vec![0.0; d_out * d_in], false);
544 let mut qlora = QLoRALayer::new(base_weight, d_out, d_in, 2, 2.0);
545
546 *qlora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
548 *qlora.lora_b_mut().data_mut() = ndarray::arr1(&[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
549
550 let merged = qlora.merge_to_f32();
551
552 let adapter_contribution: f32 = merged.iter().map(|v| v.abs()).sum();
554 assert!(adapter_contribution > 0.0, "Merged weights should include adapter contribution");
555 }
556
557 #[test]
558 fn test_qlora_merge_to_f32_equivalence_with_lora() {
559 let d_out = 4;
561 let d_in = 4;
562 let base_data = vec![
563 0.5, 0.3, -0.2, 0.1, 0.4, -0.1, 0.6, 0.2, -0.3, 0.5, 0.1, -0.4, 0.2, 0.3, -0.5, 0.6,
564 ];
565 let base_weight = Tensor::from_vec(base_data.clone(), false);
566 let mut lora = LoRALayer::new(base_weight.clone(), d_out, d_in, 2, 4.0);
567
568 let a_data = vec![0.1, 0.2, -0.1, 0.3, 0.2, -0.2, 0.1, 0.1];
569 let b_data = vec![0.3, -0.1, 0.2, 0.1, -0.2, 0.3, 0.1, -0.1];
570 *lora.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
571 *lora.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
572
573 let mut qlora = QLoRALayer::from_lora(lora.clone());
574 *qlora.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
576 *qlora.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
577
578 lora.merge();
580 let lora_merged: Vec<f32> = lora.base_weight().data().to_vec();
581 let qlora_merged = qlora.merge_to_f32();
582
583 assert_eq!(lora_merged.len(), qlora_merged.len());
584 for i in 0..lora_merged.len() {
585 let diff = (lora_merged[i] - qlora_merged[i]).abs();
586 assert!(
587 diff < 0.5,
588 "Merge difference too large at {i}: lora={}, qlora={}, diff={diff}",
589 lora_merged[i],
590 qlora_merged[i]
591 );
592 }
593 }
594
595 #[test]
596 fn test_qlora_large_matrix() {
597 let d_model = 256;
599 let base_weight = Tensor::from_vec(vec![1.0; d_model * d_model], false);
600 let qlora = QLoRALayer::new(base_weight, d_model, d_model, 16, 32.0);
601
602 let x = Tensor::from_vec(vec![0.5; d_model], true);
603 let output = qlora.forward(&x);
604
605 assert_eq!(output.len(), d_model);
606
607 let stats = qlora.memory_stats();
609 let savings_percent =
610 (1.0 - stats.base_quantized_bytes as f32 / stats.base_unquantized_bytes as f32) * 100.0;
611
612 assert!(savings_percent > 70.0, "Should save > 70% memory, got {savings_percent}%");
613 }
614
615 #[test]
620 fn test_ent_lora_008_double_quant_creation() {
621 let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
622 let lora = LoRALayer::new(base_weight, 2, 2, 1, 2.0);
623 let qlora = QLoRALayer::from_lora_double_quant(lora);
624
625 assert!(qlora.is_double_quantized());
626 assert_eq!(qlora.d_out(), 2);
627 assert_eq!(qlora.d_in(), 2);
628 }
629
630 #[test]
631 fn test_ent_lora_008_double_quant_forward_close_to_single() {
632 let d = 64;
633 let base_weight =
634 Tensor::from_vec((0..d * d).map(|i| (i as f32 * 0.1).sin() * 2.0).collect(), false);
635 let lora = LoRALayer::new(base_weight, d, d, 4, 8.0);
636
637 let single = QLoRALayer::from_lora(lora.clone());
638 let double = QLoRALayer::from_lora_double_quant(lora);
639
640 let x = Tensor::from_vec(vec![0.1; d], true);
641 let single_out = single.forward(&x);
642 let double_out = double.forward(&x);
643
644 assert_eq!(single_out.len(), double_out.len());
645 for i in 0..single_out.len() {
646 let diff = (single_out.data()[i] - double_out.data()[i]).abs();
647 let tol = single_out.data()[i].abs() * 0.01 + 0.1;
648 assert!(
649 diff <= tol,
650 "Forward output diverged at [{i}]: single={}, double={}, diff={diff}",
651 single_out.data()[i],
652 double_out.data()[i]
653 );
654 }
655 }
656
657 #[test]
658 fn test_ent_lora_008_single_quant_not_double() {
659 let base_weight = Tensor::from_vec(vec![1.0; 16], false);
660 let qlora = QLoRALayer::new(base_weight, 4, 4, 2, 4.0);
661 assert!(!qlora.is_double_quantized());
662 }
663
664 #[test]
665 fn test_ent_lora_008_double_quant_memory_stats() {
666 let d = 256;
667 let base_weight = Tensor::from_vec(vec![1.0; d * d], false);
668 let lora = LoRALayer::new(base_weight, d, d, 16, 32.0);
669
670 let single = QLoRALayer::from_lora(lora.clone());
671 let double = QLoRALayer::from_lora_double_quant(lora);
672
673 let single_stats = single.memory_stats();
674 let double_stats = double.memory_stats();
675
676 assert!(
678 double_stats.base_quantized_bytes <= single_stats.base_quantized_bytes,
679 "Double quant ({}) should use <= memory than single ({})",
680 double_stats.base_quantized_bytes,
681 single_stats.base_quantized_bytes
682 );
683 }
684
685 #[test]
690 fn test_qlora_merge_to_f32_double_quant() {
691 let d_out = 8;
693 let d_in = 8;
694 let base_weight = Tensor::from_vec(
695 (0..d_out * d_in).map(|i| (i as f32 * 0.2).sin() * 0.5).collect(),
696 false,
697 );
698 let lora = LoRALayer::new(base_weight, d_out, d_in, 2, 4.0);
699 let qlora_dq = QLoRALayer::from_lora_double_quant(lora);
700
701 assert!(qlora_dq.is_double_quantized());
702
703 let merged = qlora_dq.merge_to_f32();
704 assert_eq!(merged.len(), d_out * d_in);
705
706 for val in &merged {
708 assert!(val.is_finite(), "Merged weight must be finite, got {val}");
709 }
710 }
711
712 #[test]
713 fn test_qlora_merge_to_f32_single_vs_double_close() {
714 let d_out = 8;
716 let d_in = 8;
717 let base_data: Vec<f32> =
718 (0..d_out * d_in).map(|i| (i as f32 * 0.15).cos() * 0.3).collect();
719 let base_weight = Tensor::from_vec(base_data, false);
720
721 let lora = LoRALayer::new(base_weight, d_out, d_in, 2, 4.0);
722 let single = QLoRALayer::from_lora(lora.clone());
723 let double = QLoRALayer::from_lora_double_quant(lora);
724
725 let merged_single = single.merge_to_f32();
726 let merged_double = double.merge_to_f32();
727
728 assert_eq!(merged_single.len(), merged_double.len());
729 for i in 0..merged_single.len() {
730 let diff = (merged_single[i] - merged_double[i]).abs();
731 let tol = merged_single[i].abs() * 0.05 + 0.2;
732 assert!(
733 diff <= tol,
734 "merge_to_f32 single vs double diverged at [{i}]: single={}, double={}, diff={diff}",
735 merged_single[i],
736 merged_double[i]
737 );
738 }
739 }
740
741 #[test]
742 fn test_qlora_base_weight_quantized_accessor() {
743 let d = 8;
745 let base_weight = Tensor::from_vec(vec![1.0; d * d], false);
746 let qlora = QLoRALayer::new(base_weight, d, d, 2, 4.0);
747
748 let quantized = qlora.base_weight_quantized();
749 assert!(quantized.memory_bytes() > 0, "Quantized base weight should use memory");
751 }
752
753 #[test]
754 fn test_qlora_double_quant_forward_with_known_adapter() {
755 let d_out = 4;
757 let d_in = 4;
758 let base_weight = Tensor::from_vec(vec![0.5; d_out * d_in], false);
759 let lora = LoRALayer::new(base_weight, d_out, d_in, 2, 4.0);
760 let mut qlora = QLoRALayer::from_lora_double_quant(lora);
761
762 assert!(qlora.is_double_quantized());
763
764 let a_data: Vec<f32> = (0..2 * d_in).map(|i| (i as f32 * 0.1).sin() * 0.5).collect();
766 let b_data: Vec<f32> = (0..d_out * 2).map(|i| (i as f32 * 0.2).cos() * 0.3).collect();
767 *qlora.lora_a_mut().data_mut() = ndarray::Array1::from_vec(a_data);
768 *qlora.lora_b_mut().data_mut() = ndarray::Array1::from_vec(b_data);
769
770 let x = Tensor::from_vec(vec![1.0; d_in], true);
771 let output = qlora.forward(&x);
772
773 assert_eq!(output.len(), d_out);
774 for val in output.data() {
775 assert!(val.is_finite(), "Forward output must be finite, got {val}");
776 }
777 }
778
779 #[test]
780 fn test_qlora_memory_stats_double_quant() {
781 let d = 16;
783 let base_weight = Tensor::from_vec(vec![1.0; d * d], false);
784 let lora = LoRALayer::new(base_weight, d, d, 4, 8.0);
785 let qlora = QLoRALayer::from_lora_double_quant(lora);
786
787 let stats = qlora.memory_stats();
788
789 assert!(stats.base_quantized_bytes > 0);
791 assert!(stats.lora_bytes > 0);
792 assert_eq!(stats.total_bytes, stats.base_quantized_bytes + stats.lora_bytes);
793 assert!(stats.compression_ratio >= 1.0);
794 assert_eq!(stats.base_unquantized_bytes, d * d * 4);
795 }
796
797 #[test]
798 fn test_qlora_memory_stats_clone_and_debug() {
799 let base_weight = Tensor::from_vec(vec![1.0; 16], false);
801 let qlora = QLoRALayer::new(base_weight, 4, 4, 2, 4.0);
802
803 let stats = qlora.memory_stats();
804 let stats_clone = stats.clone();
805
806 assert_eq!(stats.total_bytes, stats_clone.total_bytes);
807 assert_eq!(stats.lora_bytes, stats_clone.lora_bytes);
808 assert_eq!(stats.base_quantized_bytes, stats_clone.base_quantized_bytes);
809
810 let debug_str = format!("{stats_clone:?}");
811 assert!(debug_str.contains("MemoryStats"));
812 }
813
814 #[test]
815 fn test_qlora_lora_a_mut_and_lora_b_mut() {
816 let base_weight = Tensor::from_vec(vec![1.0; 4], false);
818 let mut qlora = QLoRALayer::new(base_weight, 2, 2, 1, 2.0);
819
820 *qlora.lora_a_mut().data_mut() = ndarray::arr1(&[10.0, 20.0]);
822 assert_abs_diff_eq!(qlora.lora_a().data()[0], 10.0, epsilon = 1e-6);
823 assert_abs_diff_eq!(qlora.lora_a().data()[1], 20.0, epsilon = 1e-6);
824
825 *qlora.lora_b_mut().data_mut() = ndarray::arr1(&[30.0, 40.0]);
827 assert_abs_diff_eq!(qlora.lora_b().data()[0], 30.0, epsilon = 1e-6);
828 assert_abs_diff_eq!(qlora.lora_b().data()[1], 40.0, epsilon = 1e-6);
829 }
830}