Skip to main content

qlora_rs/
qlora.rs

1//! `QLoRA` layer implementation.
2//!
3//! Combines quantized base weights with trainable `LoRA` adapters.
4//!
5//! # Training Configuration
6//!
7//! **CRITICAL**: Always use BF16 compute dtype for training stability.
8//! Using FP16 results in ~20% training failure rate due to numerical instability.
9//!
10//! # References
11//! - `QLoRA` paper: <https://arxiv.org/abs/2305.14314>
12//! - PEFT `prepare_model_for_kbit_training`: upcasts non-quantized modules to FP32
13
14use candle_core::{DType, Device, Tensor};
15use candle_nn::VarBuilder;
16use peft_rs::{Adapter, LoraConfig, LoraLayer};
17use serde::{Deserialize, Serialize};
18
19use crate::error::{QLoraError, Result};
20use crate::quantization::{
21    dequantize_nf4, quantize_nf4_with_config, ComputeDType, QuantizationConfig, QuantizedTensor,
22};
23
24fn warn_cpu_fallback(device: &Device) {
25    static WARN_ONCE: std::sync::Once = std::sync::Once::new();
26    if matches!(device, Device::Cpu) {
27        WARN_ONCE.call_once(|| {
28            eprintln!(
29                "qlora-rs: CPU device in use. CUDA is the intended default; enable the 'cuda' feature and use Device::cuda_if_available(0) when possible."
30            );
31        });
32    }
33}
34
35/// Configuration for `QLoRA` training and inference.
36///
37/// # Compute Dtype
38///
39/// **CRITICAL**: The `compute_dtype` field controls numerical precision during training.
40/// - `BF16`: **Required for training** - 100% stability rate
41/// - `FP16`: **Do NOT use for training** - 20% failure rate due to numerical instability
42/// - `FP32`: Stable but slower, useful for debugging
43///
44/// # Target Modules
45///
46/// The `target_modules` field controls which linear layers get `LoRA` adapters:
47/// - Minimal: `["q_proj", "v_proj"]` - 98% of full fine-tuning quality
48/// - Recommended: `["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"]`
49///   - Matches full fine-tuning quality (~99.3%)
50///
51/// # Example
52///
53/// ```rust
54/// use qlora_rs::QLoraConfig;
55///
56/// // Use preset for best stability
57/// let config = QLoraConfig::preset_all_bf16(64, 16);
58///
59/// // Or customize
60/// let config = QLoraConfig {
61///     target_modules: vec!["q_proj".into(), "v_proj".into()],
62///     ..QLoraConfig::preset_all_bf16(32, 8)
63/// };
64/// ```
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct QLoraConfig {
67    /// `LoRA` configuration (rank, alpha, dropout).
68    pub lora: LoraConfig,
69    /// Quantization configuration (block size, double quant).
70    pub quantization: QuantizationConfig,
71    /// Target modules to apply `LoRA` to.
72    /// Default: all linear layers in transformer blocks.
73    #[serde(default = "default_target_modules")]
74    pub target_modules: Vec<String>,
75    /// Whether to cache dequantized weights (opt-in, for inference speedup).
76    /// Default: false (on-the-fly dequantization saves memory).
77    #[serde(default)]
78    pub cache_dequantized: bool,
79}
80
81fn default_target_modules() -> Vec<String> {
82    vec![
83        "q_proj".into(),
84        "k_proj".into(),
85        "v_proj".into(),
86        "o_proj".into(),
87        "gate_proj".into(),
88        "up_proj".into(),
89        "down_proj".into(),
90    ]
91}
92
93impl Default for QLoraConfig {
94    /// Default configuration: BF16 compute, all linear layers targeted.
95    ///
96    /// **Note**: Default uses BF16 compute dtype for training stability.
97    fn default() -> Self {
98        Self {
99            lora: LoraConfig {
100                r: 64,
101                alpha: 16,
102                dropout: 0.05,
103                ..Default::default()
104            },
105            quantization: QuantizationConfig {
106                block_size: 64,
107                double_quant: true,
108                compute_dtype: ComputeDType::BF16, // CRITICAL: BF16 for stability
109                ..Default::default()
110            },
111            target_modules: default_target_modules(),
112            cache_dequantized: false, // On-the-fly by default (memory optimal)
113        }
114    }
115}
116
117impl QLoraConfig {
118    /// Create preset targeting all linear layers with BF16 compute.
119    ///
120    /// **Recommended for training**. Matches `QLoRA` paper configuration.
121    ///
122    /// # Arguments
123    /// * `r` - `LoRA` rank (typical: 8-64)
124    /// * `alpha` - `LoRA` scaling factor (typical: 16-32)
125    #[must_use]
126    pub fn preset_all_bf16(r: usize, alpha: usize) -> Self {
127        Self {
128            lora: LoraConfig {
129                r,
130                alpha,
131                dropout: 0.05,
132                ..Default::default()
133            },
134            quantization: QuantizationConfig {
135                block_size: 64,
136                double_quant: true,
137                compute_dtype: ComputeDType::BF16,
138                ..Default::default()
139            },
140            target_modules: default_target_modules(),
141            cache_dequantized: false,
142        }
143    }
144
145    /// Create preset targeting only attention Q/V projections with BF16 compute.
146    ///
147    /// Memory-optimal preset: fewer trainable parameters.
148    /// Achieves ~98% of full fine-tuning quality.
149    #[must_use]
150    pub fn preset_qv_bf16(r: usize, alpha: usize) -> Self {
151        Self {
152            lora: LoraConfig {
153                r,
154                alpha,
155                dropout: 0.05,
156                ..Default::default()
157            },
158            quantization: QuantizationConfig {
159                block_size: 64,
160                double_quant: true,
161                compute_dtype: ComputeDType::BF16,
162                ..Default::default()
163            },
164            target_modules: vec!["q_proj".into(), "v_proj".into()],
165            cache_dequantized: false,
166        }
167    }
168
169    /// Create preset for inference with weight caching enabled.
170    ///
171    /// Uses cached dequantization for faster inference at cost of memory.
172    #[must_use]
173    pub fn preset_inference(r: usize, alpha: usize) -> Self {
174        Self {
175            cache_dequantized: true, // Enable caching for inference speed
176            ..Self::preset_all_bf16(r, alpha)
177        }
178    }
179
180    /// Check if a module should have `LoRA` applied.
181    #[must_use]
182    pub fn is_target(&self, module_name: &str) -> bool {
183        self.target_modules.iter().any(|t| module_name.contains(t))
184    }
185
186    /// Get the `LoRA` scaling factor (alpha / r).
187    #[must_use]
188    #[allow(clippy::cast_precision_loss)]
189    pub fn scale(&self) -> f64 {
190        self.lora.alpha as f64 / self.lora.r as f64
191    }
192
193    /// Validate configuration for training.
194    ///
195    /// # Errors
196    /// Returns error if configuration is invalid for training.
197    pub fn validate_for_training(&self) -> Result<()> {
198        if self.lora.r == 0 {
199            return Err(QLoraError::InvalidConfig("LoRA rank must be > 0".into()));
200        }
201        if self.target_modules.is_empty() {
202            return Err(QLoraError::InvalidConfig(
203                "At least one target module required".into(),
204            ));
205        }
206        // Warn about FP16 but don't error - user might know what they're doing
207        if matches!(self.quantization.compute_dtype, ComputeDType::F16) {
208            tracing::warn!(
209                "FP16 compute dtype may cause training instability (20% failure rate). \
210                 Consider using BF16 instead."
211            );
212        }
213        Ok(())
214    }
215}
216
217/// A linear layer with quantized base weights and trainable `LoRA` adapters.
218///
219/// # Dequantization Modes
220///
221/// - **On-the-fly** (default): Dequantizes during each forward pass, saves memory.
222/// - **Cached** (opt-in via `cache_dequantized`): Dequantizes once, faster inference.
223///
224/// For training, always use on-the-fly mode (default) to save memory.
225/// For inference, consider enabling caching for ~30% speedup.
226pub struct QuantizedLinear {
227    /// Quantized base weight (frozen, NF4 format).
228    quantized_weight: QuantizedTensor,
229    /// Cached dequantized weight (opt-in for inference speedup).
230    cached_weight: Option<Tensor>,
231    /// Optional bias (not quantized, kept in full precision).
232    bias: Option<Tensor>,
233    /// `LoRA` adapter (trainable).
234    lora: LoraLayer,
235    /// Device for dequantization.
236    device: Device,
237    /// Configuration for quantization (needed for on-the-fly dequant).
238    config: QLoraConfig,
239}
240
241impl QuantizedLinear {
242    /// Create a new quantized linear layer from existing weights.
243    ///
244    /// Uses on-the-fly dequantization by default (memory-optimal).
245    /// Set `config.cache_dequantized = true` for inference speedup.
246    ///
247    /// # Arguments
248    /// * `weight` - Full-precision weight tensor to quantize
249    /// * `bias` - Optional bias tensor (kept in full precision)
250    /// * `config` - `QLoRA` configuration
251    /// * `device` - Device for computation
252    ///
253    /// # Errors
254    /// Returns error if weight tensor has invalid shape or quantization fails
255    pub fn from_weight(
256        weight: &Tensor,
257        bias: Option<Tensor>,
258        config: &QLoraConfig,
259        device: &Device,
260    ) -> Result<Self> {
261        warn_cpu_fallback(device);
262        let shape = weight.shape().dims();
263        if shape.len() != 2 {
264            return Err(QLoraError::InvalidConfig("weight must be 2D".into()));
265        }
266        let (out_features, in_features) = (shape[0], shape[1]);
267
268        // Quantize the base weight using full config
269        let quantized_weight = quantize_nf4_with_config(weight, &config.quantization)?;
270
271        // Only cache if explicitly requested (opt-in for inference)
272        let cached_weight = if config.cache_dequantized {
273            Some(dequantize_nf4(&quantized_weight, device)?)
274        } else {
275            None
276        };
277
278        // Create LoRA adapter
279        let lora =
280            LoraLayer::new_with_zeros(in_features, out_features, config.lora.clone(), device)?;
281
282        Ok(Self {
283            quantized_weight,
284            cached_weight,
285            bias,
286            lora,
287            device: device.clone(),
288            config: config.clone(),
289        })
290    }
291
292    /// Create a quantized linear layer with trainable `LoRA` weights registered via `VarBuilder`.
293    ///
294    /// This constructor ensures `LoRA` A/B weights are tracked for gradient computation.
295    /// Use this for training; use `from_weight` for inference.
296    ///
297    /// # Arguments
298    /// * `weight` - Full-precision weight tensor to quantize
299    /// * `bias` - Optional bias tensor (kept in full precision)
300    /// * `config` - `QLoRA` configuration
301    /// * `vb` - `VarBuilder` backed by `VarMap` for gradient tracking
302    ///
303    /// # Errors
304    /// Returns error if weight tensor has invalid shape or quantization fails
305    pub fn from_weight_with_varbuilder(
306        weight: &Tensor,
307        bias: Option<Tensor>,
308        config: &QLoraConfig,
309        vb: VarBuilder,
310    ) -> Result<Self> {
311        let shape = weight.shape().dims();
312        if shape.len() != 2 {
313            return Err(QLoraError::InvalidConfig("weight must be 2D".into()));
314        }
315        let (out_features, in_features) = (shape[0], shape[1]);
316        let device = weight.device();
317        warn_cpu_fallback(device);
318
319        // Quantize the base weight using full config
320        let quantized_weight = quantize_nf4_with_config(weight, &config.quantization)?;
321
322        // Only cache if explicitly requested (should be false for training)
323        let cached_weight = if config.cache_dequantized {
324            Some(dequantize_nf4(&quantized_weight, device)?)
325        } else {
326            None
327        };
328
329        // Create LoRA adapter with VarBuilder for gradient tracking
330        let lora = LoraLayer::new(in_features, out_features, config.lora.clone(), vb)?;
331
332        Ok(Self {
333            quantized_weight,
334            cached_weight,
335            bias,
336            lora,
337            device: device.clone(),
338            config: config.clone(),
339        })
340    }
341
342    /// Create a new quantized linear layer with zero-initialized quantized weights.
343    ///
344    /// Primarily for testing; use `from_weight` for actual models.
345    ///
346    /// # Errors
347    /// Returns error if tensor creation or quantization fails
348    pub fn new(
349        in_features: usize,
350        out_features: usize,
351        config: &QLoraConfig,
352        device: &Device,
353    ) -> Result<Self> {
354        let weight = Tensor::zeros(&[out_features, in_features], DType::F32, device)?;
355        Self::from_weight(&weight, None, config, device)
356    }
357
358    /// Forward pass through the quantized linear layer.
359    ///
360    /// Computes: `output = x @ W_q^T + x @ (B @ A)^T * scaling + bias`
361    ///
362    /// Uses on-the-fly dequantization unless `cache_dequantized` was enabled.
363    ///
364    /// # Errors
365    /// Returns error if tensor operations fail
366    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
367        // Get dequantized weight (either from cache or on-the-fly)
368        let weight = if let Some(cached) = &self.cached_weight {
369            cached.clone()
370        } else {
371            // On-the-fly dequantization (default, memory-optimal)
372            dequantize_nf4(&self.quantized_weight, &self.device)?
373        };
374        let weight_t = weight.t()?;
375
376        // Handle both 2D and 3D inputs for batch processing
377        let base_output = if input.dims().len() == 3 {
378            // For [batch, seq, in_features], reshape to [batch * seq, in_features]
379            let (batch, seq, in_features) = input.dims3()?;
380            let reshaped = input.reshape(&[batch * seq, in_features])?;
381            let out = reshaped.matmul(&weight_t)?;
382            // Reshape back to [batch, seq, out_features]
383            let out_features = weight_t.dim(1)?;
384            out.reshape(&[batch, seq, out_features])?
385        } else {
386            // For 2D [batch, in_features], standard matmul
387            input.matmul(&weight_t)?
388        };
389
390        // LoRA forward: adds x @ A^T @ B^T * scaling
391        let output = self.lora.forward(input, Some(&base_output))?;
392
393        // Add bias if present
394        match &self.bias {
395            Some(bias) => Ok(output.broadcast_add(bias)?),
396            None => Ok(output),
397        }
398    }
399
400    /// Enable weight caching for faster inference.
401    ///
402    /// Call this after loading a trained model for inference.
403    /// Not recommended for training (wastes memory).
404    ///
405    /// # Errors
406    /// Returns error if dequantization fails.
407    pub fn enable_weight_caching(&mut self) -> Result<()> {
408        if self.cached_weight.is_none() {
409            self.cached_weight = Some(dequantize_nf4(&self.quantized_weight, &self.device)?);
410        }
411        Ok(())
412    }
413
414    /// Disable weight caching to save memory.
415    pub fn disable_weight_caching(&mut self) {
416        self.cached_weight = None;
417    }
418
419    /// Check if weight caching is enabled.
420    #[must_use]
421    pub fn is_weight_cached(&self) -> bool {
422        self.cached_weight.is_some()
423    }
424
425    /// Get the `QLoRA` configuration used to create this layer.
426    #[must_use]
427    pub fn config(&self) -> &QLoraConfig {
428        &self.config
429    }
430
431    /// Get the `LoRA` adapter.
432    #[must_use]
433    pub fn lora(&self) -> &LoraLayer {
434        &self.lora
435    }
436
437    /// Get mutable access to the `LoRA` adapter.
438    pub fn lora_mut(&mut self) -> &mut LoraLayer {
439        &mut self.lora
440    }
441
442    /// Get the `LoRA` A and B weight tensors.
443    ///
444    /// Returns (`lora_a`, `lora_b`) where:
445    /// - `lora_a` has shape `[r, in_features]`
446    /// - `lora_b` has shape `[out_features, r]`
447    #[must_use]
448    pub fn lora_weights(&self) -> (&Tensor, &Tensor) {
449        self.lora.weights()
450    }
451
452    /// Get the number of trainable parameters (`LoRA` only).
453    #[must_use]
454    pub fn num_trainable_parameters(&self) -> usize {
455        self.lora.num_parameters()
456    }
457
458    /// Get total memory usage in bytes.
459    #[must_use]
460    pub fn memory_bytes(&self) -> usize {
461        let quantized_size = self.quantized_weight.size_bytes();
462        let lora_size = self.lora.num_parameters() * 4; // f32
463        let bias_size = self.bias.as_ref().map_or(0, |b| b.elem_count() * 4);
464        quantized_size + lora_size + bias_size
465    }
466}
467
468/// `QLoRA` adapter wrapping a model's linear layers.
469pub struct QLoraLayer {
470    /// Underlying quantized linear layer.
471    linear: QuantizedLinear,
472}
473
474impl QLoraLayer {
475    /// Create a new `QLoRA` layer.
476    #[must_use]
477    pub fn new(linear: QuantizedLinear) -> Self {
478        Self { linear }
479    }
480
481    /// Forward pass.
482    ///
483    /// # Errors
484    /// Returns error if the underlying linear layer forward fails
485    pub fn forward(&self, input: &Tensor) -> Result<Tensor> {
486        self.linear.forward(input)
487    }
488
489    /// Get a reference to the quantized base weight tensor.
490    #[must_use]
491    pub fn quantized_weight(&self) -> &QuantizedTensor {
492        &self.linear.quantized_weight
493    }
494
495    /// Get the `LoRA` A and B weight tensors.
496    ///
497    /// Returns (`lora_a`, `lora_b`) where:
498    /// - `lora_a` has shape `[r, in_features]`
499    /// - `lora_b` has shape `[out_features, r]`
500    #[must_use]
501    pub fn lora_weights(&self) -> (&Tensor, &Tensor) {
502        self.linear.lora_weights()
503    }
504
505    /// Get the `LoRA` scaling factor (alpha / rank).
506    #[must_use]
507    pub fn lora_scale(&self) -> f64 {
508        self.linear.config.scale()
509    }
510
511    /// Get the device used by this layer.
512    #[must_use]
513    pub fn device(&self) -> &Device {
514        &self.linear.device
515    }
516
517    /// Get the quantization configuration.
518    #[must_use]
519    pub fn config(&self) -> &QLoraConfig {
520        &self.linear.config
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527
528    #[test]
529    fn test_qlora_creation() {
530        let config = QLoraConfig::default();
531        let device = Device::Cpu;
532        let layer = QuantizedLinear::new(768, 768, &config, &device);
533        assert!(layer.is_ok());
534    }
535
536    #[test]
537    fn test_qlora_forward_shape() {
538        let config = QLoraConfig::default();
539        let device = Device::Cpu;
540        let layer = QuantizedLinear::new(768, 768, &config, &device).unwrap();
541
542        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device).unwrap();
543        let output = layer.forward(&input).unwrap();
544
545        assert_eq!(output.shape().dims(), &[1, 10, 768]);
546    }
547
548    #[test]
549    fn test_qlora_memory_reduction() {
550        let config = QLoraConfig::default();
551        let device = Device::Cpu;
552        let layer = QuantizedLinear::new(4096, 4096, &config, &device).unwrap();
553
554        // Full precision would be 4096 * 4096 * 4 = 67MB
555        let full_size = 4096 * 4096 * 4;
556        let actual_size = layer.memory_bytes();
557
558        // Should be significantly smaller due to quantization
559        #[allow(clippy::cast_precision_loss)]
560        let ratio = f64::from(full_size) / actual_size as f64;
561        assert!(ratio > 2.0, "Expected >2x reduction, got {ratio:.2}x");
562    }
563
564    // Tests for QLoraConfig presets and helper methods
565    // Addresses PR #10 review comment: "No test coverage for presets and helper methods"
566
567    #[test]
568    fn test_preset_all_bf16() {
569        let config = QLoraConfig::preset_all_bf16(64, 16);
570
571        // Check LoRA config
572        assert_eq!(config.lora.r, 64);
573        assert_eq!(config.lora.alpha, 16);
574        assert!((config.lora.dropout - 0.05).abs() < 1e-10);
575
576        // Check quantization config
577        assert!(matches!(
578            config.quantization.compute_dtype,
579            ComputeDType::BF16
580        ));
581        assert!(config.quantization.double_quant);
582
583        // Check target modules (should be all linear layers)
584        assert!(config.target_modules.contains(&"q_proj".to_string()));
585        assert!(config.target_modules.contains(&"k_proj".to_string()));
586        assert!(config.target_modules.contains(&"v_proj".to_string()));
587        assert!(config.target_modules.contains(&"o_proj".to_string()));
588        assert!(config.target_modules.contains(&"gate_proj".to_string()));
589
590        // Should not cache weights by default (memory-optimal for training)
591        assert!(!config.cache_dequantized);
592    }
593
594    #[test]
595    fn test_preset_qv_bf16() {
596        let config = QLoraConfig::preset_qv_bf16(32, 8);
597
598        // Check LoRA config
599        assert_eq!(config.lora.r, 32);
600        assert_eq!(config.lora.alpha, 8);
601
602        // Check target modules (should only be Q/V)
603        assert_eq!(config.target_modules.len(), 2);
604        assert!(config.target_modules.contains(&"q_proj".to_string()));
605        assert!(config.target_modules.contains(&"v_proj".to_string()));
606
607        // Should NOT contain other modules
608        assert!(!config.target_modules.contains(&"k_proj".to_string()));
609        assert!(!config.target_modules.contains(&"o_proj".to_string()));
610    }
611
612    #[test]
613    fn test_preset_inference() {
614        let config = QLoraConfig::preset_inference(16, 32);
615
616        // Check LoRA config
617        assert_eq!(config.lora.r, 16);
618        assert_eq!(config.lora.alpha, 32);
619
620        // Key difference: should cache weights for inference speed
621        assert!(config.cache_dequantized);
622
623        // Should still use BF16 compute
624        assert!(matches!(
625            config.quantization.compute_dtype,
626            ComputeDType::BF16
627        ));
628    }
629
630    #[test]
631    fn test_is_target() {
632        let config = QLoraConfig::preset_all_bf16(8, 16);
633
634        // Should match target modules
635        assert!(config.is_target("model.layer.q_proj"));
636        assert!(config.is_target("transformer.blocks.0.attn.v_proj"));
637        assert!(config.is_target("gate_proj"));
638
639        // Should NOT match non-target modules
640        assert!(!config.is_target("embed_tokens"));
641        assert!(!config.is_target("lm_head"));
642        assert!(!config.is_target("layer_norm"));
643    }
644
645    #[test]
646    fn test_scale() {
647        let config = QLoraConfig::preset_all_bf16(64, 16);
648        let scale = config.scale();
649
650        // scale = alpha / r = 16 / 64 = 0.25
651        assert!((scale - 0.25).abs() < 1e-10);
652
653        let config2 = QLoraConfig::preset_all_bf16(8, 32);
654        let scale2 = config2.scale();
655
656        // scale = alpha / r = 32 / 8 = 4.0
657        assert!((scale2 - 4.0).abs() < 1e-10);
658    }
659
660    #[test]
661    fn test_validate_for_training_success() {
662        let config = QLoraConfig::preset_all_bf16(8, 16);
663        assert!(config.validate_for_training().is_ok());
664    }
665
666    #[test]
667    fn test_validate_for_training_zero_rank() {
668        let mut config = QLoraConfig::preset_all_bf16(0, 16);
669        config.lora.r = 0;
670
671        let result = config.validate_for_training();
672        assert!(result.is_err());
673        if let Err(e) = result {
674            assert!(e.to_string().contains("rank"));
675        }
676    }
677
678    #[test]
679    fn test_validate_for_training_empty_targets() {
680        let mut config = QLoraConfig::preset_all_bf16(8, 16);
681        config.target_modules.clear();
682
683        let result = config.validate_for_training();
684        assert!(result.is_err());
685        if let Err(e) = result {
686            assert!(e.to_string().contains("target module"));
687        }
688    }
689
690    #[test]
691    fn test_default_config() {
692        let config = QLoraConfig::default();
693
694        // Should use BF16 by default for training stability
695        assert!(matches!(
696            config.quantization.compute_dtype,
697            ComputeDType::BF16
698        ));
699
700        // Should have standard LoRA defaults
701        assert_eq!(config.lora.r, 64);
702        assert_eq!(config.lora.alpha, 16);
703
704        // Should target all linear layers
705        assert!(!config.target_modules.is_empty());
706
707        // Should not cache by default (memory-optimal)
708        assert!(!config.cache_dequantized);
709    }
710
711    #[test]
712    fn test_lora_weights() {
713        let config = QLoraConfig::preset_all_bf16(8, 16);
714        let device = Device::Cpu;
715        let layer = QuantizedLinear::new(64, 128, &config, &device).unwrap();
716
717        let (a_weight, b_weight) = layer.lora_weights();
718
719        // A: [r, in_features] = [8, 64]
720        assert_eq!(a_weight.dims(), &[8, 64]);
721
722        // B: [out_features, r] = [128, 8]
723        assert_eq!(b_weight.dims(), &[128, 8]);
724    }
725}