use super::LoraError;
pub fn apply_lora_delta(
base_weight: &mut [f32],
rows: usize,
cols: usize,
a: &[f32],
b: &[f32],
rank: usize,
alpha: f32,
) -> Result<(), LoraError> {
if !alpha.is_finite() || rank == 0 {
return Err(LoraError::ShapeMismatch);
}
if base_weight.len() != rows * cols || a.len() != rank * cols || b.len() != rows * rank {
return Err(LoraError::ShapeMismatch);
}
let scale = alpha / rank as f32;
for r in 0..rows {
for c in 0..cols {
let mut delta = 0.0f32;
for k in 0..rank {
delta += b[r * rank + k] * a[k * cols + c];
}
base_weight[r * cols + c] += delta * scale;
}
}
Ok(())
}