Skip to main content

entrenar/lora/
pissa.rs

1//! PiSSA (Principal Singular Values and Singular Vectors Adaptation) — ENT-LoRA-012
2//!
3//! PiSSA initializes LoRA A and B from the top-r singular components of the base weight,
4//! achieving faster convergence (+5% on some benchmarks).
5//!
6//! Standard LoRA: A ~ N(0, σ²), B = 0 → ΔW = 0 at init
7//! PiSSA: SVD(W) = U·S·V^T → A = sqrt(S_r)·V_r^T, B = U_r·sqrt(S_r) → ΔW = U_r·S_r·V_r^T
8//! Residual: W_residual = W - U_r·S_r·V_r^T (frozen base becomes the residual)
9//!
10//! Reference: Meng et al. (2024). "PiSSA: Principal Singular Values and Singular Vectors
11//! Adaptation." NeurIPS 2024 Spotlight.
12
13use crate::lora::LoRALayer;
14use crate::Tensor;
15
16/// Initialize a LoRA layer using PiSSA (SVD-based initialization)
17///
18/// Returns a LoRALayer where:
19/// - `base_weight` is the residual (W - U_r·S_r·V_r^T)
20/// - `lora_a` = sqrt(S_r) · V_r^T [rank × d_in]
21/// - `lora_b` = U_r · sqrt(S_r) [d_out × rank]
22///
23/// The key insight: instead of starting from ΔW=0, PiSSA starts from the principal
24/// components, so the residual base weight has lower effective rank and LoRA adapters
25/// start from a better initialization.
26pub fn pissa_init(
27    base_weight: &Tensor,
28    d_out: usize,
29    d_in: usize,
30    rank: usize,
31    alpha: f32,
32) -> LoRALayer {
33    assert_eq!(base_weight.len(), d_out * d_in);
34    assert!(rank <= d_out.min(d_in), "Rank must be <= min(d_out, d_in)");
35
36    // Truncated SVD via power iteration
37    let (u_r, s_r, v_r) =
38        truncated_svd(base_weight.data().as_slice().expect("contiguous"), d_out, d_in, rank);
39
40    // Compute A = sqrt(S_r) · V_r^T [rank × d_in]
41    let mut a_data = vec![0.0f32; rank * d_in];
42    for r in 0..rank {
43        let sqrt_s = s_r[r].sqrt();
44        for j in 0..d_in {
45            a_data[r * d_in + j] = sqrt_s * v_r[r * d_in + j];
46        }
47    }
48
49    // Compute B = U_r · sqrt(S_r) [d_out × rank]
50    let mut b_data = vec![0.0f32; d_out * rank];
51    for i in 0..d_out {
52        for r in 0..rank {
53            let sqrt_s = s_r[r].sqrt();
54            b_data[i * rank + r] = u_r[i * rank + r] * sqrt_s;
55        }
56    }
57
58    // Compute residual: W_res = W - U_r · S_r · V_r^T
59    let scale = alpha / rank as f32;
60    let mut residual = base_weight.data().to_vec();
61    for i in 0..d_out {
62        for j in 0..d_in {
63            let mut reconstruction = 0.0f32;
64            for r in 0..rank {
65                reconstruction += u_r[i * rank + r] * s_r[r] * v_r[r * d_in + j];
66            }
67            // Adjust: the LoRA contribution will be scale * B @ A = scale * U_r·S_r·V_r^T
68            // So we subtract scale * reconstruction from base
69            residual[i * d_in + j] -= scale * reconstruction;
70        }
71    }
72
73    let residual_tensor = Tensor::from_vec(residual, false);
74    let mut layer = LoRALayer::new(residual_tensor, d_out, d_in, rank, alpha);
75
76    // Override the default random init with PiSSA init
77    *layer.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
78    *layer.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
79
80    layer
81}
82
83/// Truncated SVD via power iteration method
84///
85/// Returns (U_r, S_r, V_r) where:
86/// - U_r: [d_out × rank] left singular vectors (column-major stored as row-major)
87/// - S_r: [rank] singular values (descending)
88/// - V_r: [rank × d_in] right singular vectors
89fn truncated_svd(
90    w: &[f32],
91    d_out: usize,
92    d_in: usize,
93    rank: usize,
94) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
95    let iterations = 20;
96    let mut u_r = vec![0.0f32; d_out * rank];
97    let mut s_r = vec![0.0f32; rank];
98    let mut v_r = vec![0.0f32; rank * d_in];
99
100    // Work on a copy so we can deflate
101    let mut w_residual = w.to_vec();
102
103    for r in 0..rank {
104        // Initialize random vector v
105        let mut v: Vec<f32> = (0..d_in).map(|i| (i as f32 * 0.7 + r as f32 * 1.3).sin()).collect();
106        normalize(&mut v);
107
108        let mut u = vec![0.0f32; d_out];
109        let mut sigma = 0.0f32;
110
111        for _ in 0..iterations {
112            // u = W @ v
113            mat_vec_mul(&w_residual, &v, &mut u, d_out, d_in);
114            sigma = norm(&u).max(1e-10);
115            for val in &mut u {
116                *val /= sigma;
117            }
118
119            // v = W^T @ u
120            mat_t_vec_mul(&w_residual, &u, &mut v, d_out, d_in);
121            let v_norm = norm(&v).max(1e-10);
122            for val in &mut v {
123                *val /= v_norm;
124            }
125        }
126
127        // Store results
128        for i in 0..d_out {
129            u_r[i * rank + r] = u[i];
130        }
131        s_r[r] = sigma;
132        for j in 0..d_in {
133            v_r[r * d_in + j] = v[j];
134        }
135
136        // Deflate: W_residual -= sigma * u * v^T
137        for i in 0..d_out {
138            for j in 0..d_in {
139                w_residual[i * d_in + j] -= sigma * u[i] * v[j];
140            }
141        }
142    }
143
144    (u_r, s_r, v_r)
145}
146
147fn mat_vec_mul(w: &[f32], v: &[f32], out: &mut [f32], rows: usize, cols: usize) {
148    for i in 0..rows {
149        let mut sum = 0.0f32;
150        for j in 0..cols {
151            sum += w[i * cols + j] * v[j];
152        }
153        out[i] = sum;
154    }
155}
156
157fn mat_t_vec_mul(w: &[f32], u: &[f32], out: &mut [f32], rows: usize, cols: usize) {
158    for j in 0..cols {
159        let mut sum = 0.0f32;
160        for i in 0..rows {
161            sum += w[i * cols + j] * u[i];
162        }
163        out[j] = sum;
164    }
165}
166
167fn norm(v: &[f32]) -> f32 {
168    v.iter().map(|x| x * x).sum::<f32>().sqrt()
169}
170
171fn normalize(v: &mut [f32]) {
172    let n = norm(v).max(1e-10);
173    for val in v.iter_mut() {
174        *val /= n;
175    }
176}
177
178#[cfg(test)]
179#[allow(clippy::unwrap_used)]
180mod tests {
181    use super::*;
182    use approx::assert_abs_diff_eq;
183    use proptest::prelude::*;
184
185    #[test]
186    fn test_ent_lora_012_pissa_init_dimensions() {
187        let base = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
188        let layer = pissa_init(&base, 2, 3, 1, 2.0);
189        assert_eq!(layer.d_out(), 2);
190        assert_eq!(layer.d_in(), 3);
191        assert_eq!(layer.rank(), 1);
192        assert_eq!(layer.lora_a().len(), 3);
193        assert_eq!(layer.lora_b().len(), 2);
194    }
195
196    #[test]
197    fn test_ent_lora_012_pissa_nonzero_init() {
198        // Unlike standard LoRA where B=0, PiSSA initializes both A and B from SVD
199        let base = Tensor::from_vec(vec![1.0, 0.5, 0.5, 1.0], false);
200        let layer = pissa_init(&base, 2, 2, 1, 2.0);
201
202        // B should be non-zero (unlike standard LoRA)
203        let b_norm: f32 = layer.lora_b().data().iter().map(|x| x * x).sum::<f32>().sqrt();
204        assert!(b_norm > 0.01, "PiSSA B should be non-zero, got norm={b_norm}");
205
206        let a_norm: f32 = layer.lora_a().data().iter().map(|x| x * x).sum::<f32>().sqrt();
207        assert!(a_norm > 0.01, "PiSSA A should be non-zero, got norm={a_norm}");
208    }
209
210    #[test]
211    fn test_ent_lora_012_pissa_reconstruction_close() {
212        // W ≈ residual + scale * B @ A
213        let d_out = 4;
214        let d_in = 4;
215        let base_data: Vec<f32> = (0..d_out * d_in).map(|i| (i as f32 * 0.3).sin()).collect();
216        let base = Tensor::from_vec(base_data.clone(), false);
217        let layer = pissa_init(&base, d_out, d_in, 2, 2.0);
218
219        // Compute reconstruction: residual + scale * B @ A
220        let scale = layer.scale();
221        let residual = layer.base_weight().data();
222        let a = layer.lora_a().data();
223        let b = layer.lora_b().data();
224        let rank = layer.rank();
225
226        for i in 0..d_out {
227            for j in 0..d_in {
228                let mut ba = 0.0f32;
229                for r in 0..rank {
230                    ba += b[i * rank + r] * a[r * d_in + j];
231                }
232                let reconstructed = residual[i * d_in + j] + scale * ba;
233                assert_abs_diff_eq!(base_data[i * d_in + j], reconstructed, epsilon = 0.3);
234            }
235        }
236    }
237
238    #[test]
239    fn test_ent_lora_012_pissa_forward_works() {
240        let base = Tensor::from_vec(vec![1.0; 16], false);
241        let layer = pissa_init(&base, 4, 4, 2, 4.0);
242        let x = Tensor::from_vec(vec![0.5; 4], true);
243        let out = layer.forward(&x);
244        assert_eq!(out.len(), 4);
245        for val in out.data() {
246            assert!(val.is_finite());
247        }
248    }
249
250    #[test]
251    fn test_ent_lora_012_truncated_svd_singular_values_descending() {
252        let w: Vec<f32> = (0..24).map(|i| (i as f32 * 0.2).sin()).collect();
253        let (_, s, _) = truncated_svd(&w, 4, 6, 3);
254
255        for i in 1..s.len() {
256            assert!(
257                s[i - 1] >= s[i] - 1e-4,
258                "Singular values should descend: s[{}]={} < s[{}]={}",
259                i - 1,
260                s[i - 1],
261                i,
262                s[i]
263            );
264        }
265    }
266
267    #[test]
268    fn test_ent_lora_012_truncated_svd_orthogonal_u() {
269        let w: Vec<f32> = (0..24).map(|i| (i as f32 * 0.3).cos()).collect();
270        let (u, _, _) = truncated_svd(&w, 4, 6, 2);
271
272        // Check approximate orthogonality of U columns
273        // U is stored as [d_out × rank], column r is u[i*rank + r]
274        let mut dot = 0.0f32;
275        for i in 0..4 {
276            dot += u[i * 2] * u[i * 2 + 1];
277        }
278        assert!(dot.abs() < 0.15, "U columns should be ~orthogonal, dot={dot}");
279    }
280
281    proptest! {
282        #![proptest_config(proptest::test_runner::Config::with_cases(30))]
283
284        #[test]
285        fn prop_pissa_forward_finite(
286            d_out in 2usize..8,
287            d_in in 2usize..8,
288        ) {
289            let rank = 1.min(d_out.min(d_in));
290            let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
291            let layer = pissa_init(&base, d_out, d_in, rank, 4.0);
292            let x = Tensor::from_vec(vec![0.1; d_in], true);
293            let out = layer.forward(&x);
294            prop_assert_eq!(out.len(), d_out);
295            for val in out.data() {
296                prop_assert!(val.is_finite());
297            }
298        }
299    }
300}