Skip to main content

entrenar/lora/layer/
core.rs

1//! LoRA (Low-Rank Adaptation) layer implementation
2//!
3//! LoRA enables parameter-efficient fine-tuning by adding trainable low-rank
4//! decomposition matrices to frozen pretrained weights.
5//!
6//! For a frozen weight matrix W ∈ ℝ^(d_out × d_in), LoRA adds:
7//! ΔW = B @ A where A ∈ ℝ^(r × d_in) and B ∈ ℝ^(d_out × r)
8//!
9//! Forward pass: y = (W + α·B·A) @ x = W@x + α·(B@(A@x))
10//! where α is a scaling factor (typically alpha/r)
11
12use crate::autograd::matmul;
13use crate::Tensor;
14
15/// LoRA scaling mode (ENT-LoRA-004)
16#[derive(Debug, Clone, Copy, PartialEq)]
17pub enum LoRAScaling {
18    /// Standard: scale = alpha / rank
19    Standard,
20    /// rsLoRA: scale = alpha / sqrt(rank) — rank-stable, default for rank > 16
21    RsLoRA,
22}
23
24impl LoRAScaling {
25    /// Compute the scaling factor
26    ///
27    /// # Panics
28    /// Panics if rank is zero
29    pub fn compute(self, alpha: f32, rank: usize) -> f32 {
30        assert!(rank > 0, "LoRA rank must be > 0");
31        match self {
32            Self::Standard => alpha / rank as f32,
33            Self::RsLoRA => alpha / (rank as f32).sqrt(),
34        }
35    }
36}
37
38/// LoRA layer: adds trainable low-rank adaptation to a frozen base weight
39#[derive(Clone)]
40pub struct LoRALayer {
41    /// Frozen base weight matrix stored as 1D [d_out * d_in]
42    base_weight: Tensor,
43    /// LoRA matrix A stored as 1D [r * d_in] - downprojection
44    lora_a: Tensor,
45    /// LoRA matrix B stored as 1D [d_out * r] - upprojection
46    lora_b: Tensor,
47    /// Output dimension
48    d_out: usize,
49    /// Input dimension
50    d_in: usize,
51    /// LoRA rank
52    rank: usize,
53    /// Scaling factor (alpha/rank)
54    scale: f32,
55    /// Whether the adapter is merged into base_weight
56    merged: bool,
57}
58
59impl LoRALayer {
60    /// Create a new LoRA layer
61    ///
62    /// # Arguments
63    /// * `base_weight` - Frozen pretrained weight [d_out * d_in]
64    /// * `d_out` - Output dimension
65    /// * `d_in` - Input dimension
66    /// * `rank` - LoRA rank (typically 4, 8, 16, 32, or 64)
67    /// * `alpha` - LoRA scaling parameter (often same as rank)
68    ///
69    /// # Returns
70    /// LoRA layer with randomly initialized A (Gaussian) and zero-initialized B
71    pub fn new(base_weight: Tensor, d_out: usize, d_in: usize, rank: usize, alpha: f32) -> Self {
72        assert!(rank > 0, "LoRA rank must be > 0");
73        assert_eq!(base_weight.len(), d_out * d_in, "Base weight size must match d_out * d_in");
74
75        // Initialize A with small Gaussian noise, B with zeros (standard LoRA init)
76        // This ensures that initially ΔW = B·A = 0
77        let lora_a_data: Vec<f32> = (0..rank * d_in)
78            .map(|i| {
79                // Simple deterministic "random" init for reproducibility in tests
80                let x = (i as f32 * 0.1).sin();
81                x * 0.01 // Small values
82            })
83            .collect();
84        let lora_a = Tensor::from_vec(lora_a_data, true);
85
86        let lora_b = Tensor::zeros(d_out * rank, true);
87
88        let scale = alpha / rank as f32;
89
90        Self { base_weight, lora_a, lora_b, d_out, d_in, rank, scale, merged: false }
91    }
92
93    /// Create a new LoRA layer with explicit scaling mode (ENT-LoRA-004)
94    ///
95    /// Use `LoRAScaling::RsLoRA` for rank-stable training (recommended for rank > 16).
96    pub fn new_with_scaling(
97        base_weight: Tensor,
98        d_out: usize,
99        d_in: usize,
100        rank: usize,
101        alpha: f32,
102        scaling: LoRAScaling,
103    ) -> Self {
104        let mut layer = Self::new(base_weight, d_out, d_in, rank, alpha);
105        layer.scale = scaling.compute(alpha, rank);
106        layer
107    }
108
109    /// Forward pass: y = W@x + scale * (B @ (A @ x))
110    ///
111    /// # Arguments
112    /// * `x` - Input tensor `[d_in]`
113    ///
114    /// # Returns
115    /// Output tensor `[d_out]`
116    pub fn forward(&self, x: &Tensor) -> Tensor {
117        assert_eq!(x.len(), self.d_in, "Input size must match d_in");
118
119        // Base forward: W @ x [d_out, d_in] @ [d_in, 1] -> [d_out, 1]
120        let base_output = matmul(&self.base_weight, x, self.d_out, self.d_in, 1);
121
122        if self.merged {
123            // If merged, W already includes LoRA adaptation
124            base_output
125        } else {
126            // LoRA forward: scale * (B @ (A @ x))
127            // Step 1: A @ x [r, d_in] @ [d_in, 1] -> [r, 1]
128            let lora_out_a = matmul(&self.lora_a, x, self.rank, self.d_in, 1);
129
130            // Step 2: B @ (A @ x) [d_out, r] @ [r, 1] -> [d_out, 1]
131            let lora_out_b = matmul(&self.lora_b, &lora_out_a, self.d_out, self.rank, 1);
132
133            // Step 3: scale * LoRA output
134            let mut scaled_lora_data = lora_out_b.data().to_owned();
135            for val in &mut scaled_lora_data {
136                *val *= self.scale;
137            }
138            let scaled_lora = Tensor::new(scaled_lora_data, false);
139
140            // Step 4: base + LoRA
141            let mut result_data = base_output.data().to_owned();
142            for (i, val) in result_data.iter_mut().enumerate() {
143                *val += scaled_lora.data()[i];
144            }
145            Tensor::new(result_data, base_output.requires_grad())
146        }
147    }
148
149    /// Merge LoRA weights into base weight: W' = W + scale * (B @ A)
150    ///
151    /// After merging, forward pass only uses W' (more efficient).
152    /// This is typically done for inference.
153    pub fn merge(&mut self) {
154        if self.merged {
155            return; // Already merged
156        }
157
158        // Compute B @ A [d_out, r] @ [r, d_in] -> [d_out, d_in]
159        let ba = matmul(&self.lora_b, &self.lora_a, self.d_out, self.rank, self.d_in);
160
161        // Scale and add to base weight: W' = W + scale * B @ A
162        for (i, val) in self.base_weight.data_mut().iter_mut().enumerate() {
163            *val += self.scale * ba.data()[i];
164        }
165
166        self.merged = true;
167    }
168
169    /// Unmerge LoRA weights from base weight: W = W' - scale * (B @ A)
170    ///
171    /// Reverses the merge operation. Useful for continuing training or
172    /// switching adapters.
173    pub fn unmerge(&mut self) {
174        if !self.merged {
175            return; // Not merged
176        }
177
178        // Compute B @ A
179        let ba = matmul(&self.lora_b, &self.lora_a, self.d_out, self.rank, self.d_in);
180
181        // Subtract from base weight: W = W' - scale * B @ A
182        for (i, val) in self.base_weight.data_mut().iter_mut().enumerate() {
183            *val -= self.scale * ba.data()[i];
184        }
185
186        self.merged = false;
187    }
188
189    /// Get reference to base weight matrix
190    pub fn base_weight(&self) -> &Tensor {
191        &self.base_weight
192    }
193
194    /// Get reference to LoRA A matrix
195    pub fn lora_a(&self) -> &Tensor {
196        &self.lora_a
197    }
198
199    /// Get mutable reference to LoRA A matrix
200    pub fn lora_a_mut(&mut self) -> &mut Tensor {
201        &mut self.lora_a
202    }
203
204    /// Get reference to LoRA B matrix
205    pub fn lora_b(&self) -> &Tensor {
206        &self.lora_b
207    }
208
209    /// Get mutable reference to LoRA B matrix
210    pub fn lora_b_mut(&mut self) -> &mut Tensor {
211        &mut self.lora_b
212    }
213
214    /// Get trainable parameters (A and B)
215    pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
216        vec![&mut self.lora_a, &mut self.lora_b]
217    }
218
219    /// Check if LoRA is merged
220    pub fn is_merged(&self) -> bool {
221        self.merged
222    }
223
224    /// Get rank
225    pub fn rank(&self) -> usize {
226        self.rank
227    }
228
229    /// Get scale factor
230    pub fn scale(&self) -> f32 {
231        self.scale
232    }
233
234    /// Get output dimension
235    pub fn d_out(&self) -> usize {
236        self.d_out
237    }
238
239    /// Get input dimension
240    pub fn d_in(&self) -> usize {
241        self.d_in
242    }
243}