1use crate::lora::LoRALayer;
14use crate::Tensor;
15
16pub fn pissa_init(
27 base_weight: &Tensor,
28 d_out: usize,
29 d_in: usize,
30 rank: usize,
31 alpha: f32,
32) -> LoRALayer {
33 assert_eq!(base_weight.len(), d_out * d_in);
34 assert!(rank <= d_out.min(d_in), "Rank must be <= min(d_out, d_in)");
35
36 let (u_r, s_r, v_r) =
38 truncated_svd(base_weight.data().as_slice().expect("contiguous"), d_out, d_in, rank);
39
40 let mut a_data = vec![0.0f32; rank * d_in];
42 for r in 0..rank {
43 let sqrt_s = s_r[r].sqrt();
44 for j in 0..d_in {
45 a_data[r * d_in + j] = sqrt_s * v_r[r * d_in + j];
46 }
47 }
48
49 let mut b_data = vec![0.0f32; d_out * rank];
51 for i in 0..d_out {
52 for r in 0..rank {
53 let sqrt_s = s_r[r].sqrt();
54 b_data[i * rank + r] = u_r[i * rank + r] * sqrt_s;
55 }
56 }
57
58 let scale = alpha / rank as f32;
60 let mut residual = base_weight.data().to_vec();
61 for i in 0..d_out {
62 for j in 0..d_in {
63 let mut reconstruction = 0.0f32;
64 for r in 0..rank {
65 reconstruction += u_r[i * rank + r] * s_r[r] * v_r[r * d_in + j];
66 }
67 residual[i * d_in + j] -= scale * reconstruction;
70 }
71 }
72
73 let residual_tensor = Tensor::from_vec(residual, false);
74 let mut layer = LoRALayer::new(residual_tensor, d_out, d_in, rank, alpha);
75
76 *layer.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
78 *layer.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
79
80 layer
81}
82
83fn truncated_svd(
90 w: &[f32],
91 d_out: usize,
92 d_in: usize,
93 rank: usize,
94) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
95 let iterations = 20;
96 let mut u_r = vec![0.0f32; d_out * rank];
97 let mut s_r = vec![0.0f32; rank];
98 let mut v_r = vec![0.0f32; rank * d_in];
99
100 let mut w_residual = w.to_vec();
102
103 for r in 0..rank {
104 let mut v: Vec<f32> = (0..d_in).map(|i| (i as f32 * 0.7 + r as f32 * 1.3).sin()).collect();
106 normalize(&mut v);
107
108 let mut u = vec![0.0f32; d_out];
109 let mut sigma = 0.0f32;
110
111 for _ in 0..iterations {
112 mat_vec_mul(&w_residual, &v, &mut u, d_out, d_in);
114 sigma = norm(&u).max(1e-10);
115 for val in &mut u {
116 *val /= sigma;
117 }
118
119 mat_t_vec_mul(&w_residual, &u, &mut v, d_out, d_in);
121 let v_norm = norm(&v).max(1e-10);
122 for val in &mut v {
123 *val /= v_norm;
124 }
125 }
126
127 for i in 0..d_out {
129 u_r[i * rank + r] = u[i];
130 }
131 s_r[r] = sigma;
132 for j in 0..d_in {
133 v_r[r * d_in + j] = v[j];
134 }
135
136 for i in 0..d_out {
138 for j in 0..d_in {
139 w_residual[i * d_in + j] -= sigma * u[i] * v[j];
140 }
141 }
142 }
143
144 (u_r, s_r, v_r)
145}
146
147fn mat_vec_mul(w: &[f32], v: &[f32], out: &mut [f32], rows: usize, cols: usize) {
148 for i in 0..rows {
149 let mut sum = 0.0f32;
150 for j in 0..cols {
151 sum += w[i * cols + j] * v[j];
152 }
153 out[i] = sum;
154 }
155}
156
157fn mat_t_vec_mul(w: &[f32], u: &[f32], out: &mut [f32], rows: usize, cols: usize) {
158 for j in 0..cols {
159 let mut sum = 0.0f32;
160 for i in 0..rows {
161 sum += w[i * cols + j] * u[i];
162 }
163 out[j] = sum;
164 }
165}
166
167fn norm(v: &[f32]) -> f32 {
168 v.iter().map(|x| x * x).sum::<f32>().sqrt()
169}
170
171fn normalize(v: &mut [f32]) {
172 let n = norm(v).max(1e-10);
173 for val in v.iter_mut() {
174 *val /= n;
175 }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used)]
180mod tests {
181 use super::*;
182 use approx::assert_abs_diff_eq;
183 use proptest::prelude::*;
184
185 #[test]
186 fn test_ent_lora_012_pissa_init_dimensions() {
187 let base = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
188 let layer = pissa_init(&base, 2, 3, 1, 2.0);
189 assert_eq!(layer.d_out(), 2);
190 assert_eq!(layer.d_in(), 3);
191 assert_eq!(layer.rank(), 1);
192 assert_eq!(layer.lora_a().len(), 3);
193 assert_eq!(layer.lora_b().len(), 2);
194 }
195
196 #[test]
197 fn test_ent_lora_012_pissa_nonzero_init() {
198 let base = Tensor::from_vec(vec![1.0, 0.5, 0.5, 1.0], false);
200 let layer = pissa_init(&base, 2, 2, 1, 2.0);
201
202 let b_norm: f32 = layer.lora_b().data().iter().map(|x| x * x).sum::<f32>().sqrt();
204 assert!(b_norm > 0.01, "PiSSA B should be non-zero, got norm={b_norm}");
205
206 let a_norm: f32 = layer.lora_a().data().iter().map(|x| x * x).sum::<f32>().sqrt();
207 assert!(a_norm > 0.01, "PiSSA A should be non-zero, got norm={a_norm}");
208 }
209
210 #[test]
211 fn test_ent_lora_012_pissa_reconstruction_close() {
212 let d_out = 4;
214 let d_in = 4;
215 let base_data: Vec<f32> = (0..d_out * d_in).map(|i| (i as f32 * 0.3).sin()).collect();
216 let base = Tensor::from_vec(base_data.clone(), false);
217 let layer = pissa_init(&base, d_out, d_in, 2, 2.0);
218
219 let scale = layer.scale();
221 let residual = layer.base_weight().data();
222 let a = layer.lora_a().data();
223 let b = layer.lora_b().data();
224 let rank = layer.rank();
225
226 for i in 0..d_out {
227 for j in 0..d_in {
228 let mut ba = 0.0f32;
229 for r in 0..rank {
230 ba += b[i * rank + r] * a[r * d_in + j];
231 }
232 let reconstructed = residual[i * d_in + j] + scale * ba;
233 assert_abs_diff_eq!(base_data[i * d_in + j], reconstructed, epsilon = 0.3);
234 }
235 }
236 }
237
238 #[test]
239 fn test_ent_lora_012_pissa_forward_works() {
240 let base = Tensor::from_vec(vec![1.0; 16], false);
241 let layer = pissa_init(&base, 4, 4, 2, 4.0);
242 let x = Tensor::from_vec(vec![0.5; 4], true);
243 let out = layer.forward(&x);
244 assert_eq!(out.len(), 4);
245 for val in out.data() {
246 assert!(val.is_finite());
247 }
248 }
249
250 #[test]
251 fn test_ent_lora_012_truncated_svd_singular_values_descending() {
252 let w: Vec<f32> = (0..24).map(|i| (i as f32 * 0.2).sin()).collect();
253 let (_, s, _) = truncated_svd(&w, 4, 6, 3);
254
255 for i in 1..s.len() {
256 assert!(
257 s[i - 1] >= s[i] - 1e-4,
258 "Singular values should descend: s[{}]={} < s[{}]={}",
259 i - 1,
260 s[i - 1],
261 i,
262 s[i]
263 );
264 }
265 }
266
267 #[test]
268 fn test_ent_lora_012_truncated_svd_orthogonal_u() {
269 let w: Vec<f32> = (0..24).map(|i| (i as f32 * 0.3).cos()).collect();
270 let (u, _, _) = truncated_svd(&w, 4, 6, 2);
271
272 let mut dot = 0.0f32;
275 for i in 0..4 {
276 dot += u[i * 2] * u[i * 2 + 1];
277 }
278 assert!(dot.abs() < 0.15, "U columns should be ~orthogonal, dot={dot}");
279 }
280
281 proptest! {
282 #![proptest_config(proptest::test_runner::Config::with_cases(30))]
283
284 #[test]
285 fn prop_pissa_forward_finite(
286 d_out in 2usize..8,
287 d_in in 2usize..8,
288 ) {
289 let rank = 1.min(d_out.min(d_in));
290 let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
291 let layer = pissa_init(&base, d_out, d_in, rank, 4.0);
292 let x = Tensor::from_vec(vec![0.1; d_in], true);
293 let out = layer.forward(&x);
294 prop_assert_eq!(out.len(), d_out);
295 for val in out.data() {
296 prop_assert!(val.is_finite());
297 }
298 }
299 }
300}