Skip to main content

entrenar/lora/
qlora.rs

1//! QLoRA (Quantized LoRA) implementation
2//!
3//! QLoRA reduces memory usage by storing frozen base weights in 4-bit quantized format.
4//! During forward pass, weights are dequantized on-the-fly.
5//!
6//! Memory savings: ~75% for base weights (4-bit vs 32-bit)
7//! LoRA adapters remain in full precision for training.
8
9use crate::autograd::matmul;
10use crate::lora::LoRALayer;
11use crate::quant::{
12    dequantize_4bit, dequantize_4bit_double, quantize_4bit, quantize_4bit_double,
13    DoubleQuantized4Bit, Quantized4Bit,
14};
15use crate::Tensor;
16
17/// QLoRA layer with 4-bit quantized base weight
18///
19/// Memory-efficient variant of LoRALayer that stores the frozen base weight
20/// in 4-bit quantized format, reducing memory usage by ~75%.
21/// Optionally uses double quantization (ENT-LoRA-008) for additional savings.
22pub struct QLoRALayer {
23    /// Quantized base weight (4-bit, single quantization)
24    base_weight_quantized: Quantized4Bit,
25    /// Double-quantized base weight (ENT-LoRA-008, None if not enabled)
26    base_weight_double: Option<DoubleQuantized4Bit>,
27    /// LoRA matrix A [r * d_in] - full precision, trainable
28    lora_a: Tensor,
29    /// LoRA matrix B [d_out * r] - full precision, trainable
30    lora_b: Tensor,
31    /// Output dimension
32    d_out: usize,
33    /// Input dimension
34    d_in: usize,
35    /// LoRA rank
36    rank: usize,
37    /// Scaling factor (alpha/rank)
38    scale: f32,
39    /// Whether the adapter is merged (not supported for quantized weights)
40    merged: bool,
41}
42
43impl QLoRALayer {
44    /// Create QLoRA layer from existing LoRA layer
45    ///
46    /// # Arguments
47    /// * `lora_layer` - Existing LoRALayer to convert
48    ///
49    /// # Returns
50    /// QLoRALayer with quantized base weight
51    pub fn from_lora(lora_layer: LoRALayer) -> Self {
52        let base_weight_data = lora_layer.base_weight().data().to_vec();
53        let base_weight_quantized = quantize_4bit(&base_weight_data);
54
55        Self {
56            base_weight_quantized,
57            base_weight_double: None,
58            lora_a: lora_layer.lora_a().clone(),
59            lora_b: lora_layer.lora_b().clone(),
60            d_out: lora_layer.d_out(),
61            d_in: lora_layer.d_in(),
62            rank: lora_layer.rank(),
63            scale: lora_layer.scale(),
64            merged: false,
65        }
66    }
67
68    /// Create QLoRA layer from LoRALayer with double quantization (ENT-LoRA-008)
69    pub fn from_lora_double_quant(lora_layer: LoRALayer) -> Self {
70        let base_weight_data = lora_layer.base_weight().data().to_vec();
71        let base_weight_quantized = quantize_4bit(&base_weight_data);
72        let base_weight_double = Some(quantize_4bit_double(&base_weight_data));
73
74        Self {
75            base_weight_quantized,
76            base_weight_double,
77            lora_a: lora_layer.lora_a().clone(),
78            lora_b: lora_layer.lora_b().clone(),
79            d_out: lora_layer.d_out(),
80            d_in: lora_layer.d_in(),
81            rank: lora_layer.rank(),
82            scale: lora_layer.scale(),
83            merged: false,
84        }
85    }
86
87    /// Create QLoRA layer directly with quantized base weight
88    ///
89    /// # Arguments
90    /// * `base_weight` - Base weight to quantize [d_out * d_in]
91    /// * `d_out` - Output dimension
92    /// * `d_in` - Input dimension
93    /// * `rank` - LoRA rank
94    /// * `alpha` - LoRA alpha parameter
95    pub fn new(base_weight: Tensor, d_out: usize, d_in: usize, rank: usize, alpha: f32) -> Self {
96        // Create LoRALayer first, then convert
97        let lora_layer = LoRALayer::new(base_weight, d_out, d_in, rank, alpha);
98        Self::from_lora(lora_layer)
99    }
100
101    /// Forward pass with on-the-fly dequantization
102    ///
103    /// # Arguments
104    /// * `x` - Input tensor `[d_in]`
105    ///
106    /// # Returns
107    /// Output tensor `[d_out]`
108    pub fn forward(&self, x: &Tensor) -> Tensor {
109        // Contract: inference-pipeline-v1.yaml precondition (pv codegen)
110        // contract_pre_prefill_phase!(x.data()); // TODO: macro not yet generated
111
112        assert_eq!(x.len(), self.d_in, "Input size must match d_in");
113
114        // Dequantize base weight on-the-fly (use double quant path if available)
115        let base_weight_data = if let Some(ref dq) = self.base_weight_double {
116            dequantize_4bit_double(dq)
117        } else {
118            dequantize_4bit(&self.base_weight_quantized)
119        };
120        let base_weight = Tensor::new(ndarray::arr1(&base_weight_data), false);
121
122        // Base forward: W @ x
123        let base_output = matmul(&base_weight, x, self.d_out, self.d_in, 1);
124
125        if self.merged {
126            base_output
127        } else {
128            // LoRA forward: scale * (B @ (A @ x))
129            let lora_out_a = matmul(&self.lora_a, x, self.rank, self.d_in, 1);
130            let lora_out_b = matmul(&self.lora_b, &lora_out_a, self.d_out, self.rank, 1);
131
132            // Scale and add
133            let mut scaled_lora_data = lora_out_b.data().to_owned();
134            for val in &mut scaled_lora_data {
135                *val *= self.scale;
136            }
137            let scaled_lora = Tensor::new(scaled_lora_data, false);
138
139            let mut result_data = base_output.data().to_owned();
140            for (i, val) in result_data.iter_mut().enumerate() {
141                *val += scaled_lora.data()[i];
142            }
143            Tensor::new(result_data, base_output.requires_grad())
144        }
145    }
146
147    /// Get reference to LoRA A matrix
148    pub fn lora_a(&self) -> &Tensor {
149        &self.lora_a
150    }
151
152    /// Get mutable reference to LoRA A matrix
153    pub fn lora_a_mut(&mut self) -> &mut Tensor {
154        &mut self.lora_a
155    }
156
157    /// Get reference to LoRA B matrix
158    pub fn lora_b(&self) -> &Tensor {
159        &self.lora_b
160    }
161
162    /// Get mutable reference to LoRA B matrix
163    pub fn lora_b_mut(&mut self) -> &mut Tensor {
164        &mut self.lora_b
165    }
166
167    /// Get trainable parameters (A and B)
168    pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
169        vec![&mut self.lora_a, &mut self.lora_b]
170    }
171
172    /// Get rank
173    pub fn rank(&self) -> usize {
174        self.rank
175    }
176
177    /// Get scale factor
178    pub fn scale(&self) -> f32 {
179        self.scale
180    }
181
182    /// Get output dimension
183    pub fn d_out(&self) -> usize {
184        self.d_out
185    }
186
187    /// Get input dimension
188    pub fn d_in(&self) -> usize {
189        self.d_in
190    }
191
192    /// Get memory usage statistics
193    pub fn memory_stats(&self) -> MemoryStats {
194        let base_unquantized_bytes = self.d_out * self.d_in * 4; // f32
195        let base_quantized_bytes = if let Some(ref dq) = self.base_weight_double {
196            dq.memory_bytes()
197        } else {
198            self.base_weight_quantized.memory_bytes()
199        };
200        let lora_a_bytes = self.lora_a.len() * 4;
201        let lora_b_bytes = self.lora_b.len() * 4;
202
203        MemoryStats {
204            base_unquantized_bytes,
205            base_quantized_bytes,
206            lora_bytes: lora_a_bytes + lora_b_bytes,
207            total_bytes: base_quantized_bytes + lora_a_bytes + lora_b_bytes,
208            compression_ratio: base_unquantized_bytes as f32 / base_quantized_bytes.max(1) as f32,
209        }
210    }
211
212    /// Check if merged (always false for quantized layers)
213    pub fn is_merged(&self) -> bool {
214        self.merged
215    }
216
217    /// Merge adapter into dequantized base weight, returning full-precision f32 weights
218    ///
219    /// Computes: `dequantize(base_4bit) + scale * B @ A`
220    ///
221    /// This produces a merged weight matrix that can be exported as SafeTensors or GGUF.
222    /// The result is a flat Vec<f32> of shape [d_out, d_in] in row-major layout.
223    pub fn merge_to_f32(&self) -> Vec<f32> {
224        // Dequantize base weight (use double quant path if available)
225        let mut merged = if let Some(ref dq) = self.base_weight_double {
226            dequantize_4bit_double(dq)
227        } else {
228            dequantize_4bit(&self.base_weight_quantized)
229        };
230
231        // Compute scale * B @ A and add to merged weights
232        // A: [rank, d_in], B: [d_out, rank]
233        // B @ A = [d_out, d_in]
234        let a_data = self.lora_a.data();
235        let b_data = self.lora_b.data();
236
237        for row in 0..self.d_out {
238            for col in 0..self.d_in {
239                let mut sum = 0.0f32;
240                for r in 0..self.rank {
241                    let b_val = b_data[row * self.rank + r];
242                    let a_val = a_data[r * self.d_in + col];
243                    sum += b_val * a_val;
244                }
245                merged[row * self.d_in + col] += self.scale * sum;
246            }
247        }
248
249        merged
250    }
251
252    /// Get reference to quantized base weights
253    pub fn base_weight_quantized(&self) -> &Quantized4Bit {
254        &self.base_weight_quantized
255    }
256
257    /// Check if double quantization is enabled (ENT-LoRA-008)
258    pub fn is_double_quantized(&self) -> bool {
259        self.base_weight_double.is_some()
260    }
261}
262
263/// Memory usage statistics for QLoRA layer
264#[derive(Debug, Clone)]
265pub struct MemoryStats {
266    /// Base weight size if unquantized (bytes)
267    pub base_unquantized_bytes: usize,
268    /// Base weight size quantized (bytes)
269    pub base_quantized_bytes: usize,
270    /// LoRA adapters size (bytes)
271    pub lora_bytes: usize,
272    /// Total memory usage (bytes)
273    pub total_bytes: usize,
274    /// Compression ratio for base weights
275    pub compression_ratio: f32,
276}
277
278#[cfg(test)]
279mod tests {
280    use super::*;
281    use approx::assert_abs_diff_eq;
282    use proptest::prelude::*;
283
284    // ========================================================================
285    // PROPERTY TESTS - Mathematical correctness validation
286    // ========================================================================
287
288    proptest! {
289        #![proptest_config(proptest::test_runner::Config::with_cases(200))]
290
291        /// Memory savings should be consistent with dimensions
292        #[test]
293        fn prop_qlora_memory_savings_consistent(
294            d in 8usize..32,
295            rank in 1usize..8,
296            alpha in 1.0f32..32.0
297        ) {
298            let size = d * d;
299            let base_weight = Tensor::from_vec(vec![0.5; size], false);
300            let qlora = QLoRALayer::new(base_weight, d, d, rank, alpha);
301
302            let stats = qlora.memory_stats();
303
304            // Quantized should always be smaller than unquantized
305            prop_assert!(stats.base_quantized_bytes <= stats.base_unquantized_bytes);
306
307            // Compression ratio should be > 1.0
308            prop_assert!(stats.compression_ratio >= 1.0);
309
310            // Total bytes = quantized base + lora
311            prop_assert_eq!(
312                stats.total_bytes,
313                stats.base_quantized_bytes + stats.lora_bytes
314            );
315
316            // LoRA bytes = (d*rank + d*rank) * 4 bytes
317            let expected_lora_bytes = (d * rank + d * rank) * 4;
318            prop_assert_eq!(stats.lora_bytes, expected_lora_bytes);
319        }
320
321        /// LoRA parameters should be preserved during quantization
322        #[test]
323        fn prop_lora_params_preserved_after_quantization(
324            d_out in 4usize..16,
325            d_in in 4usize..16,
326            rank in 1usize..4,
327            alpha in 1.0f32..16.0
328        ) {
329            let size = d_out * d_in;
330            let base_weight = Tensor::from_vec(vec![1.0; size], false);
331            let lora = LoRALayer::new(base_weight.clone(), d_out, d_in, rank, alpha);
332
333            let qlora = QLoRALayer::from_lora(lora.clone());
334
335            // Dimensions should match exactly
336            prop_assert_eq!(qlora.d_out(), lora.d_out());
337            prop_assert_eq!(qlora.d_in(), lora.d_in());
338            prop_assert_eq!(qlora.rank(), lora.rank());
339
340            // Scale should match
341            prop_assert!((qlora.scale() - lora.scale()).abs() < 1e-6);
342
343            // LoRA A and B data should be identical
344            prop_assert_eq!(qlora.lora_a().data().len(), lora.lora_a().data().len());
345            prop_assert_eq!(qlora.lora_b().data().len(), lora.lora_b().data().len());
346
347            for (a, b) in qlora.lora_a().data().iter().zip(lora.lora_a().data().iter()) {
348                prop_assert!((a - b).abs() < 1e-6);
349            }
350            for (a, b) in qlora.lora_b().data().iter().zip(lora.lora_b().data().iter()) {
351                prop_assert!((a - b).abs() < 1e-6);
352            }
353        }
354
355        /// Quantization error should be bounded
356        #[test]
357        fn prop_quantization_error_bounded(
358            d in 8usize..24,
359        ) {
360            let size = d * d;
361            // Use values in reasonable range for 4-bit quantization
362            let base_weight = Tensor::from_vec(
363                (0..size).map(|i| ((i % 16) as f32 - 8.0) * 0.1).collect(),
364                false
365            );
366            let lora = LoRALayer::new(base_weight.clone(), d, d, 2, 4.0);
367            let qlora = QLoRALayer::from_lora(lora.clone());
368
369            // Forward with same input
370            let x = Tensor::from_vec(vec![0.1; d], true);
371            let lora_out = lora.forward(&x);
372            let qlora_out = qlora.forward(&x);
373
374            // Outputs should be close (quantization introduces some error)
375            prop_assert_eq!(lora_out.len(), qlora_out.len());
376            for i in 0..lora_out.len() {
377                let diff = (lora_out.data()[i] - qlora_out.data()[i]).abs();
378                // Allow 30% relative error due to 4-bit quantization
379                let max_diff = lora_out.data()[i].abs() * 0.3 + 0.5;
380                prop_assert!(
381                    diff < max_diff,
382                    "Quantization error {} > {} at index {}",
383                    diff, max_diff, i
384                );
385            }
386        }
387
388        /// Forward output dimensions should always be correct
389        #[test]
390        fn prop_forward_dimensions_correct(
391            d_out in 4usize..16,
392            d_in in 4usize..16,
393            rank in 1usize..4,
394        ) {
395            let size = d_out * d_in;
396            let base_weight = Tensor::from_vec(vec![1.0; size], false);
397            let qlora = QLoRALayer::new(base_weight, d_out, d_in, rank, 4.0);
398
399            let x = Tensor::from_vec(vec![0.5; d_in], true);
400            let output = qlora.forward(&x);
401
402            prop_assert_eq!(output.len(), d_out);
403        }
404
405        /// Trainable params should have correct dimensions
406        #[test]
407        fn prop_trainable_params_dimensions(
408            d_out in 4usize..16,
409            d_in in 4usize..16,
410            rank in 1usize..4,
411        ) {
412            let size = d_out * d_in;
413            let base_weight = Tensor::from_vec(vec![1.0; size], false);
414            let mut qlora = QLoRALayer::new(base_weight, d_out, d_in, rank, 4.0);
415
416            let params = qlora.trainable_params();
417            prop_assert_eq!(params.len(), 2);
418
419            // A: [rank * d_in]
420            prop_assert_eq!(params[0].len(), rank * d_in);
421            // B: [d_out * rank]
422            prop_assert_eq!(params[1].len(), d_out * rank);
423
424            // Both should require grad
425            prop_assert!(params[0].requires_grad());
426            prop_assert!(params[1].requires_grad());
427        }
428    }
429
430    // ========================================================================
431    // UNIT TESTS
432    // ========================================================================
433
434    #[test]
435    fn test_qlora_creation() {
436        let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
437        let qlora = QLoRALayer::new(base_weight, 2, 2, 1, 2.0);
438
439        assert_eq!(qlora.rank(), 1);
440        assert_eq!(qlora.d_out(), 2);
441        assert_eq!(qlora.d_in(), 2);
442        assert_abs_diff_eq!(qlora.scale(), 2.0, epsilon = 1e-6); // alpha/rank = 2/1
443        assert!(!qlora.is_merged());
444    }
445
446    #[test]
447    fn test_qlora_forward_matches_lora() {
448        // Test that QLoRA forward pass approximates LoRA
449        let base_weight = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
450        let mut lora = LoRALayer::new(base_weight.clone(), 2, 2, 1, 1.0);
451        *lora.lora_a_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
452        *lora.lora_b_mut().data_mut() = ndarray::arr1(&[0.3, 0.3]);
453
454        let mut qlora = QLoRALayer::new(base_weight, 2, 2, 1, 1.0);
455        *qlora.lora_a_mut().data_mut() = ndarray::arr1(&[0.5, 0.5]);
456        *qlora.lora_b_mut().data_mut() = ndarray::arr1(&[0.3, 0.3]);
457
458        let x = Tensor::from_vec(vec![2.0, 3.0], true);
459
460        let lora_output = lora.forward(&x);
461        let qlora_output = qlora.forward(&x);
462
463        // Outputs should be close (within quantization error)
464        assert_eq!(lora_output.len(), qlora_output.len());
465        for i in 0..lora_output.len() {
466            let diff = (lora_output.data()[i] - qlora_output.data()[i]).abs();
467            assert!(
468                diff < 0.2,
469                "Output mismatch at {}: {} vs {} (diff: {})",
470                i,
471                lora_output.data()[i],
472                qlora_output.data()[i],
473                diff
474            );
475        }
476    }
477
478    #[test]
479    fn test_qlora_memory_savings() {
480        // Test with large enough weight to show compression (use perfect square)
481        let d = 16; // 16x16 = 256 elements
482        let size = d * d;
483        let base_weight = Tensor::from_vec(vec![1.0; size], false);
484        let qlora = QLoRALayer::new(base_weight, d, d, 8, 16.0);
485
486        let stats = qlora.memory_stats();
487
488        // Should see significant memory savings
489        assert!(
490            stats.base_quantized_bytes < stats.base_unquantized_bytes,
491            "Quantized should use less memory"
492        );
493
494        // Compression ratio should be > 6x
495        assert!(
496            stats.compression_ratio > 6.0,
497            "Compression ratio {} should be > 6.0",
498            stats.compression_ratio
499        );
500    }
501
502    #[test]
503    fn test_qlora_trainable_params() {
504        let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
505        let mut qlora = QLoRALayer::new(base_weight, 2, 2, 2, 4.0);
506
507        let params = qlora.trainable_params();
508        assert_eq!(params.len(), 2);
509
510        // Both should be trainable
511        assert!(params[0].requires_grad());
512        assert!(params[1].requires_grad());
513    }
514
515    #[test]
516    fn test_qlora_from_lora() {
517        let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
518        let lora = LoRALayer::new(base_weight, 3, 2, 2, 8.0);
519
520        let qlora = QLoRALayer::from_lora(lora);
521
522        assert_eq!(qlora.rank(), 2);
523        assert_eq!(qlora.d_out(), 3);
524        assert_eq!(qlora.d_in(), 2);
525        assert_abs_diff_eq!(qlora.scale(), 4.0, epsilon = 1e-6); // 8/2 = 4
526    }
527
528    #[test]
529    fn test_qlora_merge_to_f32_dimensions() {
530        let d_out = 8;
531        let d_in = 16;
532        let base_weight = Tensor::from_vec(vec![1.0; d_out * d_in], false);
533        let qlora = QLoRALayer::new(base_weight, d_out, d_in, 4, 8.0);
534
535        let merged = qlora.merge_to_f32();
536        assert_eq!(merged.len(), d_out * d_in);
537    }
538
539    #[test]
540    fn test_qlora_merge_to_f32_includes_adapter() {
541        let d_out = 4;
542        let d_in = 4;
543        let base_weight = Tensor::from_vec(vec![0.0; d_out * d_in], false);
544        let mut qlora = QLoRALayer::new(base_weight, d_out, d_in, 2, 2.0);
545
546        // Set adapter weights to known values
547        *qlora.lora_a_mut().data_mut() = ndarray::arr1(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0, 1.0, 0.0]);
548        *qlora.lora_b_mut().data_mut() = ndarray::arr1(&[1.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 1.0]);
549
550        let merged = qlora.merge_to_f32();
551
552        // Base is all zeros (within quant error), adapter should contribute non-zero values
553        let adapter_contribution: f32 = merged.iter().map(|v| v.abs()).sum();
554        assert!(adapter_contribution > 0.0, "Merged weights should include adapter contribution");
555    }
556
557    #[test]
558    fn test_qlora_merge_to_f32_equivalence_with_lora() {
559        // For small values, QLoRA merge should approximate LoRA merge
560        let d_out = 4;
561        let d_in = 4;
562        let base_data = vec![
563            0.5, 0.3, -0.2, 0.1, 0.4, -0.1, 0.6, 0.2, -0.3, 0.5, 0.1, -0.4, 0.2, 0.3, -0.5, 0.6,
564        ];
565        let base_weight = Tensor::from_vec(base_data.clone(), false);
566        let mut lora = LoRALayer::new(base_weight.clone(), d_out, d_in, 2, 4.0);
567
568        let a_data = vec![0.1, 0.2, -0.1, 0.3, 0.2, -0.2, 0.1, 0.1];
569        let b_data = vec![0.3, -0.1, 0.2, 0.1, -0.2, 0.3, 0.1, -0.1];
570        *lora.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
571        *lora.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
572
573        let mut qlora = QLoRALayer::from_lora(lora.clone());
574        // Copy same adapter weights
575        *qlora.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
576        *qlora.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
577
578        // Merge LoRA in-place and extract base weight as the merged result
579        lora.merge();
580        let lora_merged: Vec<f32> = lora.base_weight().data().to_vec();
581        let qlora_merged = qlora.merge_to_f32();
582
583        assert_eq!(lora_merged.len(), qlora_merged.len());
584        for i in 0..lora_merged.len() {
585            let diff = (lora_merged[i] - qlora_merged[i]).abs();
586            assert!(
587                diff < 0.5,
588                "Merge difference too large at {i}: lora={}, qlora={}, diff={diff}",
589                lora_merged[i],
590                qlora_merged[i]
591            );
592        }
593    }
594
595    #[test]
596    fn test_qlora_large_matrix() {
597        // Test with realistic transformer dimensions
598        let d_model = 256;
599        let base_weight = Tensor::from_vec(vec![1.0; d_model * d_model], false);
600        let qlora = QLoRALayer::new(base_weight, d_model, d_model, 16, 32.0);
601
602        let x = Tensor::from_vec(vec![0.5; d_model], true);
603        let output = qlora.forward(&x);
604
605        assert_eq!(output.len(), d_model);
606
607        // Check memory savings
608        let stats = qlora.memory_stats();
609        let savings_percent =
610            (1.0 - stats.base_quantized_bytes as f32 / stats.base_unquantized_bytes as f32) * 100.0;
611
612        assert!(savings_percent > 70.0, "Should save > 70% memory, got {savings_percent}%");
613    }
614
615    // ========================================================================
616    // ENT-LoRA-008: Double Quantization Tests
617    // ========================================================================
618
619    #[test]
620    fn test_ent_lora_008_double_quant_creation() {
621        let base_weight = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], false);
622        let lora = LoRALayer::new(base_weight, 2, 2, 1, 2.0);
623        let qlora = QLoRALayer::from_lora_double_quant(lora);
624
625        assert!(qlora.is_double_quantized());
626        assert_eq!(qlora.d_out(), 2);
627        assert_eq!(qlora.d_in(), 2);
628    }
629
630    #[test]
631    fn test_ent_lora_008_double_quant_forward_close_to_single() {
632        let d = 64;
633        let base_weight =
634            Tensor::from_vec((0..d * d).map(|i| (i as f32 * 0.1).sin() * 2.0).collect(), false);
635        let lora = LoRALayer::new(base_weight, d, d, 4, 8.0);
636
637        let single = QLoRALayer::from_lora(lora.clone());
638        let double = QLoRALayer::from_lora_double_quant(lora);
639
640        let x = Tensor::from_vec(vec![0.1; d], true);
641        let single_out = single.forward(&x);
642        let double_out = double.forward(&x);
643
644        assert_eq!(single_out.len(), double_out.len());
645        for i in 0..single_out.len() {
646            let diff = (single_out.data()[i] - double_out.data()[i]).abs();
647            let tol = single_out.data()[i].abs() * 0.01 + 0.1;
648            assert!(
649                diff <= tol,
650                "Forward output diverged at [{i}]: single={}, double={}, diff={diff}",
651                single_out.data()[i],
652                double_out.data()[i]
653            );
654        }
655    }
656
657    #[test]
658    fn test_ent_lora_008_single_quant_not_double() {
659        let base_weight = Tensor::from_vec(vec![1.0; 16], false);
660        let qlora = QLoRALayer::new(base_weight, 4, 4, 2, 4.0);
661        assert!(!qlora.is_double_quantized());
662    }
663
664    #[test]
665    fn test_ent_lora_008_double_quant_memory_stats() {
666        let d = 256;
667        let base_weight = Tensor::from_vec(vec![1.0; d * d], false);
668        let lora = LoRALayer::new(base_weight, d, d, 16, 32.0);
669
670        let single = QLoRALayer::from_lora(lora.clone());
671        let double = QLoRALayer::from_lora_double_quant(lora);
672
673        let single_stats = single.memory_stats();
674        let double_stats = double.memory_stats();
675
676        // Double quant should use less memory for base weights
677        assert!(
678            double_stats.base_quantized_bytes <= single_stats.base_quantized_bytes,
679            "Double quant ({}) should use <= memory than single ({})",
680            double_stats.base_quantized_bytes,
681            single_stats.base_quantized_bytes
682        );
683    }
684
685    // ========================================================================
686    // COVERAGE GAP TESTS — double-quant merge, accessors, merged forward
687    // ========================================================================
688
689    #[test]
690    fn test_qlora_merge_to_f32_double_quant() {
691        // Covers the double quantization path in merge_to_f32()
692        let d_out = 8;
693        let d_in = 8;
694        let base_weight = Tensor::from_vec(
695            (0..d_out * d_in).map(|i| (i as f32 * 0.2).sin() * 0.5).collect(),
696            false,
697        );
698        let lora = LoRALayer::new(base_weight, d_out, d_in, 2, 4.0);
699        let qlora_dq = QLoRALayer::from_lora_double_quant(lora);
700
701        assert!(qlora_dq.is_double_quantized());
702
703        let merged = qlora_dq.merge_to_f32();
704        assert_eq!(merged.len(), d_out * d_in);
705
706        // All values should be finite
707        for val in &merged {
708            assert!(val.is_finite(), "Merged weight must be finite, got {val}");
709        }
710    }
711
712    #[test]
713    fn test_qlora_merge_to_f32_single_vs_double_close() {
714        // Verify single-quant and double-quant merge paths produce similar results
715        let d_out = 8;
716        let d_in = 8;
717        let base_data: Vec<f32> =
718            (0..d_out * d_in).map(|i| (i as f32 * 0.15).cos() * 0.3).collect();
719        let base_weight = Tensor::from_vec(base_data, false);
720
721        let lora = LoRALayer::new(base_weight, d_out, d_in, 2, 4.0);
722        let single = QLoRALayer::from_lora(lora.clone());
723        let double = QLoRALayer::from_lora_double_quant(lora);
724
725        let merged_single = single.merge_to_f32();
726        let merged_double = double.merge_to_f32();
727
728        assert_eq!(merged_single.len(), merged_double.len());
729        for i in 0..merged_single.len() {
730            let diff = (merged_single[i] - merged_double[i]).abs();
731            let tol = merged_single[i].abs() * 0.05 + 0.2;
732            assert!(
733                diff <= tol,
734                "merge_to_f32 single vs double diverged at [{i}]: single={}, double={}, diff={diff}",
735                merged_single[i],
736                merged_double[i]
737            );
738        }
739    }
740
741    #[test]
742    fn test_qlora_base_weight_quantized_accessor() {
743        // Covers base_weight_quantized() accessor
744        let d = 8;
745        let base_weight = Tensor::from_vec(vec![1.0; d * d], false);
746        let qlora = QLoRALayer::new(base_weight, d, d, 2, 4.0);
747
748        let quantized = qlora.base_weight_quantized();
749        // The quantized struct should have data — just verify it exists and has content
750        assert!(quantized.memory_bytes() > 0, "Quantized base weight should use memory");
751    }
752
753    #[test]
754    fn test_qlora_double_quant_forward_with_known_adapter() {
755        // Covers the double-quant forward path with explicit adapter weights
756        let d_out = 4;
757        let d_in = 4;
758        let base_weight = Tensor::from_vec(vec![0.5; d_out * d_in], false);
759        let lora = LoRALayer::new(base_weight, d_out, d_in, 2, 4.0);
760        let mut qlora = QLoRALayer::from_lora_double_quant(lora);
761
762        assert!(qlora.is_double_quantized());
763
764        // Set non-zero adapter weights
765        let a_data: Vec<f32> = (0..2 * d_in).map(|i| (i as f32 * 0.1).sin() * 0.5).collect();
766        let b_data: Vec<f32> = (0..d_out * 2).map(|i| (i as f32 * 0.2).cos() * 0.3).collect();
767        *qlora.lora_a_mut().data_mut() = ndarray::Array1::from_vec(a_data);
768        *qlora.lora_b_mut().data_mut() = ndarray::Array1::from_vec(b_data);
769
770        let x = Tensor::from_vec(vec![1.0; d_in], true);
771        let output = qlora.forward(&x);
772
773        assert_eq!(output.len(), d_out);
774        for val in output.data() {
775            assert!(val.is_finite(), "Forward output must be finite, got {val}");
776        }
777    }
778
779    #[test]
780    fn test_qlora_memory_stats_double_quant() {
781        // Covers the double-quant path in memory_stats()
782        let d = 16;
783        let base_weight = Tensor::from_vec(vec![1.0; d * d], false);
784        let lora = LoRALayer::new(base_weight, d, d, 4, 8.0);
785        let qlora = QLoRALayer::from_lora_double_quant(lora);
786
787        let stats = qlora.memory_stats();
788
789        // Basic sanity checks
790        assert!(stats.base_quantized_bytes > 0);
791        assert!(stats.lora_bytes > 0);
792        assert_eq!(stats.total_bytes, stats.base_quantized_bytes + stats.lora_bytes);
793        assert!(stats.compression_ratio >= 1.0);
794        assert_eq!(stats.base_unquantized_bytes, d * d * 4);
795    }
796
797    #[test]
798    fn test_qlora_memory_stats_clone_and_debug() {
799        // Covers Clone and Debug derives on MemoryStats
800        let base_weight = Tensor::from_vec(vec![1.0; 16], false);
801        let qlora = QLoRALayer::new(base_weight, 4, 4, 2, 4.0);
802
803        let stats = qlora.memory_stats();
804        let stats_clone = stats.clone();
805
806        assert_eq!(stats.total_bytes, stats_clone.total_bytes);
807        assert_eq!(stats.lora_bytes, stats_clone.lora_bytes);
808        assert_eq!(stats.base_quantized_bytes, stats_clone.base_quantized_bytes);
809
810        let debug_str = format!("{stats_clone:?}");
811        assert!(debug_str.contains("MemoryStats"));
812    }
813
814    #[test]
815    fn test_qlora_lora_a_mut_and_lora_b_mut() {
816        // Explicitly covers lora_a_mut() and lora_b_mut() on QLoRALayer
817        let base_weight = Tensor::from_vec(vec![1.0; 4], false);
818        let mut qlora = QLoRALayer::new(base_weight, 2, 2, 1, 2.0);
819
820        // Mutate A
821        *qlora.lora_a_mut().data_mut() = ndarray::arr1(&[10.0, 20.0]);
822        assert_abs_diff_eq!(qlora.lora_a().data()[0], 10.0, epsilon = 1e-6);
823        assert_abs_diff_eq!(qlora.lora_a().data()[1], 20.0, epsilon = 1e-6);
824
825        // Mutate B
826        *qlora.lora_b_mut().data_mut() = ndarray::arr1(&[30.0, 40.0]);
827        assert_abs_diff_eq!(qlora.lora_b().data()[0], 30.0, epsilon = 1e-6);
828        assert_abs_diff_eq!(qlora.lora_b().data()[1], 40.0, epsilon = 1e-6);
829    }
830}