use scirs2_core::ndarray_ext::Array2;
use scirs2_core::random::rand_prelude::StdRng;
use scirs2_core::random::{seeded_rng, CoreRandom};
#[derive(Debug, Clone)]
pub struct LoraAdapter {
pub rank: usize,
pub alpha: f64,
pub a_matrix: Array2<f64>,
pub b_matrix: Array2<f64>,
pub d_in: usize,
pub d_out: usize,
grad_a: Array2<f64>,
grad_b: Array2<f64>,
}
impl LoraAdapter {
pub fn new(d_in: usize, d_out: usize, rank: usize, alpha: f64, seed: u64) -> Self {
assert!(rank > 0, "LoRA rank must be at least 1");
assert!(d_in > 0 && d_out > 0, "dimensions must be non-zero");
let mut rng: CoreRandom<StdRng> = seeded_rng(seed);
let limit = (6.0_f64 / (d_in + rank) as f64).sqrt();
let a_data: Vec<f64> = (0..d_in * rank)
.map(|_| {
let u = rng.random_range(0.0_f64..1.0_f64);
u * 2.0 * limit - limit
})
.collect();
let a_matrix = Array2::from_shape_vec((d_in, rank), a_data)
.expect("a_matrix shape is consistent by construction");
let b_matrix = Array2::zeros((rank, d_out));
let grad_a = Array2::zeros((d_in, rank));
let grad_b = Array2::zeros((rank, d_out));
Self {
rank,
alpha,
a_matrix,
b_matrix,
d_in,
d_out,
grad_a,
grad_b,
}
}
#[inline]
pub fn scale(&self) -> f64 {
self.alpha / self.rank as f64
}
pub fn forward_delta(&self, input: &Array2<f64>) -> Array2<f64> {
let batch = input.nrows();
debug_assert_eq!(
input.ncols(),
self.d_in,
"input column count must equal d_in"
);
let mut z = Array2::zeros((batch, self.rank));
for i in 0..batch {
for k in 0..self.rank {
let mut sum = 0.0_f64;
for j in 0..self.d_in {
sum += input[[i, j]] * self.a_matrix[[j, k]];
}
z[[i, k]] = sum;
}
}
let mut delta = Array2::zeros((batch, self.d_out));
for i in 0..batch {
for m in 0..self.d_out {
let mut sum = 0.0_f64;
for k in 0..self.rank {
sum += z[[i, k]] * self.b_matrix[[k, m]];
}
delta[[i, m]] = sum * self.scale();
}
}
delta
}
pub fn backward(&mut self, input: &Array2<f64>, d_output: &Array2<f64>) -> Array2<f64> {
let batch = input.nrows();
debug_assert_eq!(input.ncols(), self.d_in);
debug_assert_eq!(d_output.nrows(), batch);
debug_assert_eq!(d_output.ncols(), self.d_out);
let s = self.scale();
let mut z = Array2::zeros((batch, self.rank));
for i in 0..batch {
for k in 0..self.rank {
let mut sum = 0.0_f64;
for j in 0..self.d_in {
sum += input[[i, j]] * self.a_matrix[[j, k]];
}
z[[i, k]] = sum;
}
}
let mut d_z = Array2::zeros((batch, self.rank));
for i in 0..batch {
for k in 0..self.rank {
let mut sum = 0.0_f64;
for m in 0..self.d_out {
sum += d_output[[i, m]] * self.b_matrix[[k, m]];
}
d_z[[i, k]] = sum * s;
}
}
for k in 0..self.rank {
for m in 0..self.d_out {
let mut sum = 0.0_f64;
for i in 0..batch {
sum += z[[i, k]] * d_output[[i, m]];
}
self.grad_b[[k, m]] += sum * s;
}
}
for j in 0..self.d_in {
for k in 0..self.rank {
let mut sum = 0.0_f64;
for i in 0..batch {
sum += input[[i, j]] * d_z[[i, k]];
}
self.grad_a[[j, k]] += sum;
}
}
let mut d_input = Array2::zeros((batch, self.d_in));
for i in 0..batch {
for j in 0..self.d_in {
let mut sum = 0.0_f64;
for k in 0..self.rank {
sum += d_z[[i, k]] * self.a_matrix[[j, k]];
}
d_input[[i, j]] = sum;
}
}
d_input
}
pub fn sgd_step(&mut self, learning_rate: f64) {
for j in 0..self.d_in {
for k in 0..self.rank {
self.a_matrix[[j, k]] -= learning_rate * self.grad_a[[j, k]];
}
}
for k in 0..self.rank {
for m in 0..self.d_out {
self.b_matrix[[k, m]] -= learning_rate * self.grad_b[[k, m]];
}
}
}
pub fn zero_grad(&mut self) {
for v in self.grad_a.iter_mut() {
*v = 0.0;
}
for v in self.grad_b.iter_mut() {
*v = 0.0;
}
}
pub fn grad_norm(&self) -> f64 {
let sq_sum: f64 = self
.grad_a
.iter()
.chain(self.grad_b.iter())
.map(|&v| v * v)
.sum();
sq_sum.sqrt()
}
}
pub struct LoraTrainer {
lora: LoraAdapter,
learning_rate: f64,
}
impl LoraTrainer {
pub fn new(lora: LoraAdapter, learning_rate: f64) -> Self {
Self {
lora,
learning_rate,
}
}
pub fn train_epoch(&mut self, base_output: &Array2<f64>, targets: &Array2<f64>) -> f64 {
assert_eq!(
self.lora.d_in, self.lora.d_out,
"LoraTrainer::train_epoch requires d_in == d_out; got d_in={}, d_out={}",
self.lora.d_in, self.lora.d_out,
);
assert_eq!(
base_output.ncols(),
self.lora.d_in,
"base_output column count must equal adapter d_in ({}), got {}",
self.lora.d_in,
base_output.ncols(),
);
let batch = base_output.nrows();
let d_out = base_output.ncols();
assert_eq!(
targets.nrows(),
batch,
"targets row count must match base_output batch size"
);
assert_eq!(
targets.ncols(),
d_out,
"targets column count must match d_out"
);
let input = base_output;
let delta = self.lora.forward_delta(input);
let mut total_loss = 0.0_f64;
let scale = (batch * d_out).max(1) as f64;
let mut d_output = Array2::zeros((batch, d_out));
for i in 0..batch {
for j in 0..d_out {
let out_val = base_output[[i, j]] + delta[[i, j]];
let target_val = targets[[i, j]];
let diff = out_val - target_val;
total_loss += diff * diff;
d_output[[i, j]] = 2.0 * diff / scale;
}
}
self.lora.zero_grad();
self.lora.backward(input, &d_output);
self.lora.sgd_step(self.learning_rate);
total_loss / scale
}
pub fn adapter(&self) -> &LoraAdapter {
&self.lora
}
pub fn into_adapter(self) -> LoraAdapter {
self.lora
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray_ext::Array2;
fn make_adapter(d_in: usize, d_out: usize, rank: usize) -> LoraAdapter {
LoraAdapter::new(d_in, d_out, rank, rank as f64, 42)
}
#[test]
fn test_new_b_matrix_is_zero() {
let adapter = make_adapter(4, 6, 2);
for &v in adapter.b_matrix.iter() {
assert_eq!(v, 0.0, "B should be zero-initialised");
}
}
#[test]
fn test_forward_delta_zero_ab_gives_zero() {
let mut adapter = make_adapter(4, 6, 2);
for v in adapter.a_matrix.iter_mut() {
*v = 0.0;
}
let input = Array2::from_elem((3, 4), 1.0);
let delta = adapter.forward_delta(&input);
for &v in delta.iter() {
assert!(
v.abs() < 1e-14,
"delta must be zero when A=0 and B=0, got {v}"
);
}
}
#[test]
fn test_forward_delta_shape() {
let adapter = make_adapter(8, 12, 3);
let input = Array2::zeros((5, 8));
let delta = adapter.forward_delta(&input);
assert_eq!(delta.nrows(), 5, "batch dimension must be preserved");
assert_eq!(delta.ncols(), 12, "column count must equal d_out");
}
#[test]
fn test_backward_d_input_shape() {
let mut adapter = make_adapter(8, 12, 3);
let input = Array2::from_elem((5, 8), 0.1);
let d_output = Array2::from_elem((5, 12), 0.01);
let d_input = adapter.backward(&input, &d_output);
assert_eq!(d_input.nrows(), 5, "d_input batch must match");
assert_eq!(d_input.ncols(), 8, "d_input columns must equal d_in");
}
#[test]
fn test_sgd_step_updates_matrices() {
let mut adapter = make_adapter(4, 6, 2);
let a_before = adapter.a_matrix.clone();
let input = Array2::from_elem((2, 4), 1.0);
let d_output = Array2::from_elem((2, 6), 0.1);
adapter.backward(&input, &d_output);
let b_before = adapter.b_matrix.clone();
adapter.sgd_step(0.01);
let b_changed = b_before
.iter()
.zip(adapter.b_matrix.iter())
.any(|(old, new)| (old - new).abs() > 1e-15);
assert!(b_changed, "B matrix must change after sgd_step");
let _ = a_before; }
#[test]
fn test_zero_grad_resets_gradients() {
let mut adapter = make_adapter(4, 6, 2);
let input = Array2::from_elem((2, 4), 1.0);
let d_output = Array2::from_elem((2, 6), 0.5);
adapter.backward(&input, &d_output);
adapter.zero_grad();
for &v in adapter.grad_a.iter() {
assert_eq!(v, 0.0, "grad_a must be zero after zero_grad");
}
for &v in adapter.grad_b.iter() {
assert_eq!(v, 0.0, "grad_b must be zero after zero_grad");
}
}
#[test]
fn test_grad_norm_zero_after_zero_grad() {
let mut adapter = make_adapter(4, 6, 2);
let input = Array2::from_elem((2, 4), 1.0);
let d_output = Array2::from_elem((2, 6), 0.5);
adapter.backward(&input, &d_output);
adapter.zero_grad();
assert_eq!(
adapter.grad_norm(),
0.0,
"grad_norm must be 0 after zero_grad"
);
}
#[test]
fn test_trainer_loss_converges() {
let adapter = LoraAdapter::new(4, 4, 1, 1.0, 7);
let mut trainer = LoraTrainer::new(adapter, 0.05);
let base_out = Array2::from_elem((3, 4), 1.0);
let targets = Array2::zeros((3, 4));
let initial_loss = trainer.train_epoch(&base_out, &targets);
let mut final_loss = initial_loss;
for _ in 0..99 {
final_loss = trainer.train_epoch(&base_out, &targets);
}
assert!(
final_loss < initial_loss * 0.9 || final_loss < 1e-6,
"loss should decrease: initial={initial_loss:.6}, final={final_loss:.6}"
);
}
#[test]
fn test_rank_one_adapter() {
let adapter = make_adapter(6, 4, 1);
assert_eq!(adapter.rank, 1);
let input = Array2::from_elem((2, 6), 0.5);
let delta = adapter.forward_delta(&input);
assert_eq!(delta.shape(), &[2, 4]);
}
#[test]
fn test_scale_equals_alpha_over_rank() {
let adapter = LoraAdapter::new(4, 4, 3, 9.0, 0);
let expected = 9.0_f64 / 3.0_f64;
assert!(
(adapter.scale() - expected).abs() < 1e-15,
"scale should be alpha/rank = {expected}, got {}",
adapter.scale()
);
}
#[test]
fn test_fd_gradient_check_grad_b() {
let d_in = 2;
let d_out = 3;
let rank = 1;
let alpha = 1.0;
let eps = 1e-5;
let input = Array2::from_shape_vec((2, d_in), vec![0.3, -0.5, 1.2, 0.1]).expect("shape ok");
let d_output = Array2::from_shape_vec((2, d_out), vec![0.1, -0.2, 0.4, 0.6, -0.1, 0.3])
.expect("shape ok");
let mut adapter_a = LoraAdapter::new(d_in, d_out, rank, alpha, 99);
adapter_a.backward(&input, &d_output);
let analytic = adapter_a.grad_b[[0, 0]];
let mut adapter_p = LoraAdapter::new(d_in, d_out, rank, alpha, 99);
let mut adapter_n = LoraAdapter::new(d_in, d_out, rank, alpha, 99);
adapter_p.b_matrix[[0, 0]] += eps;
adapter_n.b_matrix[[0, 0]] -= eps;
let loss_fn = |ad: &LoraAdapter| -> f64 {
let delta = ad.forward_delta(&input);
delta
.iter()
.zip(d_output.iter())
.map(|(a, b)| a * b)
.sum::<f64>()
};
let fd = (loss_fn(&adapter_p) - loss_fn(&adapter_n)) / (2.0 * eps);
let rel_err = (analytic - fd).abs() / (fd.abs().max(1e-10));
assert!(
rel_err < 1e-4,
"FD gradient check failed: analytic={analytic:.8}, fd={fd:.8}, rel_err={rel_err:.6}"
);
}
}