Skip to main content

ferrotorch_nn/
lora.rs

1//! Low-Rank Adaptation (LoRA) for parameter-efficient fine-tuning.
2//!
3//! Instead of fine-tuning all weights of a pretrained model, LoRA freezes
4//! the original weights and injects a trainable low-rank decomposition:
5//!
6//! ```text
7//! W' = W + (alpha / r) * B @ A
8//! ```
9//!
10//! where `A` is `[r, in_features]` and `B` is `[out_features, r]`. Only `A`
11//! and `B` are trainable — the original `W` stays frozen. This dramatically
12//! reduces the number of trainable parameters while preserving model quality.
13//!
14//! # References
15//!
16//! Hu et al., "LoRA: Low-Rank Adaptation of Large Language Models", 2021.
17//!
18//! ## REQ status (per `.design/ferrotorch-nn/lora.md`)
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 | SHIPPED | impl: `pub struct LoRALinear<T: Float>` here with `base` / `lora_a` / `lora_b` / `alpha` / `rank` / `dropout` / `training` fields per Hu et al. 2021; non-test consumer: `pub use lora::LoRALinear` in `lib.rs` makes the type available to `ferrotorch-train`'s fine-tuning scaffolding. |
23//! | REQ-2 | SHIPPED | impl: the `LoRALinear::new` constructor body here with rank validation + N(0, 1/sqrt(rank)) init of A + zeros init of B + optional `Dropout` construction; non-test consumer: PEFT fine-tuning code calls `LoRALinear::new(base, rank, alpha, dropout_p)?`. |
24//! | REQ-3 | SHIPPED | impl: `<LoRALinear as Module>::forward` body (base + transposed matmul chain + scale + add) here; non-test consumer: fine-tuning training loops call `lora.forward(input)` every step. |
25//! | REQ-4 | SHIPPED | impl: `Module::parameters` returns `vec![&self.lora_a, &self.lora_b]` here, excluding the base; non-test consumer: `ferrotorch_optim::Optimizer::step` iterates `model.parameters_mut()` and only sees `lora_a` / `lora_b` (the frozen base is skipped). This is THE LoRA invariant. |
26//! | REQ-5 | SHIPPED | impl: the `LoRALinear::merge` body (triple-nested B @ A + weight update + LoRA reset) here; non-test consumer: inference-serving code calls `lora.merge()` then `lora.into_base()` to fuse the adapter for deployment. |
27//! | REQ-6 | SHIPPED | impl: `impl<T: Float> Module<T> for LoRALinear<T>` block here with `train` / `eval` cascading to `base` and `dropout`; non-test consumer: training-loop control flow toggles `model.train()` / `model.eval()` between training and validation, which cascades through `LoRALinear` to `Dropout`. |
28//! | REQ-7 | SHIPPED | impl: `impl<T: Float> Display for LoRALinear<T>` block here; non-test consumer: any `format!("{layer}")` in model summary logging (the same path that prints `Linear(...)` for the base). |
29//! | REQ-8 | SHIPPED | `LoRALinear` is `Send + Sync` by composition of `Send + Sync` fields; compile-time-asserted via `assert_send_sync::<LoRALinear<f32>>()` in tests; non-test consumer: any multi-threaded training scaffolding requiring `Send + Sync`. |
30//! | REQ-9 | SHIPPED | impl: the `rank` / `alpha` / `base` / `into_base` accessors here; non-test consumer: inference-serving code calls `lora.into_base()` after `lora.merge()` to drop the LoRA wrapper. |
31
32use ferrotorch_core::grad_fns::arithmetic::{add, mul};
33use ferrotorch_core::grad_fns::linalg::mm_differentiable;
34use ferrotorch_core::grad_fns::shape::transpose_2d;
35use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, scalar};
36
37use crate::dropout::Dropout;
38use crate::init;
39use crate::linear::Linear;
40use crate::module::Module;
41use crate::parameter::Parameter;
42
43/// Low-Rank Adaptation wrapper for a [`Linear`] layer.
44///
45/// Freezes the original weight and adds a trainable low-rank decomposition.
46/// The forward pass computes:
47///
48/// ```text
49/// y = x @ W^T + x @ (B @ A)^T * (alpha / r) + bias
50/// ```
51///
52/// Only `lora_a` and `lora_b` appear in [`parameters()`](Module::parameters),
53/// so optimizers only update the low-rank matrices. The base layer's weight
54/// and bias are excluded from the parameter list (frozen).
55///
56/// # Initialization
57///
58/// - **A**: `N(0, 1/sqrt(r))` — Kaiming-style for the rank dimension.
59/// - **B**: Zeros — so the LoRA contribution starts at zero and training
60///   begins from the pretrained checkpoint.
61///
62/// # Merging
63///
64/// After fine-tuning, call [`merge()`](LoRALinear::merge) to fold the LoRA
65/// weights into the base layer. This eliminates the runtime overhead of the
66/// extra matmuls, producing a standard `Linear` layer for inference.
67///
68/// # Examples
69///
70/// ```ignore
71/// let base = Linear::<f32>::new(768, 768, true)?;
72/// let lora = LoRALinear::new(base, 8, 1.0, 0.0)?;
73/// let output = lora.forward(&input)?;   // only lora_a, lora_b are trainable
74/// ```
75#[derive(Debug)]
76pub struct LoRALinear<T: Float> {
77    /// Original frozen linear layer (not included in `parameters()`).
78    base: Linear<T>,
79    /// Low-rank A matrix: `[r, in_features]`, trainable.
80    lora_a: Parameter<T>,
81    /// Low-rank B matrix: `[out_features, r]`, trainable.
82    lora_b: Parameter<T>,
83    /// Scaling factor (numerator of `alpha / r`).
84    alpha: f64,
85    /// Rank of the low-rank decomposition.
86    rank: usize,
87    /// Optional dropout on the LoRA input path.
88    dropout: Option<Dropout<T>>,
89    /// Whether the module is in training mode.
90    training: bool,
91}
92
93impl<T: Float> LoRALinear<T> {
94    /// Create a LoRA wrapper around an existing `Linear` layer.
95    ///
96    /// # Arguments
97    ///
98    /// - `base` — The pretrained linear layer to adapt. Its parameters are
99    ///   frozen (excluded from `parameters()`).
100    /// - `rank` — Rank of the low-rank decomposition. Typical values: 1–64.
101    /// - `alpha` — Scaling factor. The LoRA contribution is scaled by
102    ///   `alpha / rank`. Common choice: `alpha == rank` (scale = 1).
103    /// - `dropout_p` — Dropout probability on the LoRA input path. Set to
104    ///   `0.0` to disable.
105    ///
106    /// # Errors
107    ///
108    /// Returns an error if `rank` is zero, if `dropout_p` is invalid, or if
109    /// parameter allocation fails.
110    pub fn new(base: Linear<T>, rank: usize, alpha: f64, dropout_p: f64) -> FerrotorchResult<Self> {
111        if rank == 0 {
112            return Err(FerrotorchError::InvalidArgument {
113                message: "LoRALinear: rank must be > 0".into(),
114            });
115        }
116
117        let in_features = base.in_features();
118        let out_features = base.out_features();
119
120        // A initialized from N(0, 1/sqrt(r)) — so the initial LoRA output
121        // has variance independent of rank.
122        let mut lora_a = Parameter::zeros(&[rank, in_features])?;
123        init::normal(&mut lora_a, 0.0, 1.0 / (rank as f64).sqrt())?;
124
125        // B initialized to zeros — LoRA contribution starts at zero.
126        let lora_b = Parameter::zeros(&[out_features, rank])?;
127
128        let dropout = if dropout_p > 0.0 {
129            Some(Dropout::new(dropout_p)?)
130        } else {
131            None
132        };
133
134        Ok(Self {
135            base,
136            lora_a,
137            lora_b,
138            alpha,
139            rank,
140            dropout,
141            training: true,
142        })
143    }
144
145    /// Merge LoRA weights into the base layer for inference efficiency.
146    ///
147    /// Computes `W_merged = W + (alpha/r) * B @ A` and replaces the base
148    /// weight. After merging, the forward pass is a single matmul with no
149    /// overhead. The LoRA matrices are reset to their initial state (A
150    /// re-initialized, B zeroed) so that additional fine-tuning can continue
151    /// from the merged checkpoint if desired.
152    pub fn merge(&mut self) -> FerrotorchResult<()> {
153        let scale = T::from(self.alpha / self.rank as f64).unwrap();
154
155        // B @ A: [out_features, r] @ [r, in_features] = [out_features, in_features]
156        let b_data = self.lora_b.data()?;
157        let a_data = self.lora_a.data()?;
158        let out_features = self.base.out_features();
159        let in_features = self.base.in_features();
160        let r = self.rank;
161
162        let zero = <T as num_traits::Zero>::zero();
163        let mut ba = vec![zero; out_features * in_features];
164        for i in 0..out_features {
165            for j in 0..in_features {
166                let mut sum = zero;
167                for k in 0..r {
168                    sum += b_data[i * r + k] * a_data[k * in_features + j];
169                }
170                ba[i * in_features + j] = sum;
171            }
172        }
173
174        // W_merged = W + scale * B @ A
175        let w_data = self.base.weight.data()?;
176        let merged: Vec<T> = w_data
177            .iter()
178            .zip(ba.iter())
179            .map(|(&w, &d)| w + scale * d)
180            .collect();
181
182        self.base.weight = Parameter::from_slice(&merged, &[out_features, in_features])?;
183
184        // Reset LoRA matrices so the module can be fine-tuned again.
185        self.lora_a = Parameter::zeros(&[r, in_features])?;
186        init::normal(&mut self.lora_a, 0.0, 1.0 / (r as f64).sqrt())?;
187        self.lora_b = Parameter::zeros(&[out_features, r])?;
188
189        Ok(())
190    }
191
192    /// The effective rank of the adaptation.
193    #[inline]
194    pub fn rank(&self) -> usize {
195        self.rank
196    }
197
198    /// The scaling factor alpha.
199    #[inline]
200    pub fn alpha(&self) -> f64 {
201        self.alpha
202    }
203
204    /// Borrow the underlying base linear layer.
205    #[inline]
206    pub fn base(&self) -> &Linear<T> {
207        &self.base
208    }
209
210    /// Consume the LoRA wrapper and return the base linear layer.
211    ///
212    /// Call [`merge()`](LoRALinear::merge) first if you want the LoRA
213    /// weights folded into the base.
214    pub fn into_base(self) -> Linear<T> {
215        self.base
216    }
217}
218
219impl<T: Float> Module<T> for LoRALinear<T> {
220    /// Forward pass: base linear output plus scaled low-rank adaptation.
221    ///
222    /// ```text
223    /// y = base.forward(x) + (x @ A^T @ B^T) * (alpha / r)
224    /// ```
225    ///
226    /// When dropout is configured and the module is in training mode,
227    /// dropout is applied to the input on the LoRA path only (the base
228    /// path is unaffected).
229    fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
230        // Base forward (frozen weights — not in parameters()).
231        let base_out = self.base.forward(input)?;
232
233        // LoRA path: optionally apply dropout to input.
234        let lora_input = if let Some(ref dropout) = self.dropout {
235            if self.training {
236                dropout.forward(input)?
237            } else {
238                input.clone()
239            }
240        } else {
241            input.clone()
242        };
243
244        // lora_out = input @ A^T @ B^T
245        // A^T: [in_features, r]
246        let a_t = transpose_2d(self.lora_a.tensor())?;
247        // xa: [batch, r]
248        let xa = mm_differentiable(&lora_input, &a_t)?;
249        // B^T: [r, out_features]
250        let b_t = transpose_2d(self.lora_b.tensor())?;
251        // lora_out: [batch, out_features]
252        let lora_out = mm_differentiable(&xa, &b_t)?;
253
254        // Scale by alpha / r.
255        let scale_val = T::from(self.alpha / self.rank as f64).unwrap();
256        let scale_tensor = scalar(scale_val)?;
257        let scaled = mul(&lora_out, &scale_tensor)?;
258
259        // Add to base output.
260        add(&base_out, &scaled)
261    }
262
263    /// Returns only the LoRA parameters (A and B). The base layer's
264    /// parameters are frozen and excluded.
265    fn parameters(&self) -> Vec<&Parameter<T>> {
266        vec![&self.lora_a, &self.lora_b]
267    }
268
269    fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
270        vec![&mut self.lora_a, &mut self.lora_b]
271    }
272
273    fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
274        vec![
275            ("lora_a".to_string(), &self.lora_a),
276            ("lora_b".to_string(), &self.lora_b),
277        ]
278    }
279
280    fn train(&mut self) {
281        self.training = true;
282        self.base.train();
283        if let Some(ref mut d) = self.dropout {
284            d.train();
285        }
286    }
287
288    fn eval(&mut self) {
289        self.training = false;
290        self.base.eval();
291        if let Some(ref mut d) = self.dropout {
292            d.eval();
293        }
294    }
295
296    fn is_training(&self) -> bool {
297        self.training
298    }
299}
300
301// ---------------------------------------------------------------------------
302// Display
303// ---------------------------------------------------------------------------
304
305impl<T: Float> std::fmt::Display for LoRALinear<T> {
306    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307        write!(
308            f,
309            "LoRALinear(in_features={}, out_features={}, rank={}, alpha={}, bias={}, dropout={})",
310            self.base.in_features(),
311            self.base.out_features(),
312            self.rank,
313            self.alpha,
314            self.base.bias.is_some(),
315            self.dropout.is_some(),
316        )
317    }
318}
319
320// ---------------------------------------------------------------------------
321// Tests
322// ---------------------------------------------------------------------------
323
324#[cfg(test)]
325mod tests {
326    use super::*;
327    use ferrotorch_core::{Tensor, TensorStorage};
328
329    /// Create a leaf tensor with given data and shape.
330    fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
331        Tensor::from_storage(
332            TensorStorage::cpu(data.to_vec()),
333            shape.to_vec(),
334            requires_grad,
335        )
336        .unwrap()
337    }
338
339    /// Assert two float slices are element-wise close.
340    fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
341        assert_eq!(
342            actual.len(),
343            expected.len(),
344            "length mismatch: {} vs {}",
345            actual.len(),
346            expected.len()
347        );
348        for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
349            assert!(
350                (a - e).abs() < tol,
351                "index {i}: actual={a} expected={e} diff={}",
352                (a - e).abs()
353            );
354        }
355    }
356
357    // -----------------------------------------------------------------------
358    // Construction
359    // -----------------------------------------------------------------------
360
361    #[test]
362    fn test_construction() {
363        let base = Linear::<f32>::new(10, 5, true).unwrap();
364        let lora = LoRALinear::new(base, 4, 1.0, 0.0).unwrap();
365        assert_eq!(lora.rank(), 4);
366        assert_eq!(lora.alpha(), 1.0);
367        assert_eq!(lora.lora_a.shape(), &[4, 10]);
368        assert_eq!(lora.lora_b.shape(), &[5, 4]);
369    }
370
371    #[test]
372    fn test_construction_zero_rank_rejected() {
373        let base = Linear::<f32>::new(10, 5, true).unwrap();
374        assert!(LoRALinear::new(base, 0, 1.0, 0.0).is_err());
375    }
376
377    #[test]
378    fn test_construction_with_dropout() {
379        let base = Linear::<f32>::new(10, 5, true).unwrap();
380        let lora = LoRALinear::new(base, 4, 1.0, 0.1).unwrap();
381        assert!(lora.dropout.is_some());
382    }
383
384    #[test]
385    fn test_construction_invalid_dropout_rejected() {
386        let base = Linear::<f32>::new(10, 5, true).unwrap();
387        assert!(LoRALinear::new(base, 4, 1.0, 1.5).is_err());
388    }
389
390    // -----------------------------------------------------------------------
391    // Forward shape
392    // -----------------------------------------------------------------------
393
394    #[test]
395    fn test_forward_shape() {
396        let base = Linear::<f32>::new(8, 4, true).unwrap();
397        let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
398        let input = leaf(&[0.0; 24], &[3, 8], false);
399        let output = lora.forward(&input).unwrap();
400        assert_eq!(output.shape(), &[3, 4]);
401    }
402
403    #[test]
404    fn test_forward_shape_no_bias() {
405        let base = Linear::<f32>::new(6, 3, false).unwrap();
406        let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
407        let input = leaf(&[0.0; 12], &[2, 6], false);
408        let output = lora.forward(&input).unwrap();
409        assert_eq!(output.shape(), &[2, 3]);
410    }
411
412    // -----------------------------------------------------------------------
413    // Parameters — only LoRA A and B, not base
414    // -----------------------------------------------------------------------
415
416    #[test]
417    fn test_parameters_only_lora() {
418        let base = Linear::<f32>::new(10, 5, true).unwrap();
419        let lora = LoRALinear::new(base, 4, 1.0, 0.0).unwrap();
420        let params = lora.parameters();
421        // Only lora_a and lora_b — NOT base weight/bias.
422        assert_eq!(params.len(), 2);
423        // lora_a: 4 * 10 = 40, lora_b: 5 * 4 = 20
424        let total: usize = params.iter().map(|p| p.numel()).sum();
425        assert_eq!(total, 60);
426    }
427
428    #[test]
429    fn test_named_parameters_keys() {
430        let base = Linear::<f32>::new(10, 5, true).unwrap();
431        let lora = LoRALinear::new(base, 4, 1.0, 0.0).unwrap();
432        let named = lora.named_parameters();
433        assert_eq!(named.len(), 2);
434        assert_eq!(named[0].0, "lora_a");
435        assert_eq!(named[1].0, "lora_b");
436    }
437
438    // -----------------------------------------------------------------------
439    // Zero-initialized B means output matches base
440    // -----------------------------------------------------------------------
441
442    #[test]
443    fn test_zero_b_matches_base_output() {
444        // Since B is initialized to zeros, the LoRA contribution is zero.
445        // The LoRA output should exactly match the base Linear output.
446        let mut base = Linear::<f32>::new(3, 2, true).unwrap();
447        base.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
448        *base.bias.as_mut().unwrap() = Parameter::from_slice(&[10.0, 20.0], &[2]).unwrap();
449
450        // Compute base output for reference.
451        let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
452        let base_out = base.forward(&input).unwrap();
453        let base_data = base_out.data().unwrap().to_vec();
454
455        // Wrap in LoRA with rank=1. B is zeros, so LoRA contribution is zero.
456        let lora = LoRALinear::new(base, 1, 1.0, 0.0).unwrap();
457        let input2 = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
458        let lora_out = lora.forward(&input2).unwrap();
459
460        assert_eq!(lora_out.shape(), &[2, 2]);
461        assert_close(lora_out.data().unwrap(), &base_data, 1e-5);
462    }
463
464    // -----------------------------------------------------------------------
465    // Different ranks
466    // -----------------------------------------------------------------------
467
468    #[test]
469    fn test_rank_1() {
470        let base = Linear::<f32>::new(8, 4, true).unwrap();
471        let lora = LoRALinear::new(base, 1, 1.0, 0.0).unwrap();
472        assert_eq!(lora.rank(), 1);
473        assert_eq!(lora.lora_a.shape(), &[1, 8]);
474        assert_eq!(lora.lora_b.shape(), &[4, 1]);
475        let input = leaf(&[0.0; 16], &[2, 8], false);
476        let output = lora.forward(&input).unwrap();
477        assert_eq!(output.shape(), &[2, 4]);
478    }
479
480    #[test]
481    fn test_rank_4() {
482        let base = Linear::<f32>::new(16, 8, false).unwrap();
483        let lora = LoRALinear::new(base, 4, 2.0, 0.0).unwrap();
484        assert_eq!(lora.rank(), 4);
485        assert_eq!(lora.lora_a.shape(), &[4, 16]);
486        assert_eq!(lora.lora_b.shape(), &[8, 4]);
487        let input = leaf(&[0.0; 32], &[2, 16], false);
488        let output = lora.forward(&input).unwrap();
489        assert_eq!(output.shape(), &[2, 8]);
490    }
491
492    #[test]
493    fn test_rank_16() {
494        let base = Linear::<f32>::new(64, 32, true).unwrap();
495        let lora = LoRALinear::new(base, 16, 8.0, 0.0).unwrap();
496        assert_eq!(lora.rank(), 16);
497        assert_eq!(lora.lora_a.shape(), &[16, 64]);
498        assert_eq!(lora.lora_b.shape(), &[32, 16]);
499        let input = leaf(&[0.0; 128], &[2, 64], false);
500        let output = lora.forward(&input).unwrap();
501        assert_eq!(output.shape(), &[2, 32]);
502    }
503
504    // -----------------------------------------------------------------------
505    // Merge produces equivalent output
506    // -----------------------------------------------------------------------
507
508    #[test]
509    fn test_merge_produces_same_output() {
510        // Create a base layer with known weights.
511        let mut base = Linear::<f32>::new(4, 3, true).unwrap();
512        base.weight = Parameter::from_slice(
513            &[
514                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
515            ],
516            &[3, 4],
517        )
518        .unwrap();
519        *base.bias.as_mut().unwrap() = Parameter::from_slice(&[0.1, 0.2, 0.3], &[3]).unwrap();
520
521        let mut lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
522
523        // Set known LoRA weights so the contribution is non-zero.
524        lora.lora_a =
525            Parameter::from_slice(&[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], &[2, 4]).unwrap();
526        lora.lora_b = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0, 0.5, 0.5], &[3, 2]).unwrap();
527
528        // Compute output before merge.
529        let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 4], false);
530        let pre_merge_out = lora.forward(&input).unwrap();
531        let pre_data = pre_merge_out.data().unwrap().to_vec();
532
533        // Merge and compute output from the base layer directly.
534        lora.merge().unwrap();
535        let merged_base = &lora.base;
536        let input2 = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 4], false);
537        let post_merge_out = merged_base.forward(&input2).unwrap();
538
539        assert_close(post_merge_out.data().unwrap(), &pre_data, 1e-5);
540    }
541
542    // -----------------------------------------------------------------------
543    // Forward correctness with known weights
544    // -----------------------------------------------------------------------
545
546    #[test]
547    fn test_forward_correctness_known_weights() {
548        // base: W = [[1, 0], [0, 1]], bias = [0, 0]  (identity, 2->2)
549        let mut base = Linear::<f32>::new(2, 2, true).unwrap();
550        base.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
551        *base.bias.as_mut().unwrap() = Parameter::from_slice(&[0.0, 0.0], &[2]).unwrap();
552
553        let mut lora = LoRALinear::new(base, 1, 2.0, 0.0).unwrap();
554
555        // A = [[1, 0]]  (rank=1, in=2)
556        // B = [[1], [0]] (out=2, rank=1)
557        lora.lora_a = Parameter::from_slice(&[1.0, 0.0], &[1, 2]).unwrap();
558        lora.lora_b = Parameter::from_slice(&[1.0, 0.0], &[2, 1]).unwrap();
559
560        // input = [[1, 2]]
561        let input = leaf(&[1.0, 2.0], &[1, 2], false);
562        let output = lora.forward(&input).unwrap();
563
564        // base_out = [1, 2]  (identity)
565        // LoRA: x @ A^T = [1,2] @ [[1],[0]] = [1]
566        //       [1] @ B^T = [1] @ [[1, 0]] = [1, 0]
567        //       scaled = [1, 0] * (2.0 / 1) = [2, 0]
568        // total = [1+2, 2+0] = [3, 2]
569        assert_eq!(output.shape(), &[1, 2]);
570        assert_close(output.data().unwrap(), &[3.0, 2.0], 1e-5);
571    }
572
573    // -----------------------------------------------------------------------
574    // Train / Eval
575    // -----------------------------------------------------------------------
576
577    #[test]
578    fn test_train_eval() {
579        let base = Linear::<f32>::new(4, 3, true).unwrap();
580        let mut lora = LoRALinear::new(base, 2, 1.0, 0.1).unwrap();
581        assert!(lora.is_training());
582        lora.eval();
583        assert!(!lora.is_training());
584        lora.train();
585        assert!(lora.is_training());
586    }
587
588    // -----------------------------------------------------------------------
589    // State dict
590    // -----------------------------------------------------------------------
591
592    #[test]
593    fn test_state_dict_keys() {
594        let base = Linear::<f32>::new(8, 4, true).unwrap();
595        let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
596        let sd = lora.state_dict();
597        assert!(sd.contains_key("lora_a"));
598        assert!(sd.contains_key("lora_b"));
599        assert!(!sd.contains_key("weight"));
600        assert!(!sd.contains_key("bias"));
601        assert_eq!(sd["lora_a"].shape(), &[2, 8]);
602        assert_eq!(sd["lora_b"].shape(), &[4, 2]);
603    }
604
605    #[test]
606    fn test_state_dict_roundtrip() {
607        let base = Linear::<f32>::new(6, 3, true).unwrap();
608        let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
609        let sd = lora.state_dict();
610
611        let base2 = Linear::<f32>::new(6, 3, true).unwrap();
612        let mut lora2 = LoRALinear::new(base2, 2, 1.0, 0.0).unwrap();
613        lora2.load_state_dict(&sd, true).unwrap();
614
615        assert_close(
616            lora2.lora_a.data().unwrap(),
617            lora.lora_a.data().unwrap(),
618            1e-7,
619        );
620        assert_close(
621            lora2.lora_b.data().unwrap(),
622            lora.lora_b.data().unwrap(),
623            1e-7,
624        );
625    }
626
627    // -----------------------------------------------------------------------
628    // Display
629    // -----------------------------------------------------------------------
630
631    #[test]
632    fn test_display() {
633        let base = Linear::<f32>::new(10, 5, true).unwrap();
634        let lora = LoRALinear::new(base, 4, 2.0, 0.0).unwrap();
635        let s = format!("{lora}");
636        assert_eq!(
637            s,
638            "LoRALinear(in_features=10, out_features=5, rank=4, alpha=2, bias=true, dropout=false)"
639        );
640    }
641
642    // -----------------------------------------------------------------------
643    // Send + Sync
644    // -----------------------------------------------------------------------
645
646    #[test]
647    fn test_is_send_sync() {
648        fn assert_send_sync<T: Send + Sync>() {}
649        assert_send_sync::<LoRALinear<f32>>();
650        assert_send_sync::<LoRALinear<f64>>();
651    }
652}