ferrotorch_nn/lora.rs
1//! Low-Rank Adaptation (LoRA) for parameter-efficient fine-tuning.
2//!
3//! Instead of fine-tuning all weights of a pretrained model, LoRA freezes
4//! the original weights and injects a trainable low-rank decomposition:
5//!
6//! ```text
7//! W' = W + (alpha / r) * B @ A
8//! ```
9//!
10//! where `A` is `[r, in_features]` and `B` is `[out_features, r]`. Only `A`
11//! and `B` are trainable — the original `W` stays frozen. This dramatically
12//! reduces the number of trainable parameters while preserving model quality.
13//!
14//! # References
15//!
16//! Hu et al., "LoRA: Low-Rank Adaptation of Large Language Models", 2021.
17//!
18//! ## REQ status (per `.design/ferrotorch-nn/lora.md`)
19//!
20//! | REQ | Status | Evidence |
21//! |---|---|---|
22//! | REQ-1 | SHIPPED | impl: `pub struct LoRALinear<T: Float>` here with `base` / `lora_a` / `lora_b` / `alpha` / `rank` / `dropout` / `training` fields per Hu et al. 2021; non-test consumer: `pub use lora::LoRALinear` in `lib.rs` makes the type available to `ferrotorch-train`'s fine-tuning scaffolding. |
23//! | REQ-2 | SHIPPED | impl: the `LoRALinear::new` constructor body here with rank validation + N(0, 1/sqrt(rank)) init of A + zeros init of B + optional `Dropout` construction; non-test consumer: PEFT fine-tuning code calls `LoRALinear::new(base, rank, alpha, dropout_p)?`. |
24//! | REQ-3 | SHIPPED | impl: `<LoRALinear as Module>::forward` body (base + transposed matmul chain + scale + add) here; non-test consumer: fine-tuning training loops call `lora.forward(input)` every step. |
25//! | REQ-4 | SHIPPED | impl: `Module::parameters` returns `vec![&self.lora_a, &self.lora_b]` here, excluding the base; non-test consumer: `ferrotorch_optim::Optimizer::step` iterates `model.parameters_mut()` and only sees `lora_a` / `lora_b` (the frozen base is skipped). This is THE LoRA invariant. |
26//! | REQ-5 | SHIPPED | impl: the `LoRALinear::merge` body (triple-nested B @ A + weight update + LoRA reset) here; non-test consumer: inference-serving code calls `lora.merge()` then `lora.into_base()` to fuse the adapter for deployment. |
27//! | REQ-6 | SHIPPED | impl: `impl<T: Float> Module<T> for LoRALinear<T>` block here with `train` / `eval` cascading to `base` and `dropout`; non-test consumer: training-loop control flow toggles `model.train()` / `model.eval()` between training and validation, which cascades through `LoRALinear` to `Dropout`. |
28//! | REQ-7 | SHIPPED | impl: `impl<T: Float> Display for LoRALinear<T>` block here; non-test consumer: any `format!("{layer}")` in model summary logging (the same path that prints `Linear(...)` for the base). |
29//! | REQ-8 | SHIPPED | `LoRALinear` is `Send + Sync` by composition of `Send + Sync` fields; compile-time-asserted via `assert_send_sync::<LoRALinear<f32>>()` in tests; non-test consumer: any multi-threaded training scaffolding requiring `Send + Sync`. |
30//! | REQ-9 | SHIPPED | impl: the `rank` / `alpha` / `base` / `into_base` accessors here; non-test consumer: inference-serving code calls `lora.into_base()` after `lora.merge()` to drop the LoRA wrapper. |
31
32use ferrotorch_core::grad_fns::arithmetic::{add, mul};
33use ferrotorch_core::grad_fns::linalg::mm_differentiable;
34use ferrotorch_core::grad_fns::shape::transpose_2d;
35use ferrotorch_core::{FerrotorchError, FerrotorchResult, Float, Tensor, scalar};
36
37use crate::dropout::Dropout;
38use crate::init;
39use crate::linear::Linear;
40use crate::module::Module;
41use crate::parameter::Parameter;
42
43/// Low-Rank Adaptation wrapper for a [`Linear`] layer.
44///
45/// Freezes the original weight and adds a trainable low-rank decomposition.
46/// The forward pass computes:
47///
48/// ```text
49/// y = x @ W^T + x @ (B @ A)^T * (alpha / r) + bias
50/// ```
51///
52/// Only `lora_a` and `lora_b` appear in [`parameters()`](Module::parameters),
53/// so optimizers only update the low-rank matrices. The base layer's weight
54/// and bias are excluded from the parameter list (frozen).
55///
56/// # Initialization
57///
58/// - **A**: `N(0, 1/sqrt(r))` — Kaiming-style for the rank dimension.
59/// - **B**: Zeros — so the LoRA contribution starts at zero and training
60/// begins from the pretrained checkpoint.
61///
62/// # Merging
63///
64/// After fine-tuning, call [`merge()`](LoRALinear::merge) to fold the LoRA
65/// weights into the base layer. This eliminates the runtime overhead of the
66/// extra matmuls, producing a standard `Linear` layer for inference.
67///
68/// # Examples
69///
70/// ```ignore
71/// let base = Linear::<f32>::new(768, 768, true)?;
72/// let lora = LoRALinear::new(base, 8, 1.0, 0.0)?;
73/// let output = lora.forward(&input)?; // only lora_a, lora_b are trainable
74/// ```
75#[derive(Debug)]
76pub struct LoRALinear<T: Float> {
77 /// Original frozen linear layer (not included in `parameters()`).
78 base: Linear<T>,
79 /// Low-rank A matrix: `[r, in_features]`, trainable.
80 lora_a: Parameter<T>,
81 /// Low-rank B matrix: `[out_features, r]`, trainable.
82 lora_b: Parameter<T>,
83 /// Scaling factor (numerator of `alpha / r`).
84 alpha: f64,
85 /// Rank of the low-rank decomposition.
86 rank: usize,
87 /// Optional dropout on the LoRA input path.
88 dropout: Option<Dropout<T>>,
89 /// Whether the module is in training mode.
90 training: bool,
91}
92
93impl<T: Float> LoRALinear<T> {
94 /// Create a LoRA wrapper around an existing `Linear` layer.
95 ///
96 /// # Arguments
97 ///
98 /// - `base` — The pretrained linear layer to adapt. Its parameters are
99 /// frozen (excluded from `parameters()`).
100 /// - `rank` — Rank of the low-rank decomposition. Typical values: 1–64.
101 /// - `alpha` — Scaling factor. The LoRA contribution is scaled by
102 /// `alpha / rank`. Common choice: `alpha == rank` (scale = 1).
103 /// - `dropout_p` — Dropout probability on the LoRA input path. Set to
104 /// `0.0` to disable.
105 ///
106 /// # Errors
107 ///
108 /// Returns an error if `rank` is zero, if `dropout_p` is invalid, or if
109 /// parameter allocation fails.
110 pub fn new(base: Linear<T>, rank: usize, alpha: f64, dropout_p: f64) -> FerrotorchResult<Self> {
111 if rank == 0 {
112 return Err(FerrotorchError::InvalidArgument {
113 message: "LoRALinear: rank must be > 0".into(),
114 });
115 }
116
117 let in_features = base.in_features();
118 let out_features = base.out_features();
119
120 // A initialized from N(0, 1/sqrt(r)) — so the initial LoRA output
121 // has variance independent of rank.
122 let mut lora_a = Parameter::zeros(&[rank, in_features])?;
123 init::normal(&mut lora_a, 0.0, 1.0 / (rank as f64).sqrt())?;
124
125 // B initialized to zeros — LoRA contribution starts at zero.
126 let lora_b = Parameter::zeros(&[out_features, rank])?;
127
128 let dropout = if dropout_p > 0.0 {
129 Some(Dropout::new(dropout_p)?)
130 } else {
131 None
132 };
133
134 Ok(Self {
135 base,
136 lora_a,
137 lora_b,
138 alpha,
139 rank,
140 dropout,
141 training: true,
142 })
143 }
144
145 /// Merge LoRA weights into the base layer for inference efficiency.
146 ///
147 /// Computes `W_merged = W + (alpha/r) * B @ A` and replaces the base
148 /// weight. After merging, the forward pass is a single matmul with no
149 /// overhead. The LoRA matrices are reset to their initial state (A
150 /// re-initialized, B zeroed) so that additional fine-tuning can continue
151 /// from the merged checkpoint if desired.
152 pub fn merge(&mut self) -> FerrotorchResult<()> {
153 let scale = T::from(self.alpha / self.rank as f64).unwrap();
154
155 // B @ A: [out_features, r] @ [r, in_features] = [out_features, in_features]
156 let b_data = self.lora_b.data()?;
157 let a_data = self.lora_a.data()?;
158 let out_features = self.base.out_features();
159 let in_features = self.base.in_features();
160 let r = self.rank;
161
162 let zero = <T as num_traits::Zero>::zero();
163 let mut ba = vec![zero; out_features * in_features];
164 for i in 0..out_features {
165 for j in 0..in_features {
166 let mut sum = zero;
167 for k in 0..r {
168 sum += b_data[i * r + k] * a_data[k * in_features + j];
169 }
170 ba[i * in_features + j] = sum;
171 }
172 }
173
174 // W_merged = W + scale * B @ A
175 let w_data = self.base.weight.data()?;
176 let merged: Vec<T> = w_data
177 .iter()
178 .zip(ba.iter())
179 .map(|(&w, &d)| w + scale * d)
180 .collect();
181
182 self.base.weight = Parameter::from_slice(&merged, &[out_features, in_features])?;
183
184 // Reset LoRA matrices so the module can be fine-tuned again.
185 self.lora_a = Parameter::zeros(&[r, in_features])?;
186 init::normal(&mut self.lora_a, 0.0, 1.0 / (r as f64).sqrt())?;
187 self.lora_b = Parameter::zeros(&[out_features, r])?;
188
189 Ok(())
190 }
191
192 /// The effective rank of the adaptation.
193 #[inline]
194 pub fn rank(&self) -> usize {
195 self.rank
196 }
197
198 /// The scaling factor alpha.
199 #[inline]
200 pub fn alpha(&self) -> f64 {
201 self.alpha
202 }
203
204 /// Borrow the underlying base linear layer.
205 #[inline]
206 pub fn base(&self) -> &Linear<T> {
207 &self.base
208 }
209
210 /// Consume the LoRA wrapper and return the base linear layer.
211 ///
212 /// Call [`merge()`](LoRALinear::merge) first if you want the LoRA
213 /// weights folded into the base.
214 pub fn into_base(self) -> Linear<T> {
215 self.base
216 }
217}
218
219impl<T: Float> Module<T> for LoRALinear<T> {
220 /// Forward pass: base linear output plus scaled low-rank adaptation.
221 ///
222 /// ```text
223 /// y = base.forward(x) + (x @ A^T @ B^T) * (alpha / r)
224 /// ```
225 ///
226 /// When dropout is configured and the module is in training mode,
227 /// dropout is applied to the input on the LoRA path only (the base
228 /// path is unaffected).
229 fn forward(&self, input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
230 // Base forward (frozen weights — not in parameters()).
231 let base_out = self.base.forward(input)?;
232
233 // LoRA path: optionally apply dropout to input.
234 let lora_input = if let Some(ref dropout) = self.dropout {
235 if self.training {
236 dropout.forward(input)?
237 } else {
238 input.clone()
239 }
240 } else {
241 input.clone()
242 };
243
244 // lora_out = input @ A^T @ B^T
245 // A^T: [in_features, r]
246 let a_t = transpose_2d(self.lora_a.tensor())?;
247 // xa: [batch, r]
248 let xa = mm_differentiable(&lora_input, &a_t)?;
249 // B^T: [r, out_features]
250 let b_t = transpose_2d(self.lora_b.tensor())?;
251 // lora_out: [batch, out_features]
252 let lora_out = mm_differentiable(&xa, &b_t)?;
253
254 // Scale by alpha / r.
255 let scale_val = T::from(self.alpha / self.rank as f64).unwrap();
256 let scale_tensor = scalar(scale_val)?;
257 let scaled = mul(&lora_out, &scale_tensor)?;
258
259 // Add to base output.
260 add(&base_out, &scaled)
261 }
262
263 /// Returns only the LoRA parameters (A and B). The base layer's
264 /// parameters are frozen and excluded.
265 fn parameters(&self) -> Vec<&Parameter<T>> {
266 vec![&self.lora_a, &self.lora_b]
267 }
268
269 fn parameters_mut(&mut self) -> Vec<&mut Parameter<T>> {
270 vec![&mut self.lora_a, &mut self.lora_b]
271 }
272
273 fn named_parameters(&self) -> Vec<(String, &Parameter<T>)> {
274 vec![
275 ("lora_a".to_string(), &self.lora_a),
276 ("lora_b".to_string(), &self.lora_b),
277 ]
278 }
279
280 fn train(&mut self) {
281 self.training = true;
282 self.base.train();
283 if let Some(ref mut d) = self.dropout {
284 d.train();
285 }
286 }
287
288 fn eval(&mut self) {
289 self.training = false;
290 self.base.eval();
291 if let Some(ref mut d) = self.dropout {
292 d.eval();
293 }
294 }
295
296 fn is_training(&self) -> bool {
297 self.training
298 }
299}
300
301// ---------------------------------------------------------------------------
302// Display
303// ---------------------------------------------------------------------------
304
305impl<T: Float> std::fmt::Display for LoRALinear<T> {
306 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
307 write!(
308 f,
309 "LoRALinear(in_features={}, out_features={}, rank={}, alpha={}, bias={}, dropout={})",
310 self.base.in_features(),
311 self.base.out_features(),
312 self.rank,
313 self.alpha,
314 self.base.bias.is_some(),
315 self.dropout.is_some(),
316 )
317 }
318}
319
320// ---------------------------------------------------------------------------
321// Tests
322// ---------------------------------------------------------------------------
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use ferrotorch_core::{Tensor, TensorStorage};
328
329 /// Create a leaf tensor with given data and shape.
330 fn leaf(data: &[f32], shape: &[usize], requires_grad: bool) -> Tensor<f32> {
331 Tensor::from_storage(
332 TensorStorage::cpu(data.to_vec()),
333 shape.to_vec(),
334 requires_grad,
335 )
336 .unwrap()
337 }
338
339 /// Assert two float slices are element-wise close.
340 fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
341 assert_eq!(
342 actual.len(),
343 expected.len(),
344 "length mismatch: {} vs {}",
345 actual.len(),
346 expected.len()
347 );
348 for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
349 assert!(
350 (a - e).abs() < tol,
351 "index {i}: actual={a} expected={e} diff={}",
352 (a - e).abs()
353 );
354 }
355 }
356
357 // -----------------------------------------------------------------------
358 // Construction
359 // -----------------------------------------------------------------------
360
361 #[test]
362 fn test_construction() {
363 let base = Linear::<f32>::new(10, 5, true).unwrap();
364 let lora = LoRALinear::new(base, 4, 1.0, 0.0).unwrap();
365 assert_eq!(lora.rank(), 4);
366 assert_eq!(lora.alpha(), 1.0);
367 assert_eq!(lora.lora_a.shape(), &[4, 10]);
368 assert_eq!(lora.lora_b.shape(), &[5, 4]);
369 }
370
371 #[test]
372 fn test_construction_zero_rank_rejected() {
373 let base = Linear::<f32>::new(10, 5, true).unwrap();
374 assert!(LoRALinear::new(base, 0, 1.0, 0.0).is_err());
375 }
376
377 #[test]
378 fn test_construction_with_dropout() {
379 let base = Linear::<f32>::new(10, 5, true).unwrap();
380 let lora = LoRALinear::new(base, 4, 1.0, 0.1).unwrap();
381 assert!(lora.dropout.is_some());
382 }
383
384 #[test]
385 fn test_construction_invalid_dropout_rejected() {
386 let base = Linear::<f32>::new(10, 5, true).unwrap();
387 assert!(LoRALinear::new(base, 4, 1.0, 1.5).is_err());
388 }
389
390 // -----------------------------------------------------------------------
391 // Forward shape
392 // -----------------------------------------------------------------------
393
394 #[test]
395 fn test_forward_shape() {
396 let base = Linear::<f32>::new(8, 4, true).unwrap();
397 let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
398 let input = leaf(&[0.0; 24], &[3, 8], false);
399 let output = lora.forward(&input).unwrap();
400 assert_eq!(output.shape(), &[3, 4]);
401 }
402
403 #[test]
404 fn test_forward_shape_no_bias() {
405 let base = Linear::<f32>::new(6, 3, false).unwrap();
406 let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
407 let input = leaf(&[0.0; 12], &[2, 6], false);
408 let output = lora.forward(&input).unwrap();
409 assert_eq!(output.shape(), &[2, 3]);
410 }
411
412 // -----------------------------------------------------------------------
413 // Parameters — only LoRA A and B, not base
414 // -----------------------------------------------------------------------
415
416 #[test]
417 fn test_parameters_only_lora() {
418 let base = Linear::<f32>::new(10, 5, true).unwrap();
419 let lora = LoRALinear::new(base, 4, 1.0, 0.0).unwrap();
420 let params = lora.parameters();
421 // Only lora_a and lora_b — NOT base weight/bias.
422 assert_eq!(params.len(), 2);
423 // lora_a: 4 * 10 = 40, lora_b: 5 * 4 = 20
424 let total: usize = params.iter().map(|p| p.numel()).sum();
425 assert_eq!(total, 60);
426 }
427
428 #[test]
429 fn test_named_parameters_keys() {
430 let base = Linear::<f32>::new(10, 5, true).unwrap();
431 let lora = LoRALinear::new(base, 4, 1.0, 0.0).unwrap();
432 let named = lora.named_parameters();
433 assert_eq!(named.len(), 2);
434 assert_eq!(named[0].0, "lora_a");
435 assert_eq!(named[1].0, "lora_b");
436 }
437
438 // -----------------------------------------------------------------------
439 // Zero-initialized B means output matches base
440 // -----------------------------------------------------------------------
441
442 #[test]
443 fn test_zero_b_matches_base_output() {
444 // Since B is initialized to zeros, the LoRA contribution is zero.
445 // The LoRA output should exactly match the base Linear output.
446 let mut base = Linear::<f32>::new(3, 2, true).unwrap();
447 base.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 0.0, 1.0, 0.0], &[2, 3]).unwrap();
448 *base.bias.as_mut().unwrap() = Parameter::from_slice(&[10.0, 20.0], &[2]).unwrap();
449
450 // Compute base output for reference.
451 let input = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
452 let base_out = base.forward(&input).unwrap();
453 let base_data = base_out.data().unwrap().to_vec();
454
455 // Wrap in LoRA with rank=1. B is zeros, so LoRA contribution is zero.
456 let lora = LoRALinear::new(base, 1, 1.0, 0.0).unwrap();
457 let input2 = leaf(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], false);
458 let lora_out = lora.forward(&input2).unwrap();
459
460 assert_eq!(lora_out.shape(), &[2, 2]);
461 assert_close(lora_out.data().unwrap(), &base_data, 1e-5);
462 }
463
464 // -----------------------------------------------------------------------
465 // Different ranks
466 // -----------------------------------------------------------------------
467
468 #[test]
469 fn test_rank_1() {
470 let base = Linear::<f32>::new(8, 4, true).unwrap();
471 let lora = LoRALinear::new(base, 1, 1.0, 0.0).unwrap();
472 assert_eq!(lora.rank(), 1);
473 assert_eq!(lora.lora_a.shape(), &[1, 8]);
474 assert_eq!(lora.lora_b.shape(), &[4, 1]);
475 let input = leaf(&[0.0; 16], &[2, 8], false);
476 let output = lora.forward(&input).unwrap();
477 assert_eq!(output.shape(), &[2, 4]);
478 }
479
480 #[test]
481 fn test_rank_4() {
482 let base = Linear::<f32>::new(16, 8, false).unwrap();
483 let lora = LoRALinear::new(base, 4, 2.0, 0.0).unwrap();
484 assert_eq!(lora.rank(), 4);
485 assert_eq!(lora.lora_a.shape(), &[4, 16]);
486 assert_eq!(lora.lora_b.shape(), &[8, 4]);
487 let input = leaf(&[0.0; 32], &[2, 16], false);
488 let output = lora.forward(&input).unwrap();
489 assert_eq!(output.shape(), &[2, 8]);
490 }
491
492 #[test]
493 fn test_rank_16() {
494 let base = Linear::<f32>::new(64, 32, true).unwrap();
495 let lora = LoRALinear::new(base, 16, 8.0, 0.0).unwrap();
496 assert_eq!(lora.rank(), 16);
497 assert_eq!(lora.lora_a.shape(), &[16, 64]);
498 assert_eq!(lora.lora_b.shape(), &[32, 16]);
499 let input = leaf(&[0.0; 128], &[2, 64], false);
500 let output = lora.forward(&input).unwrap();
501 assert_eq!(output.shape(), &[2, 32]);
502 }
503
504 // -----------------------------------------------------------------------
505 // Merge produces equivalent output
506 // -----------------------------------------------------------------------
507
508 #[test]
509 fn test_merge_produces_same_output() {
510 // Create a base layer with known weights.
511 let mut base = Linear::<f32>::new(4, 3, true).unwrap();
512 base.weight = Parameter::from_slice(
513 &[
514 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
515 ],
516 &[3, 4],
517 )
518 .unwrap();
519 *base.bias.as_mut().unwrap() = Parameter::from_slice(&[0.1, 0.2, 0.3], &[3]).unwrap();
520
521 let mut lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
522
523 // Set known LoRA weights so the contribution is non-zero.
524 lora.lora_a =
525 Parameter::from_slice(&[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8], &[2, 4]).unwrap();
526 lora.lora_b = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0, 0.5, 0.5], &[3, 2]).unwrap();
527
528 // Compute output before merge.
529 let input = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 4], false);
530 let pre_merge_out = lora.forward(&input).unwrap();
531 let pre_data = pre_merge_out.data().unwrap().to_vec();
532
533 // Merge and compute output from the base layer directly.
534 lora.merge().unwrap();
535 let merged_base = &lora.base;
536 let input2 = leaf(&[1.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0], &[2, 4], false);
537 let post_merge_out = merged_base.forward(&input2).unwrap();
538
539 assert_close(post_merge_out.data().unwrap(), &pre_data, 1e-5);
540 }
541
542 // -----------------------------------------------------------------------
543 // Forward correctness with known weights
544 // -----------------------------------------------------------------------
545
546 #[test]
547 fn test_forward_correctness_known_weights() {
548 // base: W = [[1, 0], [0, 1]], bias = [0, 0] (identity, 2->2)
549 let mut base = Linear::<f32>::new(2, 2, true).unwrap();
550 base.weight = Parameter::from_slice(&[1.0, 0.0, 0.0, 1.0], &[2, 2]).unwrap();
551 *base.bias.as_mut().unwrap() = Parameter::from_slice(&[0.0, 0.0], &[2]).unwrap();
552
553 let mut lora = LoRALinear::new(base, 1, 2.0, 0.0).unwrap();
554
555 // A = [[1, 0]] (rank=1, in=2)
556 // B = [[1], [0]] (out=2, rank=1)
557 lora.lora_a = Parameter::from_slice(&[1.0, 0.0], &[1, 2]).unwrap();
558 lora.lora_b = Parameter::from_slice(&[1.0, 0.0], &[2, 1]).unwrap();
559
560 // input = [[1, 2]]
561 let input = leaf(&[1.0, 2.0], &[1, 2], false);
562 let output = lora.forward(&input).unwrap();
563
564 // base_out = [1, 2] (identity)
565 // LoRA: x @ A^T = [1,2] @ [[1],[0]] = [1]
566 // [1] @ B^T = [1] @ [[1, 0]] = [1, 0]
567 // scaled = [1, 0] * (2.0 / 1) = [2, 0]
568 // total = [1+2, 2+0] = [3, 2]
569 assert_eq!(output.shape(), &[1, 2]);
570 assert_close(output.data().unwrap(), &[3.0, 2.0], 1e-5);
571 }
572
573 // -----------------------------------------------------------------------
574 // Train / Eval
575 // -----------------------------------------------------------------------
576
577 #[test]
578 fn test_train_eval() {
579 let base = Linear::<f32>::new(4, 3, true).unwrap();
580 let mut lora = LoRALinear::new(base, 2, 1.0, 0.1).unwrap();
581 assert!(lora.is_training());
582 lora.eval();
583 assert!(!lora.is_training());
584 lora.train();
585 assert!(lora.is_training());
586 }
587
588 // -----------------------------------------------------------------------
589 // State dict
590 // -----------------------------------------------------------------------
591
592 #[test]
593 fn test_state_dict_keys() {
594 let base = Linear::<f32>::new(8, 4, true).unwrap();
595 let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
596 let sd = lora.state_dict();
597 assert!(sd.contains_key("lora_a"));
598 assert!(sd.contains_key("lora_b"));
599 assert!(!sd.contains_key("weight"));
600 assert!(!sd.contains_key("bias"));
601 assert_eq!(sd["lora_a"].shape(), &[2, 8]);
602 assert_eq!(sd["lora_b"].shape(), &[4, 2]);
603 }
604
605 #[test]
606 fn test_state_dict_roundtrip() {
607 let base = Linear::<f32>::new(6, 3, true).unwrap();
608 let lora = LoRALinear::new(base, 2, 1.0, 0.0).unwrap();
609 let sd = lora.state_dict();
610
611 let base2 = Linear::<f32>::new(6, 3, true).unwrap();
612 let mut lora2 = LoRALinear::new(base2, 2, 1.0, 0.0).unwrap();
613 lora2.load_state_dict(&sd, true).unwrap();
614
615 assert_close(
616 lora2.lora_a.data().unwrap(),
617 lora.lora_a.data().unwrap(),
618 1e-7,
619 );
620 assert_close(
621 lora2.lora_b.data().unwrap(),
622 lora.lora_b.data().unwrap(),
623 1e-7,
624 );
625 }
626
627 // -----------------------------------------------------------------------
628 // Display
629 // -----------------------------------------------------------------------
630
631 #[test]
632 fn test_display() {
633 let base = Linear::<f32>::new(10, 5, true).unwrap();
634 let lora = LoRALinear::new(base, 4, 2.0, 0.0).unwrap();
635 let s = format!("{lora}");
636 assert_eq!(
637 s,
638 "LoRALinear(in_features=10, out_features=5, rank=4, alpha=2, bias=true, dropout=false)"
639 );
640 }
641
642 // -----------------------------------------------------------------------
643 // Send + Sync
644 // -----------------------------------------------------------------------
645
646 #[test]
647 fn test_is_send_sync() {
648 fn assert_send_sync<T: Send + Sync>() {}
649 assert_send_sync::<LoRALinear<f32>>();
650 assert_send_sync::<LoRALinear<f64>>();
651 }
652}