use crate::lora::LoRALayer;
use crate::Tensor;
pub fn pissa_init(
base_weight: &Tensor,
d_out: usize,
d_in: usize,
rank: usize,
alpha: f32,
) -> LoRALayer {
assert_eq!(base_weight.len(), d_out * d_in);
assert!(rank <= d_out.min(d_in), "Rank must be <= min(d_out, d_in)");
let (u_r, s_r, v_r) =
truncated_svd(base_weight.data().as_slice().expect("contiguous"), d_out, d_in, rank);
let mut a_data = vec![0.0f32; rank * d_in];
for r in 0..rank {
let sqrt_s = s_r[r].sqrt();
for j in 0..d_in {
a_data[r * d_in + j] = sqrt_s * v_r[r * d_in + j];
}
}
let mut b_data = vec![0.0f32; d_out * rank];
for i in 0..d_out {
for r in 0..rank {
let sqrt_s = s_r[r].sqrt();
b_data[i * rank + r] = u_r[i * rank + r] * sqrt_s;
}
}
let scale = alpha / rank as f32;
let mut residual = base_weight.data().to_vec();
for i in 0..d_out {
for j in 0..d_in {
let mut reconstruction = 0.0f32;
for r in 0..rank {
reconstruction += u_r[i * rank + r] * s_r[r] * v_r[r * d_in + j];
}
residual[i * d_in + j] -= scale * reconstruction;
}
}
let residual_tensor = Tensor::from_vec(residual, false);
let mut layer = LoRALayer::new(residual_tensor, d_out, d_in, rank, alpha);
*layer.lora_a_mut().data_mut() = ndarray::arr1(&a_data);
*layer.lora_b_mut().data_mut() = ndarray::arr1(&b_data);
layer
}
fn truncated_svd(
w: &[f32],
d_out: usize,
d_in: usize,
rank: usize,
) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
let iterations = 20;
let mut u_r = vec![0.0f32; d_out * rank];
let mut s_r = vec![0.0f32; rank];
let mut v_r = vec![0.0f32; rank * d_in];
let mut w_residual = w.to_vec();
for r in 0..rank {
let mut v: Vec<f32> =
(0..d_in).map(|i| ((i as f32 * 0.7 + r as f32 * 1.3).sin())).collect();
normalize(&mut v);
let mut u = vec![0.0f32; d_out];
let mut sigma = 0.0f32;
for _ in 0..iterations {
mat_vec_mul(&w_residual, &v, &mut u, d_out, d_in);
sigma = norm(&u).max(1e-10);
for val in &mut u {
*val /= sigma;
}
mat_t_vec_mul(&w_residual, &u, &mut v, d_out, d_in);
let v_norm = norm(&v).max(1e-10);
for val in &mut v {
*val /= v_norm;
}
}
for i in 0..d_out {
u_r[i * rank + r] = u[i];
}
s_r[r] = sigma;
for j in 0..d_in {
v_r[r * d_in + j] = v[j];
}
for i in 0..d_out {
for j in 0..d_in {
w_residual[i * d_in + j] -= sigma * u[i] * v[j];
}
}
}
(u_r, s_r, v_r)
}
fn mat_vec_mul(w: &[f32], v: &[f32], out: &mut [f32], rows: usize, cols: usize) {
for i in 0..rows {
let mut sum = 0.0f32;
for j in 0..cols {
sum += w[i * cols + j] * v[j];
}
out[i] = sum;
}
}
fn mat_t_vec_mul(w: &[f32], u: &[f32], out: &mut [f32], rows: usize, cols: usize) {
for j in 0..cols {
let mut sum = 0.0f32;
for i in 0..rows {
sum += w[i * cols + j] * u[i];
}
out[j] = sum;
}
}
fn norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
fn normalize(v: &mut [f32]) {
let n = norm(v).max(1e-10);
for val in v.iter_mut() {
*val /= n;
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
use proptest::prelude::*;
#[test]
fn test_ent_lora_012_pissa_init_dimensions() {
let base = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], false);
let layer = pissa_init(&base, 2, 3, 1, 2.0);
assert_eq!(layer.d_out(), 2);
assert_eq!(layer.d_in(), 3);
assert_eq!(layer.rank(), 1);
assert_eq!(layer.lora_a().len(), 3);
assert_eq!(layer.lora_b().len(), 2);
}
#[test]
fn test_ent_lora_012_pissa_nonzero_init() {
let base = Tensor::from_vec(vec![1.0, 0.5, 0.5, 1.0], false);
let layer = pissa_init(&base, 2, 2, 1, 2.0);
let b_norm: f32 = layer.lora_b().data().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(b_norm > 0.01, "PiSSA B should be non-zero, got norm={b_norm}");
let a_norm: f32 = layer.lora_a().data().iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(a_norm > 0.01, "PiSSA A should be non-zero, got norm={a_norm}");
}
#[test]
fn test_ent_lora_012_pissa_reconstruction_close() {
let d_out = 4;
let d_in = 4;
let base_data: Vec<f32> = (0..d_out * d_in).map(|i| (i as f32 * 0.3).sin()).collect();
let base = Tensor::from_vec(base_data.clone(), false);
let layer = pissa_init(&base, d_out, d_in, 2, 2.0);
let scale = layer.scale();
let residual = layer.base_weight().data();
let a = layer.lora_a().data();
let b = layer.lora_b().data();
let rank = layer.rank();
for i in 0..d_out {
for j in 0..d_in {
let mut ba = 0.0f32;
for r in 0..rank {
ba += b[i * rank + r] * a[r * d_in + j];
}
let reconstructed = residual[i * d_in + j] + scale * ba;
assert_abs_diff_eq!(base_data[i * d_in + j], reconstructed, epsilon = 0.3);
}
}
}
#[test]
fn test_ent_lora_012_pissa_forward_works() {
let base = Tensor::from_vec(vec![1.0; 16], false);
let layer = pissa_init(&base, 4, 4, 2, 4.0);
let x = Tensor::from_vec(vec![0.5; 4], true);
let out = layer.forward(&x);
assert_eq!(out.len(), 4);
for val in out.data() {
assert!(val.is_finite());
}
}
#[test]
fn test_ent_lora_012_truncated_svd_singular_values_descending() {
let w: Vec<f32> = (0..24).map(|i| (i as f32 * 0.2).sin()).collect();
let (_, s, _) = truncated_svd(&w, 4, 6, 3);
for i in 1..s.len() {
assert!(
s[i - 1] >= s[i] - 1e-4,
"Singular values should descend: s[{}]={} < s[{}]={}",
i - 1,
s[i - 1],
i,
s[i]
);
}
}
#[test]
fn test_ent_lora_012_truncated_svd_orthogonal_u() {
let w: Vec<f32> = (0..24).map(|i| (i as f32 * 0.3).cos()).collect();
let (u, _, _) = truncated_svd(&w, 4, 6, 2);
let mut dot = 0.0f32;
for i in 0..4 {
dot += u[i * 2] * u[i * 2 + 1];
}
assert!(dot.abs() < 0.15, "U columns should be ~orthogonal, dot={dot}");
}
proptest! {
#![proptest_config(proptest::test_runner::Config::with_cases(30))]
#[test]
fn prop_pissa_forward_finite(
d_out in 2usize..8,
d_in in 2usize..8,
) {
let rank = 1.min(d_out.min(d_in));
let base = Tensor::from_vec(vec![0.5; d_out * d_in], false);
let layer = pissa_init(&base, d_out, d_in, rank, 4.0);
let x = Tensor::from_vec(vec![0.1; d_in], true);
let out = layer.forward(&x);
prop_assert_eq!(out.len(), d_out);
for val in out.data() {
prop_assert!(val.is_finite());
}
}
}
}