use crate::error::{ModelError, ModelResult};
use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
struct SeededRng {
state: u64,
}
impl SeededRng {
fn new(seed: u64) -> Self {
Self { state: seed.max(1) }
}
fn next_f32(&mut self) -> f32 {
self.state ^= self.state << 13;
self.state ^= self.state >> 7;
self.state ^= self.state << 17;
(self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoraConfig {
pub rank: usize,
pub alpha: f32,
pub dropout: f32,
pub target_modules: Vec<String>,
pub fan_in_fan_out: bool,
}
impl LoraConfig {
pub fn new(rank: usize, alpha: f32) -> Self {
Self {
rank,
alpha,
dropout: 0.0,
target_modules: Vec::new(),
fan_in_fan_out: false,
}
}
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
self.target_modules = modules;
self
}
pub fn with_fan_in_fan_out(mut self, fan_in_fan_out: bool) -> Self {
self.fan_in_fan_out = fan_in_fan_out;
self
}
pub fn validate(&self) -> ModelResult<()> {
if self.rank == 0 {
return Err(ModelError::invalid_config("LoRA rank must be > 0"));
}
if self.alpha <= 0.0 {
return Err(ModelError::invalid_config("LoRA alpha must be > 0.0"));
}
if !(0.0..=1.0).contains(&self.dropout) {
return Err(ModelError::invalid_config(
"LoRA dropout must be in [0.0, 1.0]",
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct LoraLinear {
weight: Array2<f32>,
lora_a: Array2<f32>,
lora_b: Array2<f32>,
rank: usize,
alpha: f32,
scaling: f32,
merged: bool,
enabled: bool,
}
impl LoraLinear {
pub fn new(weight: Array2<f32>, rank: usize, alpha: f32) -> ModelResult<Self> {
if rank == 0 {
return Err(ModelError::invalid_config("LoRA rank must be > 0"));
}
if alpha <= 0.0 {
return Err(ModelError::invalid_config("LoRA alpha must be > 0.0"));
}
let (out_features, in_features) = weight.dim();
if out_features == 0 || in_features == 0 {
return Err(ModelError::invalid_config(
"Weight matrix dimensions must be > 0",
));
}
if rank > out_features.min(in_features) {
return Err(ModelError::invalid_config(format!(
"LoRA rank ({}) must not exceed min(out_features, in_features) = {}",
rank,
out_features.min(in_features)
)));
}
let kaiming_scale = (2.0 / in_features as f32).sqrt();
let mut rng = SeededRng::new(42 + in_features as u64 + out_features as u64);
let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_f32() * kaiming_scale);
let lora_b = Array2::zeros((out_features, rank));
let scaling = alpha / rank as f32;
Ok(Self {
weight,
lora_a,
lora_b,
rank,
alpha,
scaling,
merged: false,
enabled: true,
})
}
pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
let (out_features, in_features) = self.weight.dim();
if input.len() != in_features {
return Err(ModelError::dimension_mismatch(
"LoraLinear forward input",
in_features,
input.len(),
));
}
let mut output = Array1::zeros(out_features);
for i in 0..out_features {
let mut sum = 0.0_f32;
for j in 0..in_features {
sum += self.weight[[i, j]] * input[j];
}
output[i] = sum;
}
if self.enabled && !self.merged {
let mut a_x = Array1::zeros(self.rank);
for r in 0..self.rank {
let mut sum = 0.0_f32;
for j in 0..in_features {
sum += self.lora_a[[r, j]] * input[j];
}
a_x[r] = sum;
}
for i in 0..out_features {
let mut sum = 0.0_f32;
for r in 0..self.rank {
sum += self.lora_b[[i, r]] * a_x[r];
}
output[i] += self.scaling * sum;
}
}
Ok(output)
}
pub fn forward_batch(&self, input: &Array2<f32>) -> ModelResult<Array2<f32>> {
let (batch_size, input_dim) = input.dim();
let (out_features, in_features) = self.weight.dim();
if input_dim != in_features {
return Err(ModelError::dimension_mismatch(
"LoraLinear forward_batch input dim",
in_features,
input_dim,
));
}
let mut output = Array2::zeros((batch_size, out_features));
for b in 0..batch_size {
for i in 0..out_features {
let mut sum = 0.0_f32;
for j in 0..in_features {
sum += input[[b, j]] * self.weight[[i, j]];
}
output[[b, i]] = sum;
}
}
if self.enabled && !self.merged {
for b in 0..batch_size {
let a_x: Vec<f32> = (0..self.rank)
.map(|r| {
let mut sum = 0.0_f32;
for j in 0..in_features {
sum += self.lora_a[[r, j]] * input[[b, j]];
}
sum
})
.collect();
for i in 0..out_features {
let mut sum = 0.0_f32;
for (r, &ax_r) in a_x.iter().enumerate() {
sum += self.lora_b[[i, r]] * ax_r;
}
output[[b, i]] += self.scaling * sum;
}
}
}
Ok(output)
}
pub fn merge(&mut self) -> ModelResult<()> {
if self.merged {
return Err(ModelError::invalid_config(
"LoRA weights are already merged",
));
}
let (out_features, in_features) = self.weight.dim();
for i in 0..out_features {
for j in 0..in_features {
let mut delta = 0.0_f32;
for r in 0..self.rank {
delta += self.lora_b[[i, r]] * self.lora_a[[r, j]];
}
self.weight[[i, j]] += self.scaling * delta;
}
}
self.merged = true;
Ok(())
}
pub fn unmerge(&mut self) -> ModelResult<()> {
if !self.merged {
return Err(ModelError::invalid_config("LoRA weights are not merged"));
}
let (out_features, in_features) = self.weight.dim();
for i in 0..out_features {
for j in 0..in_features {
let mut delta = 0.0_f32;
for r in 0..self.rank {
delta += self.lora_b[[i, r]] * self.lora_a[[r, j]];
}
self.weight[[i, j]] -= self.scaling * delta;
}
}
self.merged = false;
Ok(())
}
pub fn trainable_params(&self) -> usize {
let (out_features, in_features) = self.weight.dim();
self.rank * (in_features + out_features)
}
pub fn total_params(&self) -> usize {
let (out_features, in_features) = self.weight.dim();
in_features * out_features + self.rank * (in_features + out_features)
}
pub fn compression_ratio(&self) -> f32 {
self.trainable_params() as f32 / self.total_params() as f32
}
pub fn lora_a(&self) -> &Array2<f32> {
&self.lora_a
}
pub fn lora_b(&self) -> &Array2<f32> {
&self.lora_b
}
pub fn set_lora_a(&mut self, a: Array2<f32>) -> ModelResult<()> {
let (_, in_features) = self.weight.dim();
let (a_rank, a_in) = a.dim();
if a_rank != self.rank {
return Err(ModelError::dimension_mismatch(
"set_lora_a rank",
self.rank,
a_rank,
));
}
if a_in != in_features {
return Err(ModelError::dimension_mismatch(
"set_lora_a in_features",
in_features,
a_in,
));
}
self.lora_a = a;
Ok(())
}
pub fn set_lora_b(&mut self, b: Array2<f32>) -> ModelResult<()> {
let (out_features, _) = self.weight.dim();
let (b_out, b_rank) = b.dim();
if b_out != out_features {
return Err(ModelError::dimension_mismatch(
"set_lora_b out_features",
out_features,
b_out,
));
}
if b_rank != self.rank {
return Err(ModelError::dimension_mismatch(
"set_lora_b rank",
self.rank,
b_rank,
));
}
self.lora_b = b;
Ok(())
}
pub fn enable(&mut self) {
self.enabled = true;
}
pub fn disable(&mut self) {
self.enabled = false;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
pub fn is_merged(&self) -> bool {
self.merged
}
pub fn weight(&self) -> &Array2<f32> {
&self.weight
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn alpha(&self) -> f32 {
self.alpha
}
pub fn scaling(&self) -> f32 {
self.scaling
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LoraAdapterSummary {
pub num_layers: usize,
pub total_trainable: usize,
pub total_original: usize,
pub compression_ratio: f32,
pub rank: usize,
pub alpha: f32,
}
#[derive(Debug, Clone)]
pub struct LoraAdapter {
config: LoraConfig,
layers: Vec<(String, LoraLinear)>,
}
impl LoraAdapter {
pub fn new(config: LoraConfig) -> Self {
Self {
config,
layers: Vec::new(),
}
}
pub fn add_layer(&mut self, name: String, weight: Array2<f32>) -> ModelResult<()> {
if self.layers.iter().any(|(n, _)| n == &name) {
return Err(ModelError::invalid_config(format!(
"LoRA layer '{}' already exists",
name
)));
}
let layer = LoraLinear::new(weight, self.config.rank, self.config.alpha)?;
self.layers.push((name, layer));
Ok(())
}
pub fn forward_layer(&self, name: &str, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
let layer = self.get_layer(name).ok_or_else(|| {
ModelError::invalid_config(format!("LoRA layer '{}' not found", name))
})?;
layer.forward(input)
}
pub fn merge_all(&mut self) -> ModelResult<()> {
for (_, layer) in &mut self.layers {
if !layer.is_merged() {
layer.merge()?;
}
}
Ok(())
}
pub fn unmerge_all(&mut self) -> ModelResult<()> {
for (_, layer) in &mut self.layers {
if layer.is_merged() {
layer.unmerge()?;
}
}
Ok(())
}
pub fn total_trainable_params(&self) -> usize {
self.layers.iter().map(|(_, l)| l.trainable_params()).sum()
}
pub fn total_original_params(&self) -> usize {
self.layers
.iter()
.map(|(_, l)| {
let (out, inp) = l.weight().dim();
out * inp
})
.sum()
}
pub fn overall_compression_ratio(&self) -> f32 {
let trainable = self.total_trainable_params();
let total = self.total_original_params() + trainable;
if total == 0 {
return 0.0;
}
trainable as f32 / total as f32
}
pub fn layer_names(&self) -> Vec<&str> {
self.layers.iter().map(|(n, _)| n.as_str()).collect()
}
pub fn get_layer(&self, name: &str) -> Option<&LoraLinear> {
self.layers.iter().find(|(n, _)| n == name).map(|(_, l)| l)
}
pub fn get_layer_mut(&mut self, name: &str) -> Option<&mut LoraLinear> {
self.layers
.iter_mut()
.find(|(n, _)| n == name)
.map(|(_, l)| l)
}
pub fn config(&self) -> &LoraConfig {
&self.config
}
pub fn summary(&self) -> LoraAdapterSummary {
LoraAdapterSummary {
num_layers: self.layers.len(),
total_trainable: self.total_trainable_params(),
total_original: self.total_original_params(),
compression_ratio: self.overall_compression_ratio(),
rank: self.config.rank,
alpha: self.config.alpha,
}
}
}
const NF4_LEVELS: [f32; 16] = [
-1.0,
-0.696_192_8,
-0.525_073_05,
-0.394_917_5,
-0.284_441_38,
-0.184_773_43,
-0.091_050_04,
0.0,
0.079_580_3,
0.160_930_2,
0.246_112_3,
0.337_915_24,
0.440_709_83,
0.562_617,
0.722_956_84,
1.0,
];
#[derive(Debug, Clone)]
pub struct QLoraLinear {
quantized_weight: Vec<u8>,
scale: Array1<f32>,
zero_point: Array1<f32>,
group_size: usize,
lora_a: Array2<f32>,
lora_b: Array2<f32>,
out_features: usize,
in_features: usize,
rank: usize,
alpha: f32,
scaling: f32,
}
impl QLoraLinear {
pub fn from_weight(
weight: Array2<f32>,
rank: usize,
alpha: f32,
group_size: usize,
) -> ModelResult<Self> {
if rank == 0 {
return Err(ModelError::invalid_config("QLoRA rank must be > 0"));
}
if alpha <= 0.0 {
return Err(ModelError::invalid_config("QLoRA alpha must be > 0.0"));
}
if group_size == 0 {
return Err(ModelError::invalid_config("QLoRA group_size must be > 0"));
}
let (out_features, in_features) = weight.dim();
if out_features == 0 || in_features == 0 {
return Err(ModelError::invalid_config(
"Weight matrix dimensions must be > 0",
));
}
if rank > out_features.min(in_features) {
return Err(ModelError::invalid_config(format!(
"QLoRA rank ({}) must not exceed min(out, in) = {}",
rank,
out_features.min(in_features)
)));
}
let total_elements = out_features * in_features;
let num_groups = total_elements.div_ceil(group_size);
let flat: Vec<f32> = weight.iter().copied().collect();
let mut scale = Array1::zeros(num_groups);
let mut zero_point = Array1::zeros(num_groups);
let packed_len = total_elements.div_ceil(2);
let mut quantized_weight = vec![0u8; packed_len];
for g in 0..num_groups {
let start = g * group_size;
let end = (start + group_size).min(total_elements);
let group = &flat[start..end];
let abs_max = group
.iter()
.map(|v| v.abs())
.fold(0.0_f32, f32::max)
.max(1e-10);
scale[g] = abs_max;
zero_point[g] = 0.0;
for (k, &val) in group.iter().enumerate() {
let normalized = (val / abs_max).clamp(-1.0, 1.0);
let quant_idx = find_nearest_nf4(normalized);
let flat_idx = start + k;
let byte_idx = flat_idx / 2;
if flat_idx.is_multiple_of(2) {
quantized_weight[byte_idx] |= quant_idx;
} else {
quantized_weight[byte_idx] |= quant_idx << 4;
}
}
}
let kaiming_scale = (2.0 / in_features as f32).sqrt();
let mut rng = SeededRng::new(137 + in_features as u64 + out_features as u64);
let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_f32() * kaiming_scale);
let lora_b = Array2::zeros((out_features, rank));
let scaling = alpha / rank as f32;
Ok(Self {
quantized_weight,
scale,
zero_point,
group_size,
lora_a,
lora_b,
out_features,
in_features,
rank,
alpha,
scaling,
})
}
pub fn dequantize_weight(&self) -> ModelResult<Array2<f32>> {
let total_elements = self.out_features * self.in_features;
let num_groups = total_elements.div_ceil(self.group_size);
let mut flat = vec![0.0_f32; total_elements];
for g in 0..num_groups {
let start = g * self.group_size;
let end = (start + self.group_size).min(total_elements);
let s = self.scale[g];
for (offset, val) in flat[start..end].iter_mut().enumerate() {
let flat_idx = start + offset;
let byte_idx = flat_idx / 2;
let quant_idx = if flat_idx.is_multiple_of(2) {
self.quantized_weight[byte_idx] & 0x0F
} else {
(self.quantized_weight[byte_idx] >> 4) & 0x0F
};
*val = NF4_LEVELS[quant_idx as usize] * s;
}
}
Array2::from_shape_vec((self.out_features, self.in_features), flat).map_err(|e| {
ModelError::invalid_config(format!("Failed to reshape dequantized weight: {}", e))
})
}
pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
if input.len() != self.in_features {
return Err(ModelError::dimension_mismatch(
"QLoraLinear forward input",
self.in_features,
input.len(),
));
}
let weight = self.dequantize_weight()?;
let mut output = Array1::zeros(self.out_features);
for i in 0..self.out_features {
let mut sum = 0.0_f32;
for j in 0..self.in_features {
sum += weight[[i, j]] * input[j];
}
output[i] = sum;
}
let mut a_x = Array1::zeros(self.rank);
for r in 0..self.rank {
let mut sum = 0.0_f32;
for j in 0..self.in_features {
sum += self.lora_a[[r, j]] * input[j];
}
a_x[r] = sum;
}
for i in 0..self.out_features {
let mut sum = 0.0_f32;
for r in 0..self.rank {
sum += self.lora_b[[i, r]] * a_x[r];
}
output[i] += self.scaling * sum;
}
Ok(output)
}
pub fn memory_saved_bytes(&self) -> usize {
let total_elements = self.out_features * self.in_features;
let fp32_bytes = total_elements * 4; let packed_bytes = self.quantized_weight.len(); let num_groups = total_elements.div_ceil(self.group_size);
let scale_bytes = num_groups * 4; let zero_point_bytes = num_groups * 4; let quantized_total = packed_bytes + scale_bytes + zero_point_bytes;
fp32_bytes.saturating_sub(quantized_total)
}
pub fn trainable_params(&self) -> usize {
self.rank * (self.in_features + self.out_features)
}
pub fn lora_a(&self) -> &Array2<f32> {
&self.lora_a
}
pub fn lora_b(&self) -> &Array2<f32> {
&self.lora_b
}
pub fn group_size(&self) -> usize {
self.group_size
}
pub fn rank(&self) -> usize {
self.rank
}
pub fn out_features(&self) -> usize {
self.out_features
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn alpha(&self) -> f32 {
self.alpha
}
pub fn zero_point(&self) -> &Array1<f32> {
&self.zero_point
}
pub fn scale(&self) -> &Array1<f32> {
&self.scale
}
}
fn find_nearest_nf4(value: f32) -> u8 {
let mut best_idx = 0u8;
let mut best_dist = f32::MAX;
for (i, &level) in NF4_LEVELS.iter().enumerate() {
let dist = (value - level).abs();
if dist < best_dist {
best_dist = dist;
best_idx = i as u8;
}
}
best_idx
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::Array2;
fn make_weight(out: usize, inp: usize) -> Array2<f32> {
Array2::from_shape_fn((out, inp), |(i, j)| (i * inp + j) as f32 * 0.01)
}
#[test]
fn test_lora_linear_creation() -> ModelResult<()> {
let weight = make_weight(64, 32);
let lora = LoraLinear::new(weight.clone(), 8, 16.0)?;
let input = Array1::from_vec(vec![1.0; 32]);
let output_lora = lora.forward(&input)?;
let mut output_plain = Array1::zeros(64);
for i in 0..64 {
let mut sum = 0.0_f32;
for j in 0..32 {
sum += weight[[i, j]] * input[j];
}
output_plain[i] = sum;
}
for i in 0..64 {
assert!(
(output_lora[i] - output_plain[i]).abs() < 1e-5,
"Mismatch at index {}: lora={}, plain={}",
i,
output_lora[i],
output_plain[i]
);
}
Ok(())
}
#[test]
fn test_lora_linear_forward_with_nonzero_b() -> ModelResult<()> {
let weight = make_weight(16, 8);
let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.1);
lora.set_lora_b(b)?;
let input = Array1::from_vec(vec![1.0; 8]);
let output_lora = lora.forward(&input)?;
let mut output_plain = Array1::zeros(16);
for i in 0..16 {
let mut sum = 0.0_f32;
for j in 0..8 {
sum += weight[[i, j]] * input[j];
}
output_plain[i] = sum;
}
let mut any_diff = false;
for i in 0..16 {
if (output_lora[i] - output_plain[i]).abs() > 1e-6 {
any_diff = true;
break;
}
}
assert!(
any_diff,
"LoRA output should differ from plain output when B != 0"
);
Ok(())
}
#[test]
fn test_lora_linear_merge_unmerge() -> ModelResult<()> {
let weight = make_weight(16, 8);
let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.01);
lora.set_lora_b(b)?;
let input = Array1::from_vec(vec![0.5; 8]);
let output_before = lora.forward(&input)?;
lora.merge()?;
assert!(lora.is_merged());
let output_merged = lora.forward(&input)?;
for i in 0..16 {
assert!(
(output_before[i] - output_merged[i]).abs() < 1e-4,
"Merge changed output at {}: before={}, after={}",
i,
output_before[i],
output_merged[i]
);
}
lora.unmerge()?;
assert!(!lora.is_merged());
for i in 0..16 {
for j in 0..8 {
assert!(
(lora.weight()[[i, j]] - weight[[i, j]]).abs() < 1e-4,
"Unmerge did not restore weight at [{}, {}]",
i,
j
);
}
}
Ok(())
}
#[test]
fn test_lora_linear_trainable_params() -> ModelResult<()> {
let weight = make_weight(64, 32);
let lora = LoraLinear::new(weight, 8, 16.0)?;
assert_eq!(lora.trainable_params(), 768);
assert_eq!(lora.total_params(), 2816);
Ok(())
}
#[test]
fn test_lora_linear_compression_ratio() -> ModelResult<()> {
let weight = make_weight(256, 128);
let lora = LoraLinear::new(weight, 8, 16.0)?;
let ratio = lora.compression_ratio();
assert!(
ratio < 1.0,
"Compression ratio should be < 1.0, got {}",
ratio
);
assert!(
ratio > 0.0,
"Compression ratio should be > 0.0, got {}",
ratio
);
let expected = 3072.0 / 35840.0;
assert!(
(ratio - expected).abs() < 1e-5,
"Expected ratio ~{}, got {}",
expected,
ratio
);
Ok(())
}
#[test]
fn test_lora_adapter_multi_layer() -> ModelResult<()> {
let config = LoraConfig::new(4, 8.0).with_target_modules(vec![
"q_proj".into(),
"k_proj".into(),
"v_proj".into(),
]);
let mut adapter = LoraAdapter::new(config);
adapter.add_layer("q_proj".into(), make_weight(32, 16))?;
adapter.add_layer("k_proj".into(), make_weight(32, 16))?;
adapter.add_layer("v_proj".into(), make_weight(32, 16))?;
assert_eq!(adapter.layer_names().len(), 3);
let input = Array1::from_vec(vec![1.0; 16]);
for name in &["q_proj", "k_proj", "v_proj"] {
let output = adapter.forward_layer(name, &input)?;
assert_eq!(output.len(), 32);
}
let result = adapter.forward_layer("nonexistent", &input);
assert!(result.is_err());
Ok(())
}
#[test]
fn test_lora_adapter_merge_all() -> ModelResult<()> {
let config = LoraConfig::new(4, 8.0);
let mut adapter = LoraAdapter::new(config);
adapter.add_layer("layer_0".into(), make_weight(16, 8))?;
adapter.add_layer("layer_1".into(), make_weight(16, 8))?;
if let Some(layer) = adapter.get_layer_mut("layer_0") {
let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.01);
layer.set_lora_b(b)?;
}
let input = Array1::from_vec(vec![0.5; 8]);
let out_before_0 = adapter.forward_layer("layer_0", &input)?;
let out_before_1 = adapter.forward_layer("layer_1", &input)?;
adapter.merge_all()?;
let out_after_0 = adapter.forward_layer("layer_0", &input)?;
let out_after_1 = adapter.forward_layer("layer_1", &input)?;
for i in 0..16 {
assert!(
(out_before_0[i] - out_after_0[i]).abs() < 1e-4,
"layer_0 merge changed output"
);
assert!(
(out_before_1[i] - out_after_1[i]).abs() < 1e-4,
"layer_1 merge changed output"
);
}
Ok(())
}
#[test]
fn test_lora_adapter_summary() -> ModelResult<()> {
let config = LoraConfig::new(8, 16.0);
let mut adapter = LoraAdapter::new(config);
adapter.add_layer("proj_q".into(), make_weight(64, 32))?;
adapter.add_layer("proj_v".into(), make_weight(64, 32))?;
let summary = adapter.summary();
assert_eq!(summary.num_layers, 2);
assert_eq!(summary.rank, 8);
assert!((summary.alpha - 16.0).abs() < 1e-6);
assert_eq!(summary.total_trainable, 1536);
assert_eq!(summary.total_original, 4096);
assert!(summary.compression_ratio > 0.0);
assert!(summary.compression_ratio < 1.0);
Ok(())
}
#[test]
fn test_lora_disable_enable() -> ModelResult<()> {
let weight = make_weight(16, 8);
let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.1);
lora.set_lora_b(b)?;
let input = Array1::from_vec(vec![1.0; 8]);
let mut output_plain = Array1::zeros(16);
for i in 0..16 {
let mut sum = 0.0_f32;
for j in 0..8 {
sum += weight[[i, j]] * input[j];
}
output_plain[i] = sum;
}
let output_enabled = lora.forward(&input)?;
let mut any_diff = false;
for i in 0..16 {
if (output_enabled[i] - output_plain[i]).abs() > 1e-6 {
any_diff = true;
break;
}
}
assert!(any_diff, "Enabled LoRA should produce different output");
lora.disable();
assert!(!lora.is_enabled());
let output_disabled = lora.forward(&input)?;
for i in 0..16 {
assert!(
(output_disabled[i] - output_plain[i]).abs() < 1e-5,
"Disabled LoRA should produce same output as plain W"
);
}
lora.enable();
assert!(lora.is_enabled());
let output_reenabled = lora.forward(&input)?;
for i in 0..16 {
assert!(
(output_reenabled[i] - output_enabled[i]).abs() < 1e-5,
"Re-enabled LoRA should match original enabled output"
);
}
Ok(())
}
#[test]
fn test_qlora_creation() -> ModelResult<()> {
let weight = make_weight(32, 16);
let qlora = QLoraLinear::from_weight(weight, 4, 8.0, 64)?;
assert_eq!(qlora.out_features(), 32);
assert_eq!(qlora.in_features(), 16);
assert_eq!(qlora.rank(), 4);
assert_eq!(qlora.group_size(), 64);
assert_eq!(qlora.trainable_params(), 4 * (16 + 32));
Ok(())
}
#[test]
fn test_qlora_forward() -> ModelResult<()> {
let weight = make_weight(16, 8);
let qlora = QLoraLinear::from_weight(weight, 4, 8.0, 32)?;
let input = Array1::from_vec(vec![1.0; 8]);
let output = qlora.forward(&input)?;
assert_eq!(output.len(), 16);
for &val in output.iter() {
assert!(
val.is_finite(),
"QLoRA output contains non-finite value: {}",
val
);
}
Ok(())
}
#[test]
fn test_qlora_memory_savings() -> ModelResult<()> {
let weight = make_weight(256, 128);
let qlora = QLoraLinear::from_weight(weight, 8, 16.0, 64)?;
let saved = qlora.memory_saved_bytes();
assert!(
saved > 0,
"QLoRA should save memory compared to fp32, got saved={} bytes",
saved
);
assert!(
saved > 100_000,
"Expected significant savings for 256x128 matrix, got {} bytes",
saved
);
Ok(())
}
#[test]
fn test_lora_config_validation() -> ModelResult<()> {
let config = LoraConfig::new(8, 16.0);
assert!(config.validate().is_ok());
let bad_rank = LoraConfig::new(0, 16.0);
assert!(bad_rank.validate().is_err());
let bad_alpha = LoraConfig::new(8, -1.0);
assert!(bad_alpha.validate().is_err());
let bad_dropout = LoraConfig::new(8, 16.0).with_dropout(1.5);
assert!(bad_dropout.validate().is_err());
Ok(())
}
#[test]
fn test_lora_batch_forward() -> ModelResult<()> {
let weight = make_weight(16, 8);
let lora = LoraLinear::new(weight, 4, 8.0)?;
let batch = Array2::from_shape_fn((3, 8), |(b, j)| (b * 8 + j) as f32 * 0.1);
let output = lora.forward_batch(&batch)?;
assert_eq!(output.dim(), (3, 16));
for b in 0..3 {
let single_input = Array1::from_vec(batch.row(b).to_vec());
let single_output = lora.forward(&single_input)?;
for i in 0..16 {
assert!(
(output[[b, i]] - single_output[i]).abs() < 1e-4,
"Batch output[{},{}]={} != single output[{}]={}",
b,
i,
output[[b, i]],
i,
single_output[i]
);
}
}
Ok(())
}
#[test]
fn test_qlora_dequantize_roundtrip() -> ModelResult<()> {
let weight = Array2::from_shape_fn((8, 4), |(i, j)| {
((i as f32 - 4.0) * 0.2 + (j as f32 - 2.0) * 0.1).clamp(-0.9, 0.9)
});
let qlora = QLoraLinear::from_weight(weight.clone(), 2, 4.0, 16)?;
let deq = qlora.dequantize_weight()?;
assert_eq!(deq.dim(), (8, 4));
let mut max_err = 0.0_f32;
for i in 0..8 {
for j in 0..4 {
let err = (weight[[i, j]] - deq[[i, j]]).abs();
if err > max_err {
max_err = err;
}
}
}
assert!(
max_err < 0.5,
"Maximum dequantization error {} is too large",
max_err
);
Ok(())
}
}