use scirs2_core::random::{RngExt, SeedableRng, StdRng};
use std::f64::consts::PI;
use super::config::LoraConfig;
use super::error::{LoraError, LoraResult};
pub struct LoraLayer {
pub weight_a: Vec<Vec<f64>>,
pub weight_b: Vec<Vec<f64>>,
pub base_weight: Vec<Vec<f64>>,
pub config: LoraConfig,
pub merged: bool,
rng: StdRng,
}
fn next_normal(rng: &mut StdRng) -> f64 {
let u1 = rng.random::<f64>().max(f64::MIN_POSITIVE);
let u2 = rng.random::<f64>();
(-2.0 * u1.ln()).sqrt() * (2.0 * PI * u2).cos()
}
fn matmul(a: &[Vec<f64>], b: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
if a.is_empty() || b.is_empty() {
return Err(LoraError::DimensionMismatch {
expected: "non-empty matrices".into(),
got: format!("a rows={}, b rows={}", a.len(), b.len()),
});
}
let a_cols = a[0].len();
let b_rows = b.len();
if a_cols != b_rows {
return Err(LoraError::DimensionMismatch {
expected: format!("a_cols ({a_cols}) == b_rows"),
got: format!("{b_rows}"),
});
}
let b_cols = b[0].len();
let mut out = vec![vec![0.0; b_cols]; a.len()];
for i in 0..a.len() {
for k in 0..a_cols {
let a_ik = a[i][k];
for j in 0..b_cols {
out[i][j] += a_ik * b[k][j];
}
}
}
Ok(out)
}
fn transpose(m: &[Vec<f64>]) -> Vec<Vec<f64>> {
if m.is_empty() {
return Vec::new();
}
let rows = m.len();
let cols = m[0].len();
let mut t = vec![vec![0.0; rows]; cols];
for i in 0..rows {
for j in 0..cols {
t[j][i] = m[i][j];
}
}
t
}
fn add_matrices(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
a.iter()
.zip(b.iter())
.map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x + y).collect())
.collect()
}
fn sub_matrices(a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
a.iter()
.zip(b.iter())
.map(|(ra, rb)| ra.iter().zip(rb.iter()).map(|(x, y)| x - y).collect())
.collect()
}
fn scale_matrix(s: f64, m: &[Vec<f64>]) -> Vec<Vec<f64>> {
m.iter()
.map(|row| row.iter().map(|v| v * s).collect())
.collect()
}
impl LoraLayer {
pub fn new(base_weight: Vec<Vec<f64>>, config: LoraConfig) -> LoraResult<Self> {
let d = base_weight.len();
if d == 0 {
return Err(LoraError::DimensionMismatch {
expected: "d > 0".into(),
got: "0".into(),
});
}
let k = base_weight[0].len();
if k == 0 {
return Err(LoraError::DimensionMismatch {
expected: "k > 0".into(),
got: "0".into(),
});
}
let rank = config.rank;
if rank == 0 || rank > d.min(k) {
return Err(LoraError::InvalidRank(rank));
}
let mut rng = StdRng::seed_from_u64(config.seed);
let stddev = 1.0 / (rank as f64).sqrt();
let weight_a: Vec<Vec<f64>> = (0..rank)
.map(|_| (0..k).map(|_| next_normal(&mut rng) * stddev).collect())
.collect();
let weight_b: Vec<Vec<f64>> = vec![vec![0.0; rank]; d];
Ok(Self {
weight_a,
weight_b,
base_weight,
config,
merged: false,
rng,
})
}
fn scaling(&self) -> f64 {
self.config.alpha / self.config.rank as f64
}
fn delta_weight(&self) -> LoraResult<Vec<Vec<f64>>> {
matmul(&self.weight_b, &self.weight_a)
}
pub fn effective_weight(&self) -> LoraResult<Vec<Vec<f64>>> {
if self.merged {
return Ok(self.base_weight.clone());
}
let dw = self.delta_weight()?;
Ok(add_matrices(
&self.base_weight,
&scale_matrix(self.scaling(), &dw),
))
}
pub fn forward(&mut self, input: &[Vec<f64>]) -> LoraResult<Vec<Vec<f64>>> {
if input.is_empty() {
return Ok(Vec::new());
}
let k = self.base_weight[0].len();
if input[0].len() != k {
return Err(LoraError::DimensionMismatch {
expected: format!("input cols = {k}"),
got: format!("{}", input[0].len()),
});
}
if self.merged {
let wt = transpose(&self.base_weight);
return matmul(input, &wt);
}
let wt = transpose(&self.base_weight);
let base_out = matmul(input, &wt)?;
let at = transpose(&self.weight_a);
let mut lora_hidden = matmul(input, &at)?;
if self.config.dropout > 0.0 && self.config.dropout < 1.0 {
let inv_keep = 1.0 / (1.0 - self.config.dropout);
for row in &mut lora_hidden {
for v in row.iter_mut() {
if self.rng.random::<f64>() < self.config.dropout {
*v = 0.0;
} else {
*v *= inv_keep;
}
}
}
}
let bt = transpose(&self.weight_b);
let lora_out = matmul(&lora_hidden, &bt)?;
let scaled = scale_matrix(self.scaling(), &lora_out);
Ok(add_matrices(&base_out, &scaled))
}
pub fn merge(&mut self) -> LoraResult<()> {
if self.merged {
return Err(LoraError::MergeError("already merged".into()));
}
let dw = self.delta_weight()?;
self.base_weight = add_matrices(&self.base_weight, &scale_matrix(self.scaling(), &dw));
self.merged = true;
Ok(())
}
pub fn unmerge(&mut self) -> LoraResult<()> {
if !self.merged {
return Err(LoraError::MergeError("not merged".into()));
}
let dw = self.delta_weight()?;
self.base_weight = sub_matrices(&self.base_weight, &scale_matrix(self.scaling(), &dw));
self.merged = false;
Ok(())
}
pub fn trainable_params(&self) -> usize {
let d = self.base_weight.len();
let k = self.base_weight[0].len();
self.config.rank * (d + k)
}
pub fn total_params(&self) -> usize {
let d = self.base_weight.len();
let k = self.base_weight[0].len();
d * k + self.trainable_params()
}
pub fn compression_ratio(&self) -> f64 {
self.trainable_params() as f64 / self.total_params() as f64
}
}