Skip to main content

entrenar/lora/
dora.rs

1//! DoRA (Weight-Decomposed Low-Rank Adaptation) — ENT-LoRA-011
2//!
3//! DoRA decomposes the weight into magnitude `m` and direction `V/||V||`, then applies
4//! LoRA only to the direction component:
5//!
6//!   W' = m * (V + scale * B @ A) / ||V + scale * B @ A||
7//!
8//! This achieves +1-3% accuracy over standard LoRA on many benchmarks (ICML 2024 Oral).
9//!
10//! Reference: Liu et al. (2024). "DoRA: Weight-Decomposed Low-Rank Adaptation."
11
12use crate::lora::{LoRALayer, LoRAScaling};
13use crate::Tensor;
14
15/// DoRA layer: magnitude-direction decomposed LoRA
16///
17/// The base weight W is decomposed into:
18/// - `magnitude`: column norms `m = ||W_col||` for each output neuron
19/// - `direction`: normalized columns `V = W / m`
20///
21/// LoRA is applied to direction only, preserving the magnitude structure.
22pub struct DoRALayer {
23    /// Per-output-neuron magnitudes [d_out], trainable
24    magnitude: Tensor,
25    /// Underlying LoRA layer (applied to direction)
26    lora: LoRALayer,
27    /// Cached column norms of (V + scale * B @ A) for forward
28    d_out: usize,
29    d_in: usize,
30}
31
32impl DoRALayer {
33    /// Create a DoRA layer from a base weight
34    ///
35    /// Decomposes W into magnitude and direction, creates LoRA on the direction.
36    pub fn new(
37        base_weight: Tensor,
38        d_out: usize,
39        d_in: usize,
40        rank: usize,
41        alpha: f32,
42        scaling: LoRAScaling,
43    ) -> Self {
44        // Compute per-row magnitudes: m[i] = ||W[i, :]||
45        let magnitude_data: Vec<f32> = (0..d_out)
46            .map(|row| {
47                let row_start = row * d_in;
48                let row_end = row_start + d_in;
49                let row_norm_sq: f32 = base_weight
50                    .data()
51                    .slice(ndarray::s![row_start..row_end])
52                    .iter()
53                    .map(|x| x * x)
54                    .sum();
55                row_norm_sq.sqrt().max(1e-8)
56            })
57            .collect();
58        let magnitude = Tensor::from_vec(magnitude_data, true); // trainable
59
60        // Create LoRA layer on the base weight (direction component)
61        let lora = LoRALayer::new_with_scaling(base_weight, d_out, d_in, rank, alpha, scaling);
62
63        Self { magnitude, lora, d_out, d_in }
64    }
65
66    /// Forward pass: m * normalize(V + scale * B @ A) @ x
67    pub fn forward(&self, x: &Tensor) -> Tensor {
68        assert_eq!(x.len(), self.d_in, "Input size must match d_in");
69
70        // Compute V + scale * B @ A direction matrix
71        // For efficiency, compute (W + scale * B @ A) @ x first, then normalize
72        let lora_output = self.lora.forward(x); // (W + scale*B@A) @ x
73
74        // Normalize per row and apply magnitude
75        // Row norms of the effective weight matrix
76        let row_norms = self.compute_effective_row_norms();
77
78        let mut result = lora_output.data().to_owned();
79        for (i, val) in result.iter_mut().enumerate() {
80            let norm = row_norms[i].max(1e-8);
81            *val = self.magnitude.data()[i] * (*val / norm);
82        }
83
84        Tensor::new(result, self.magnitude.requires_grad())
85    }
86
87    /// Compute row norms of the effective weight matrix (W + scale * B @ A)
88    fn compute_effective_row_norms(&self) -> Vec<f32> {
89        let base = self.lora.base_weight().data();
90        let scale = self.lora.scale();
91        let a_data = self.lora.lora_a().data();
92        let b_data = self.lora.lora_b().data();
93        let rank = self.lora.rank();
94
95        let mut norms = vec![0.0f32; self.d_out];
96        for row in 0..self.d_out {
97            let mut row_norm_sq = 0.0f32;
98            for col in 0..self.d_in {
99                let base_val = base[row * self.d_in + col];
100                // Compute (B @ A)[row, col]
101                let mut ba_val = 0.0f32;
102                for r in 0..rank {
103                    ba_val += b_data[row * rank + r] * a_data[r * self.d_in + col];
104                }
105                let effective = base_val + scale * ba_val;
106                row_norm_sq += effective * effective;
107            }
108            norms[row] = row_norm_sq.sqrt();
109        }
110        norms
111    }
112
113    /// Merge DoRA into a single weight matrix for inference
114    pub fn merge_to_f32(&self) -> Vec<f32> {
115        let row_norms = self.compute_effective_row_norms();
116        let base = self.lora.base_weight().data();
117        let scale = self.lora.scale();
118        let a_data = self.lora.lora_a().data();
119        let b_data = self.lora.lora_b().data();
120        let rank = self.lora.rank();
121
122        let mut merged = vec![0.0f32; self.d_out * self.d_in];
123        for row in 0..self.d_out {
124            let m = self.magnitude.data()[row];
125            let norm = row_norms[row].max(1e-8);
126            for col in 0..self.d_in {
127                let base_val = base[row * self.d_in + col];
128                let mut ba_val = 0.0f32;
129                for r in 0..rank {
130                    ba_val += b_data[row * rank + r] * a_data[r * self.d_in + col];
131                }
132                merged[row * self.d_in + col] = m * (base_val + scale * ba_val) / norm;
133            }
134        }
135        merged
136    }
137
138    /// Get trainable parameters (magnitude + LoRA A + LoRA B)
139    pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
140        let mut params = vec![&mut self.magnitude];
141        params.extend(self.lora.trainable_params());
142        params
143    }
144
145    /// Get the magnitude vector
146    pub fn magnitude(&self) -> &Tensor {
147        &self.magnitude
148    }
149
150    /// Get the underlying LoRA layer
151    pub fn lora(&self) -> &LoRALayer {
152        &self.lora
153    }
154
155    /// Trainable param count: magnitude (d_out) + LoRA A (r*d_in) + LoRA B (d_out*r)
156    pub fn trainable_param_count(&self) -> usize {
157        self.d_out + self.lora.rank() * self.d_in + self.d_out * self.lora.rank()
158    }
159}
160
161#[cfg(test)]
162#[allow(clippy::unwrap_used)]
163mod tests {
164    use super::*;
165    use approx::assert_abs_diff_eq;
166    use proptest::prelude::*;
167
168    #[test]
169    fn test_ent_lora_011_dora_creation() {
170        let base = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
171        let dora = DoRALayer::new(base, 2, 2, 1, 2.0, LoRAScaling::Standard);
172        assert_eq!(dora.d_out, 2);
173        assert_eq!(dora.d_in, 2);
174        assert!(dora.magnitude().len() == 2);
175    }
176
177    #[test]
178    fn test_ent_lora_011_dora_magnitude_init() {
179        // Identity matrix: each row has norm 1.0
180        let base = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
181        let dora = DoRALayer::new(base, 2, 2, 1, 2.0, LoRAScaling::Standard);
182        assert_abs_diff_eq!(dora.magnitude().data()[0], 1.0, epsilon = 1e-6);
183        assert_abs_diff_eq!(dora.magnitude().data()[1], 1.0, epsilon = 1e-6);
184    }
185
186    #[test]
187    fn test_ent_lora_011_dora_forward_dimensions() {
188        let base = Tensor::from_vec(vec![1.0; 12], false);
189        let dora = DoRALayer::new(base, 3, 4, 2, 4.0, LoRAScaling::RsLoRA);
190        let x = Tensor::from_vec(vec![0.5; 4], true);
191        let out = dora.forward(&x);
192        assert_eq!(out.len(), 3);
193    }
194
195    #[test]
196    fn test_ent_lora_011_dora_trainable_count() {
197        let base = Tensor::from_vec(vec![1.0; 16], false);
198        let dora = DoRALayer::new(base, 4, 4, 2, 4.0, LoRAScaling::Standard);
199        // magnitude: 4 + A: 2*4=8 + B: 4*2=8 = 20
200        assert_eq!(dora.trainable_param_count(), 20);
201    }
202
203    #[test]
204    fn test_ent_lora_011_dora_merge_dimensions() {
205        let base = Tensor::from_vec(vec![1.0; 12], false);
206        let dora = DoRALayer::new(base, 3, 4, 2, 4.0, LoRAScaling::Standard);
207        let merged = dora.merge_to_f32();
208        assert_eq!(merged.len(), 12);
209    }
210
211    #[test]
212    fn test_ent_lora_011_dora_trainable_params() {
213        let base = Tensor::from_vec(vec![1.0; 16], false);
214        let mut dora = DoRALayer::new(base, 4, 4, 2, 4.0, LoRAScaling::Standard);
215        let params = dora.trainable_params();
216        // magnitude + A + B = 3 tensors
217        assert_eq!(params.len(), 3);
218    }
219
220    proptest! {
221        #![proptest_config(proptest::test_runner::Config::with_cases(50))]
222
223        #[test]
224        fn prop_dora_forward_finite(
225            d_out in 2usize..8,
226            d_in in 2usize..8,
227            rank in 1usize..4,
228        ) {
229            let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
230            let dora = DoRALayer::new(base, d_out, d_in, rank, 4.0, LoRAScaling::Standard);
231            let x = Tensor::from_vec(vec![0.1; d_in], true);
232            let out = dora.forward(&x);
233            prop_assert_eq!(out.len(), d_out);
234            for val in out.data() {
235                prop_assert!(val.is_finite(), "Output must be finite, got {val}");
236            }
237        }
238
239        #[test]
240        fn prop_dora_merge_finite(
241            d_out in 2usize..8,
242            d_in in 2usize..8,
243            rank in 1usize..4,
244        ) {
245            let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
246            let dora = DoRALayer::new(base, d_out, d_in, rank, 4.0, LoRAScaling::Standard);
247            let merged = dora.merge_to_f32();
248            prop_assert_eq!(merged.len(), d_out * d_in);
249            for val in &merged {
250                prop_assert!(val.is_finite());
251            }
252        }
253    }
254}