use scirs2_core::ndarray::{Array2, Axis};
use super::types::{LoRAConfig, LoRAStats};
use crate::{NeuralError, Result};
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 {
0xDEAD_BEEF_CAFE_1234
} else {
seed
},
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f64(&mut self) -> f64 {
(self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
}
fn next_uniform(&mut self, bound: f64) -> f64 {
(self.next_f64() * 2.0 - 1.0) * bound
}
}
fn kaiming_uniform_bound(fan_in: usize) -> f64 {
let gain = std::f64::consts::SQRT_2;
let std_dev = gain / (fan_in as f64).sqrt();
std_dev * 3.0_f64.sqrt()
}
pub struct LoRALinear {
weight: Array2<f64>,
lora_a: Array2<f64>,
lora_b: Array2<f64>,
rank: usize,
alpha: f64,
scaling: f64,
merged: bool,
}
impl LoRALinear {
pub fn new(weight: Array2<f64>, config: &LoRAConfig) -> Result<Self> {
config.validate()?;
let (out_features, in_features) = (weight.nrows(), weight.ncols());
if config.rank > in_features.min(out_features) {
return Err(NeuralError::InvalidArgument(format!(
"LoRA rank {} exceeds min(in_features={}, out_features={})",
config.rank, in_features, out_features
)));
}
let rank = config.rank;
let scaling = config.scaling();
let bound = kaiming_uniform_bound(in_features);
let mut rng = Xorshift64::new(config.seed);
let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_uniform(bound));
let lora_b = Array2::zeros((out_features, rank));
Ok(Self {
weight,
lora_a,
lora_b,
rank,
alpha: config.alpha,
scaling,
merged: false,
})
}
pub fn forward(&self, input: &Array2<f64>) -> Result<Array2<f64>> {
let in_features = self.weight.ncols();
if input.ncols() != in_features {
return Err(NeuralError::DimensionMismatch(format!(
"Input has {} features but weight expects {}",
input.ncols(),
in_features
)));
}
let output = input.dot(&self.weight.t());
if self.merged {
return Ok(output);
}
let lora_output = input.dot(&self.lora_a.t());
let lora_output = lora_output.dot(&self.lora_b.t());
let lora_scaled = &lora_output * self.scaling;
Ok(&output + &lora_scaled)
}
pub fn merge(&mut self) -> Result<()> {
if self.merged {
return Err(NeuralError::InvalidState(
"LoRA weights are already merged".to_string(),
));
}
let delta = self.lora_b.dot(&self.lora_a) * self.scaling;
self.weight = &self.weight + δ
self.merged = true;
Ok(())
}
pub fn unmerge(&mut self) -> Result<()> {
if !self.merged {
return Err(NeuralError::InvalidState(
"LoRA weights are not merged".to_string(),
));
}
let delta = self.lora_b.dot(&self.lora_a) * self.scaling;
self.weight = &self.weight - δ
self.merged = false;
Ok(())
}
pub fn lora_a(&self) -> &Array2<f64> {
&self.lora_a
}
pub fn lora_b(&self) -> &Array2<f64> {
&self.lora_b
}
pub fn weight(&self) -> &Array2<f64> {
&self.weight
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn alpha(&self) -> f64 {
self.alpha
}
pub fn scaling(&self) -> f64 {
self.scaling
}
pub fn is_merged(&self) -> bool {
self.merged
}
pub fn stats(&self) -> LoRAStats {
let (out_features, in_features) = (self.weight.nrows(), self.weight.ncols());
let frozen_params = out_features * in_features;
let trainable_params = self.rank * in_features + out_features * self.rank;
let total_params = frozen_params + trainable_params;
let compression_ratio = if total_params > 0 {
trainable_params as f64 / total_params as f64
} else {
0.0
};
LoRAStats {
total_params,
trainable_params,
frozen_params,
compression_ratio,
}
}
pub fn set_lora_a(&mut self, a: Array2<f64>) -> Result<()> {
if a.shape() != self.lora_a.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"Expected A shape {:?}, got {:?}",
self.lora_a.shape(),
a.shape()
)));
}
self.lora_a = a;
Ok(())
}
pub fn set_lora_b(&mut self, b: Array2<f64>) -> Result<()> {
if b.shape() != self.lora_b.shape() {
return Err(NeuralError::ShapeMismatch(format!(
"Expected B shape {:?}, got {:?}",
self.lora_b.shape(),
b.shape()
)));
}
self.lora_b = b;
Ok(())
}
pub fn lora_delta(&self) -> Array2<f64> {
self.lora_b.dot(&self.lora_a) * self.scaling
}
pub fn effective_weight(&self) -> Array2<f64> {
if self.merged {
self.weight.clone()
} else {
&self.weight + &self.lora_delta()
}
}
pub fn effective_rank(&self, threshold: f64) -> usize {
let delta = self.lora_delta();
let mut non_zero_cols = 0;
for col_idx in 0..delta.ncols() {
let col = delta.index_axis(Axis(1), col_idx);
let norm: f64 = col.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm > threshold {
non_zero_cols += 1;
}
}
non_zero_cols.min(self.rank)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
#[test]
fn test_lora_linear_creation() {
let weight = Array2::<f64>::eye(8);
let config = LoRAConfig {
rank: 4,
..Default::default()
};
let lora = LoRALinear::new(weight, &config);
assert!(lora.is_ok());
let lora = lora.expect("creation should succeed");
assert_eq!(lora.rank(), 4);
assert_eq!(lora.lora_a().shape(), &[4, 8]);
assert_eq!(lora.lora_b().shape(), &[8, 4]);
}
#[test]
fn test_zero_b_preserves_original_output() {
let weight = Array2::from_shape_fn((4, 6), |(_i, _j)| 0.5);
let config = LoRAConfig {
rank: 2,
..Default::default()
};
let lora = LoRALinear::new(weight.clone(), &config).expect("creation failed");
let input = Array2::from_shape_fn((2, 6), |(i, j)| (i * 6 + j) as f64 * 0.1);
let lora_output = lora.forward(&input).expect("forward failed");
let original_output = input.dot(&weight.t());
for (a, b) in lora_output.iter().zip(original_output.iter()) {
assert!((a - b).abs() < 1e-10, "outputs differ: {a} vs {b}");
}
}
#[test]
fn test_merged_vs_unmerged_same_output() {
let weight = Array2::from_shape_fn((4, 6), |(i, j)| (i as f64 + j as f64) * 0.1);
let config = LoRAConfig {
rank: 2,
..Default::default()
};
let mut lora = LoRALinear::new(weight, &config).expect("creation failed");
let b = Array2::from_shape_fn((4, 2), |(i, j)| (i as f64 - j as f64) * 0.01);
lora.set_lora_b(b).expect("set_lora_b failed");
let input = Array2::from_shape_fn((3, 6), |(i, j)| (i * 6 + j) as f64 * 0.05);
let unmerged_output = lora.forward(&input).expect("forward failed");
lora.merge().expect("merge failed");
let merged_output = lora.forward(&input).expect("forward failed");
for (a, b) in unmerged_output.iter().zip(merged_output.iter()) {
assert!(
(a - b).abs() < 1e-10,
"merged vs unmerged differ: {a} vs {b}"
);
}
}
#[test]
fn test_merge_unmerge_roundtrip() {
let weight = Array2::from_shape_fn((4, 6), |(i, j)| (i as f64 + j as f64) * 0.1);
let original_weight = weight.clone();
let config = LoRAConfig {
rank: 2,
..Default::default()
};
let mut lora = LoRALinear::new(weight, &config).expect("creation failed");
let b = Array2::from_shape_fn((4, 2), |(i, j)| (i as f64 - j as f64) * 0.01);
lora.set_lora_b(b).expect("set_lora_b failed");
lora.merge().expect("merge failed");
lora.unmerge().expect("unmerge failed");
for (a, b) in lora.weight().iter().zip(original_weight.iter()) {
assert!(
(a - b).abs() < 1e-10,
"weight changed after merge+unmerge: {a} vs {b}"
);
}
}
#[test]
fn test_rank1_lora() {
let weight = Array2::<f64>::zeros((4, 6));
let config = LoRAConfig {
rank: 1,
alpha: 1.0,
..Default::default()
};
let mut lora = LoRALinear::new(weight, &config).expect("creation failed");
let a = Array2::from_shape_fn((1, 6), |(_, j)| j as f64);
let b = Array2::from_shape_fn((4, 1), |(i, _)| i as f64);
lora.set_lora_a(a).expect("set_lora_a failed");
lora.set_lora_b(b).expect("set_lora_b failed");
let delta = lora.lora_delta();
let row0 = delta.row(0).to_owned();
if row0.iter().all(|x| x.abs() < 1e-10) {
let row1 = delta.row(1).to_owned();
for i in 2..4 {
let row_i = delta.row(i).to_owned();
let ratio = i as f64;
for (a, b) in row_i.iter().zip(row1.iter()) {
if b.abs() > 1e-10 {
assert!(
(a / b - ratio).abs() < 1e-10,
"not rank-1: row ratio mismatch"
);
}
}
}
}
}
#[test]
fn test_stats_parameter_counts() {
let weight = Array2::<f64>::eye(16);
let config = LoRAConfig {
rank: 4,
..Default::default()
};
let lora = LoRALinear::new(weight, &config).expect("creation failed");
let stats = lora.stats();
assert_eq!(stats.frozen_params, 16 * 16); assert_eq!(stats.trainable_params, 4 * 16 + 16 * 4); assert_eq!(stats.total_params, 256 + 128);
assert!((stats.compression_ratio - 128.0 / 384.0).abs() < 1e-10);
}
#[test]
fn test_rank_equals_min_dim() {
let weight = Array2::<f64>::eye(4);
let config = LoRAConfig {
rank: 4,
..Default::default()
};
let lora = LoRALinear::new(weight, &config);
assert!(lora.is_ok());
}
#[test]
fn test_rank_exceeds_dim_error() {
let weight = Array2::<f64>::eye(4);
let config = LoRAConfig {
rank: 5,
..Default::default()
};
let lora = LoRALinear::new(weight, &config);
assert!(lora.is_err());
}
#[test]
fn test_dimension_mismatch_forward() {
let weight = Array2::<f64>::eye(4);
let config = LoRAConfig {
rank: 2,
..Default::default()
};
let lora = LoRALinear::new(weight, &config).expect("creation failed");
let bad_input = Array2::<f64>::ones((2, 5)); assert!(lora.forward(&bad_input).is_err());
}
#[test]
fn test_double_merge_error() {
let weight = Array2::<f64>::eye(4);
let config = LoRAConfig {
rank: 2,
..Default::default()
};
let mut lora = LoRALinear::new(weight, &config).expect("creation failed");
lora.merge().expect("first merge failed");
assert!(lora.merge().is_err());
}
#[test]
fn test_unmerge_without_merge_error() {
let weight = Array2::<f64>::eye(4);
let config = LoRAConfig {
rank: 2,
..Default::default()
};
let mut lora = LoRALinear::new(weight, &config).expect("creation failed");
assert!(lora.unmerge().is_err());
}
}