entrenar/lora/layer/core.rs
1//! LoRA (Low-Rank Adaptation) layer implementation
2//!
3//! LoRA enables parameter-efficient fine-tuning by adding trainable low-rank
4//! decomposition matrices to frozen pretrained weights.
5//!
6//! For a frozen weight matrix W ∈ ℝ^(d_out × d_in), LoRA adds:
7//! ΔW = B @ A where A ∈ ℝ^(r × d_in) and B ∈ ℝ^(d_out × r)
8//!
9//! Forward pass: y = (W + α·B·A) @ x = W@x + α·(B@(A@x))
10//! where α is a scaling factor (typically alpha/r)
11
12use crate::autograd::matmul;
13use crate::Tensor;
14
15/// LoRA scaling mode (ENT-LoRA-004)
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum LoRAScaling {
18 /// Standard: scale = alpha / rank
19 Standard,
20 /// rsLoRA: scale = alpha / sqrt(rank) — rank-stable, default for rank > 16
21 RsLoRA,
22}
23
24impl LoRAScaling {
25 /// Compute the scaling factor
26 ///
27 /// # Panics
28 /// Panics if rank is zero
29 pub fn compute(self, alpha: f32, rank: usize) -> f32 {
30 assert!(rank > 0, "LoRA rank must be > 0");
31 match self {
32 Self::Standard => alpha / rank as f32,
33 Self::RsLoRA => alpha / (rank as f32).sqrt(),
34 }
35 }
36}
37
38/// LoRA layer: adds trainable low-rank adaptation to a frozen base weight
39#[derive(Clone)]
40pub struct LoRALayer {
41 /// Frozen base weight matrix stored as 1D [d_out * d_in]
42 base_weight: Tensor,
43 /// LoRA matrix A stored as 1D [r * d_in] - downprojection
44 lora_a: Tensor,
45 /// LoRA matrix B stored as 1D [d_out * r] - upprojection
46 lora_b: Tensor,
47 /// Output dimension
48 d_out: usize,
49 /// Input dimension
50 d_in: usize,
51 /// LoRA rank
52 rank: usize,
53 /// Scaling factor (alpha/rank)
54 scale: f32,
55 /// Whether the adapter is merged into base_weight
56 merged: bool,
57}
58
59impl LoRALayer {
60 /// Create a new LoRA layer
61 ///
62 /// # Arguments
63 /// * `base_weight` - Frozen pretrained weight [d_out * d_in]
64 /// * `d_out` - Output dimension
65 /// * `d_in` - Input dimension
66 /// * `rank` - LoRA rank (typically 4, 8, 16, 32, or 64)
67 /// * `alpha` - LoRA scaling parameter (often same as rank)
68 ///
69 /// # Returns
70 /// LoRA layer with randomly initialized A (Gaussian) and zero-initialized B
71 pub fn new(base_weight: Tensor, d_out: usize, d_in: usize, rank: usize, alpha: f32) -> Self {
72 assert!(rank > 0, "LoRA rank must be > 0");
73 assert_eq!(base_weight.len(), d_out * d_in, "Base weight size must match d_out * d_in");
74
75 // Initialize A with small Gaussian noise, B with zeros (standard LoRA init)
76 // This ensures that initially ΔW = B·A = 0
77 let lora_a_data: Vec<f32> = (0..rank * d_in)
78 .map(|i| {
79 // Simple deterministic "random" init for reproducibility in tests
80 let x = (i as f32 * 0.1).sin();
81 x * 0.01 // Small values
82 })
83 .collect();
84 let lora_a = Tensor::from_vec(lora_a_data, true);
85
86 let lora_b = Tensor::zeros(d_out * rank, true);
87
88 let scale = alpha / rank as f32;
89
90 Self { base_weight, lora_a, lora_b, d_out, d_in, rank, scale, merged: false }
91 }
92
93 /// Create a new LoRA layer with explicit scaling mode (ENT-LoRA-004)
94 ///
95 /// Use `LoRAScaling::RsLoRA` for rank-stable training (recommended for rank > 16).
96 pub fn new_with_scaling(
97 base_weight: Tensor,
98 d_out: usize,
99 d_in: usize,
100 rank: usize,
101 alpha: f32,
102 scaling: LoRAScaling,
103 ) -> Self {
104 let mut layer = Self::new(base_weight, d_out, d_in, rank, alpha);
105 layer.scale = scaling.compute(alpha, rank);
106 layer
107 }
108
109 /// Forward pass: y = W@x + scale * (B @ (A @ x))
110 ///
111 /// # Arguments
112 /// * `x` - Input tensor `[d_in]`
113 ///
114 /// # Returns
115 /// Output tensor `[d_out]`
116 pub fn forward(&self, x: &Tensor) -> Tensor {
117 assert_eq!(x.len(), self.d_in, "Input size must match d_in");
118
119 // Base forward: W @ x [d_out, d_in] @ [d_in, 1] -> [d_out, 1]
120 let base_output = matmul(&self.base_weight, x, self.d_out, self.d_in, 1);
121
122 if self.merged {
123 // If merged, W already includes LoRA adaptation
124 base_output
125 } else {
126 // LoRA forward: scale * (B @ (A @ x))
127 // Step 1: A @ x [r, d_in] @ [d_in, 1] -> [r, 1]
128 let lora_out_a = matmul(&self.lora_a, x, self.rank, self.d_in, 1);
129
130 // Step 2: B @ (A @ x) [d_out, r] @ [r, 1] -> [d_out, 1]
131 let lora_out_b = matmul(&self.lora_b, &lora_out_a, self.d_out, self.rank, 1);
132
133 // Step 3: scale * LoRA output
134 let mut scaled_lora_data = lora_out_b.data().to_owned();
135 for val in &mut scaled_lora_data {
136 *val *= self.scale;
137 }
138 let scaled_lora = Tensor::new(scaled_lora_data, false);
139
140 // Step 4: base + LoRA
141 let mut result_data = base_output.data().to_owned();
142 for (i, val) in result_data.iter_mut().enumerate() {
143 *val += scaled_lora.data()[i];
144 }
145 Tensor::new(result_data, base_output.requires_grad())
146 }
147 }
148
149 /// Merge LoRA weights into base weight: W' = W + scale * (B @ A)
150 ///
151 /// After merging, forward pass only uses W' (more efficient).
152 /// This is typically done for inference.
153 pub fn merge(&mut self) {
154 if self.merged {
155 return; // Already merged
156 }
157
158 // Compute B @ A [d_out, r] @ [r, d_in] -> [d_out, d_in]
159 let ba = matmul(&self.lora_b, &self.lora_a, self.d_out, self.rank, self.d_in);
160
161 // Scale and add to base weight: W' = W + scale * B @ A
162 for (i, val) in self.base_weight.data_mut().iter_mut().enumerate() {
163 *val += self.scale * ba.data()[i];
164 }
165
166 self.merged = true;
167 }
168
169 /// Unmerge LoRA weights from base weight: W = W' - scale * (B @ A)
170 ///
171 /// Reverses the merge operation. Useful for continuing training or
172 /// switching adapters.
173 pub fn unmerge(&mut self) {
174 if !self.merged {
175 return; // Not merged
176 }
177
178 // Compute B @ A
179 let ba = matmul(&self.lora_b, &self.lora_a, self.d_out, self.rank, self.d_in);
180
181 // Subtract from base weight: W = W' - scale * B @ A
182 for (i, val) in self.base_weight.data_mut().iter_mut().enumerate() {
183 *val -= self.scale * ba.data()[i];
184 }
185
186 self.merged = false;
187 }
188
189 /// Get reference to base weight matrix
190 pub fn base_weight(&self) -> &Tensor {
191 &self.base_weight
192 }
193
194 /// Get reference to LoRA A matrix
195 pub fn lora_a(&self) -> &Tensor {
196 &self.lora_a
197 }
198
199 /// Get mutable reference to LoRA A matrix
200 pub fn lora_a_mut(&mut self) -> &mut Tensor {
201 &mut self.lora_a
202 }
203
204 /// Get reference to LoRA B matrix
205 pub fn lora_b(&self) -> &Tensor {
206 &self.lora_b
207 }
208
209 /// Get mutable reference to LoRA B matrix
210 pub fn lora_b_mut(&mut self) -> &mut Tensor {
211 &mut self.lora_b
212 }
213
214 /// Get trainable parameters (A and B)
215 pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
216 vec![&mut self.lora_a, &mut self.lora_b]
217 }
218
219 /// Check if LoRA is merged
220 pub fn is_merged(&self) -> bool {
221 self.merged
222 }
223
224 /// Get rank
225 pub fn rank(&self) -> usize {
226 self.rank
227 }
228
229 /// Get scale factor
230 pub fn scale(&self) -> f32 {
231 self.scale
232 }
233
234 /// Get output dimension
235 pub fn d_out(&self) -> usize {
236 self.d_out
237 }
238
239 /// Get input dimension
240 pub fn d_in(&self) -> usize {
241 self.d_in
242 }
243}