1use crate::lora::{LoRALayer, LoRAScaling};
13use crate::Tensor;
14
15pub struct DoRALayer {
23 magnitude: Tensor,
25 lora: LoRALayer,
27 d_out: usize,
29 d_in: usize,
30}
31
32impl DoRALayer {
33 pub fn new(
37 base_weight: Tensor,
38 d_out: usize,
39 d_in: usize,
40 rank: usize,
41 alpha: f32,
42 scaling: LoRAScaling,
43 ) -> Self {
44 let magnitude_data: Vec<f32> = (0..d_out)
46 .map(|row| {
47 let row_start = row * d_in;
48 let row_end = row_start + d_in;
49 let row_norm_sq: f32 = base_weight
50 .data()
51 .slice(ndarray::s![row_start..row_end])
52 .iter()
53 .map(|x| x * x)
54 .sum();
55 row_norm_sq.sqrt().max(1e-8)
56 })
57 .collect();
58 let magnitude = Tensor::from_vec(magnitude_data, true); let lora = LoRALayer::new_with_scaling(base_weight, d_out, d_in, rank, alpha, scaling);
62
63 Self { magnitude, lora, d_out, d_in }
64 }
65
66 pub fn forward(&self, x: &Tensor) -> Tensor {
68 assert_eq!(x.len(), self.d_in, "Input size must match d_in");
69
70 let lora_output = self.lora.forward(x); let row_norms = self.compute_effective_row_norms();
77
78 let mut result = lora_output.data().to_owned();
79 for (i, val) in result.iter_mut().enumerate() {
80 let norm = row_norms[i].max(1e-8);
81 *val = self.magnitude.data()[i] * (*val / norm);
82 }
83
84 Tensor::new(result, self.magnitude.requires_grad())
85 }
86
87 fn compute_effective_row_norms(&self) -> Vec<f32> {
89 let base = self.lora.base_weight().data();
90 let scale = self.lora.scale();
91 let a_data = self.lora.lora_a().data();
92 let b_data = self.lora.lora_b().data();
93 let rank = self.lora.rank();
94
95 let mut norms = vec![0.0f32; self.d_out];
96 for row in 0..self.d_out {
97 let mut row_norm_sq = 0.0f32;
98 for col in 0..self.d_in {
99 let base_val = base[row * self.d_in + col];
100 let mut ba_val = 0.0f32;
102 for r in 0..rank {
103 ba_val += b_data[row * rank + r] * a_data[r * self.d_in + col];
104 }
105 let effective = base_val + scale * ba_val;
106 row_norm_sq += effective * effective;
107 }
108 norms[row] = row_norm_sq.sqrt();
109 }
110 norms
111 }
112
113 pub fn merge_to_f32(&self) -> Vec<f32> {
115 let row_norms = self.compute_effective_row_norms();
116 let base = self.lora.base_weight().data();
117 let scale = self.lora.scale();
118 let a_data = self.lora.lora_a().data();
119 let b_data = self.lora.lora_b().data();
120 let rank = self.lora.rank();
121
122 let mut merged = vec![0.0f32; self.d_out * self.d_in];
123 for row in 0..self.d_out {
124 let m = self.magnitude.data()[row];
125 let norm = row_norms[row].max(1e-8);
126 for col in 0..self.d_in {
127 let base_val = base[row * self.d_in + col];
128 let mut ba_val = 0.0f32;
129 for r in 0..rank {
130 ba_val += b_data[row * rank + r] * a_data[r * self.d_in + col];
131 }
132 merged[row * self.d_in + col] = m * (base_val + scale * ba_val) / norm;
133 }
134 }
135 merged
136 }
137
138 pub fn trainable_params(&mut self) -> Vec<&mut Tensor> {
140 let mut params = vec![&mut self.magnitude];
141 params.extend(self.lora.trainable_params());
142 params
143 }
144
145 pub fn magnitude(&self) -> &Tensor {
147 &self.magnitude
148 }
149
150 pub fn lora(&self) -> &LoRALayer {
152 &self.lora
153 }
154
155 pub fn trainable_param_count(&self) -> usize {
157 self.d_out + self.lora.rank() * self.d_in + self.d_out * self.lora.rank()
158 }
159}
160
161#[cfg(test)]
162#[allow(clippy::unwrap_used)]
163mod tests {
164 use super::*;
165 use approx::assert_abs_diff_eq;
166 use proptest::prelude::*;
167
168 #[test]
169 fn test_ent_lora_011_dora_creation() {
170 let base = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
171 let dora = DoRALayer::new(base, 2, 2, 1, 2.0, LoRAScaling::Standard);
172 assert_eq!(dora.d_out, 2);
173 assert_eq!(dora.d_in, 2);
174 assert!(dora.magnitude().len() == 2);
175 }
176
177 #[test]
178 fn test_ent_lora_011_dora_magnitude_init() {
179 let base = Tensor::from_vec(vec![1.0, 0.0, 0.0, 1.0], false);
181 let dora = DoRALayer::new(base, 2, 2, 1, 2.0, LoRAScaling::Standard);
182 assert_abs_diff_eq!(dora.magnitude().data()[0], 1.0, epsilon = 1e-6);
183 assert_abs_diff_eq!(dora.magnitude().data()[1], 1.0, epsilon = 1e-6);
184 }
185
186 #[test]
187 fn test_ent_lora_011_dora_forward_dimensions() {
188 let base = Tensor::from_vec(vec![1.0; 12], false);
189 let dora = DoRALayer::new(base, 3, 4, 2, 4.0, LoRAScaling::RsLoRA);
190 let x = Tensor::from_vec(vec![0.5; 4], true);
191 let out = dora.forward(&x);
192 assert_eq!(out.len(), 3);
193 }
194
195 #[test]
196 fn test_ent_lora_011_dora_trainable_count() {
197 let base = Tensor::from_vec(vec![1.0; 16], false);
198 let dora = DoRALayer::new(base, 4, 4, 2, 4.0, LoRAScaling::Standard);
199 assert_eq!(dora.trainable_param_count(), 20);
201 }
202
203 #[test]
204 fn test_ent_lora_011_dora_merge_dimensions() {
205 let base = Tensor::from_vec(vec![1.0; 12], false);
206 let dora = DoRALayer::new(base, 3, 4, 2, 4.0, LoRAScaling::Standard);
207 let merged = dora.merge_to_f32();
208 assert_eq!(merged.len(), 12);
209 }
210
211 #[test]
212 fn test_ent_lora_011_dora_trainable_params() {
213 let base = Tensor::from_vec(vec![1.0; 16], false);
214 let mut dora = DoRALayer::new(base, 4, 4, 2, 4.0, LoRAScaling::Standard);
215 let params = dora.trainable_params();
216 assert_eq!(params.len(), 3);
218 }
219
220 proptest! {
221 #![proptest_config(proptest::test_runner::Config::with_cases(50))]
222
223 #[test]
224 fn prop_dora_forward_finite(
225 d_out in 2usize..8,
226 d_in in 2usize..8,
227 rank in 1usize..4,
228 ) {
229 let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
230 let dora = DoRALayer::new(base, d_out, d_in, rank, 4.0, LoRAScaling::Standard);
231 let x = Tensor::from_vec(vec![0.1; d_in], true);
232 let out = dora.forward(&x);
233 prop_assert_eq!(out.len(), d_out);
234 for val in out.data() {
235 prop_assert!(val.is_finite(), "Output must be finite, got {val}");
236 }
237 }
238
239 #[test]
240 fn prop_dora_merge_finite(
241 d_out in 2usize..8,
242 d_in in 2usize..8,
243 rank in 1usize..4,
244 ) {
245 let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
246 let dora = DoRALayer::new(base, d_out, d_in, rank, 4.0, LoRAScaling::Standard);
247 let merged = dora.merge_to_f32();
248 prop_assert_eq!(merged.len(), d_out * d_in);
249 for val in &merged {
250 prop_assert!(val.is_finite());
251 }
252 }
253 }
254}