1use crate::error::{ModelError, ModelResult};
28use scirs2_core::ndarray::{Array1, Array2};
29use serde::{Deserialize, Serialize};
30
31struct SeededRng {
37 state: u64,
38}
39
40impl SeededRng {
41 fn new(seed: u64) -> Self {
42 Self { state: seed.max(1) }
43 }
44
45 fn next_f32(&mut self) -> f32 {
47 self.state ^= self.state << 13;
48 self.state ^= self.state >> 7;
49 self.state ^= self.state << 17;
50 (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
52 }
53}
54
55#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct LoraConfig {
62 pub rank: usize,
64 pub alpha: f32,
66 pub dropout: f32,
68 pub target_modules: Vec<String>,
70 pub fan_in_fan_out: bool,
72}
73
74impl LoraConfig {
75 pub fn new(rank: usize, alpha: f32) -> Self {
77 Self {
78 rank,
79 alpha,
80 dropout: 0.0,
81 target_modules: Vec::new(),
82 fan_in_fan_out: false,
83 }
84 }
85
86 pub fn with_dropout(mut self, dropout: f32) -> Self {
88 self.dropout = dropout;
89 self
90 }
91
92 pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
94 self.target_modules = modules;
95 self
96 }
97
98 pub fn with_fan_in_fan_out(mut self, fan_in_fan_out: bool) -> Self {
100 self.fan_in_fan_out = fan_in_fan_out;
101 self
102 }
103
104 pub fn validate(&self) -> ModelResult<()> {
106 if self.rank == 0 {
107 return Err(ModelError::invalid_config("LoRA rank must be > 0"));
108 }
109 if self.alpha <= 0.0 {
110 return Err(ModelError::invalid_config("LoRA alpha must be > 0.0"));
111 }
112 if !(0.0..=1.0).contains(&self.dropout) {
113 return Err(ModelError::invalid_config(
114 "LoRA dropout must be in [0.0, 1.0]",
115 ));
116 }
117 Ok(())
118 }
119}
120
121#[derive(Debug, Clone)]
135pub struct LoraLinear {
136 weight: Array2<f32>,
138 lora_a: Array2<f32>,
140 lora_b: Array2<f32>,
142 rank: usize,
144 alpha: f32,
146 scaling: f32,
148 merged: bool,
150 enabled: bool,
152}
153
154impl LoraLinear {
155 pub fn new(weight: Array2<f32>, rank: usize, alpha: f32) -> ModelResult<Self> {
160 if rank == 0 {
161 return Err(ModelError::invalid_config("LoRA rank must be > 0"));
162 }
163 if alpha <= 0.0 {
164 return Err(ModelError::invalid_config("LoRA alpha must be > 0.0"));
165 }
166
167 let (out_features, in_features) = weight.dim();
168 if out_features == 0 || in_features == 0 {
169 return Err(ModelError::invalid_config(
170 "Weight matrix dimensions must be > 0",
171 ));
172 }
173 if rank > out_features.min(in_features) {
174 return Err(ModelError::invalid_config(format!(
175 "LoRA rank ({}) must not exceed min(out_features, in_features) = {}",
176 rank,
177 out_features.min(in_features)
178 )));
179 }
180
181 let kaiming_scale = (2.0 / in_features as f32).sqrt();
183 let mut rng = SeededRng::new(42 + in_features as u64 + out_features as u64);
184 let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_f32() * kaiming_scale);
185
186 let lora_b = Array2::zeros((out_features, rank));
188
189 let scaling = alpha / rank as f32;
190
191 Ok(Self {
192 weight,
193 lora_a,
194 lora_b,
195 rank,
196 alpha,
197 scaling,
198 merged: false,
199 enabled: true,
200 })
201 }
202
203 pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
207 let (out_features, in_features) = self.weight.dim();
208 if input.len() != in_features {
209 return Err(ModelError::dimension_mismatch(
210 "LoraLinear forward input",
211 in_features,
212 input.len(),
213 ));
214 }
215
216 let mut output = Array1::zeros(out_features);
218 for i in 0..out_features {
219 let mut sum = 0.0_f32;
220 for j in 0..in_features {
221 sum += self.weight[[i, j]] * input[j];
222 }
223 output[i] = sum;
224 }
225
226 if self.enabled && !self.merged {
228 let mut a_x = Array1::zeros(self.rank);
230 for r in 0..self.rank {
231 let mut sum = 0.0_f32;
232 for j in 0..in_features {
233 sum += self.lora_a[[r, j]] * input[j];
234 }
235 a_x[r] = sum;
236 }
237
238 for i in 0..out_features {
240 let mut sum = 0.0_f32;
241 for r in 0..self.rank {
242 sum += self.lora_b[[i, r]] * a_x[r];
243 }
244 output[i] += self.scaling * sum;
245 }
246 }
247
248 Ok(output)
249 }
250
251 pub fn forward_batch(&self, input: &Array2<f32>) -> ModelResult<Array2<f32>> {
256 let (batch_size, input_dim) = input.dim();
257 let (out_features, in_features) = self.weight.dim();
258
259 if input_dim != in_features {
260 return Err(ModelError::dimension_mismatch(
261 "LoraLinear forward_batch input dim",
262 in_features,
263 input_dim,
264 ));
265 }
266
267 let mut output = Array2::zeros((batch_size, out_features));
269 for b in 0..batch_size {
270 for i in 0..out_features {
271 let mut sum = 0.0_f32;
272 for j in 0..in_features {
273 sum += input[[b, j]] * self.weight[[i, j]];
274 }
275 output[[b, i]] = sum;
276 }
277 }
278
279 if self.enabled && !self.merged {
281 for b in 0..batch_size {
282 let a_x: Vec<f32> = (0..self.rank)
284 .map(|r| {
285 let mut sum = 0.0_f32;
286 for j in 0..in_features {
287 sum += self.lora_a[[r, j]] * input[[b, j]];
288 }
289 sum
290 })
291 .collect();
292
293 for i in 0..out_features {
295 let mut sum = 0.0_f32;
296 for (r, &ax_r) in a_x.iter().enumerate() {
297 sum += self.lora_b[[i, r]] * ax_r;
298 }
299 output[[b, i]] += self.scaling * sum;
300 }
301 }
302 }
303
304 Ok(output)
305 }
306
307 pub fn merge(&mut self) -> ModelResult<()> {
312 if self.merged {
313 return Err(ModelError::invalid_config(
314 "LoRA weights are already merged",
315 ));
316 }
317
318 let (out_features, in_features) = self.weight.dim();
319
320 for i in 0..out_features {
322 for j in 0..in_features {
323 let mut delta = 0.0_f32;
324 for r in 0..self.rank {
325 delta += self.lora_b[[i, r]] * self.lora_a[[r, j]];
326 }
327 self.weight[[i, j]] += self.scaling * delta;
328 }
329 }
330
331 self.merged = true;
332 Ok(())
333 }
334
335 pub fn unmerge(&mut self) -> ModelResult<()> {
340 if !self.merged {
341 return Err(ModelError::invalid_config("LoRA weights are not merged"));
342 }
343
344 let (out_features, in_features) = self.weight.dim();
345
346 for i in 0..out_features {
348 for j in 0..in_features {
349 let mut delta = 0.0_f32;
350 for r in 0..self.rank {
351 delta += self.lora_b[[i, r]] * self.lora_a[[r, j]];
352 }
353 self.weight[[i, j]] -= self.scaling * delta;
354 }
355 }
356
357 self.merged = false;
358 Ok(())
359 }
360
361 pub fn trainable_params(&self) -> usize {
363 let (out_features, in_features) = self.weight.dim();
364 self.rank * (in_features + out_features)
365 }
366
367 pub fn total_params(&self) -> usize {
369 let (out_features, in_features) = self.weight.dim();
370 in_features * out_features + self.rank * (in_features + out_features)
371 }
372
373 pub fn compression_ratio(&self) -> f32 {
375 self.trainable_params() as f32 / self.total_params() as f32
376 }
377
378 pub fn lora_a(&self) -> &Array2<f32> {
380 &self.lora_a
381 }
382
383 pub fn lora_b(&self) -> &Array2<f32> {
385 &self.lora_b
386 }
387
388 pub fn set_lora_a(&mut self, a: Array2<f32>) -> ModelResult<()> {
390 let (_, in_features) = self.weight.dim();
391 let (a_rank, a_in) = a.dim();
392 if a_rank != self.rank {
393 return Err(ModelError::dimension_mismatch(
394 "set_lora_a rank",
395 self.rank,
396 a_rank,
397 ));
398 }
399 if a_in != in_features {
400 return Err(ModelError::dimension_mismatch(
401 "set_lora_a in_features",
402 in_features,
403 a_in,
404 ));
405 }
406 self.lora_a = a;
407 Ok(())
408 }
409
410 pub fn set_lora_b(&mut self, b: Array2<f32>) -> ModelResult<()> {
412 let (out_features, _) = self.weight.dim();
413 let (b_out, b_rank) = b.dim();
414 if b_out != out_features {
415 return Err(ModelError::dimension_mismatch(
416 "set_lora_b out_features",
417 out_features,
418 b_out,
419 ));
420 }
421 if b_rank != self.rank {
422 return Err(ModelError::dimension_mismatch(
423 "set_lora_b rank",
424 self.rank,
425 b_rank,
426 ));
427 }
428 self.lora_b = b;
429 Ok(())
430 }
431
432 pub fn enable(&mut self) {
434 self.enabled = true;
435 }
436
437 pub fn disable(&mut self) {
439 self.enabled = false;
440 }
441
442 pub fn is_enabled(&self) -> bool {
444 self.enabled
445 }
446
447 pub fn is_merged(&self) -> bool {
449 self.merged
450 }
451
452 pub fn weight(&self) -> &Array2<f32> {
454 &self.weight
455 }
456
457 pub fn rank(&self) -> usize {
459 self.rank
460 }
461
462 pub fn alpha(&self) -> f32 {
464 self.alpha
465 }
466
467 pub fn scaling(&self) -> f32 {
469 self.scaling
470 }
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct LoraAdapterSummary {
480 pub num_layers: usize,
482 pub total_trainable: usize,
484 pub total_original: usize,
486 pub compression_ratio: f32,
488 pub rank: usize,
490 pub alpha: f32,
492}
493
494#[derive(Debug, Clone)]
496pub struct LoraAdapter {
497 config: LoraConfig,
499 layers: Vec<(String, LoraLinear)>,
501}
502
503impl LoraAdapter {
504 pub fn new(config: LoraConfig) -> Self {
506 Self {
507 config,
508 layers: Vec::new(),
509 }
510 }
511
512 pub fn add_layer(&mut self, name: String, weight: Array2<f32>) -> ModelResult<()> {
514 if self.layers.iter().any(|(n, _)| n == &name) {
516 return Err(ModelError::invalid_config(format!(
517 "LoRA layer '{}' already exists",
518 name
519 )));
520 }
521
522 let layer = LoraLinear::new(weight, self.config.rank, self.config.alpha)?;
523 self.layers.push((name, layer));
524 Ok(())
525 }
526
527 pub fn forward_layer(&self, name: &str, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
529 let layer = self.get_layer(name).ok_or_else(|| {
530 ModelError::invalid_config(format!("LoRA layer '{}' not found", name))
531 })?;
532 layer.forward(input)
533 }
534
535 pub fn merge_all(&mut self) -> ModelResult<()> {
537 for (_, layer) in &mut self.layers {
538 if !layer.is_merged() {
539 layer.merge()?;
540 }
541 }
542 Ok(())
543 }
544
545 pub fn unmerge_all(&mut self) -> ModelResult<()> {
547 for (_, layer) in &mut self.layers {
548 if layer.is_merged() {
549 layer.unmerge()?;
550 }
551 }
552 Ok(())
553 }
554
555 pub fn total_trainable_params(&self) -> usize {
557 self.layers.iter().map(|(_, l)| l.trainable_params()).sum()
558 }
559
560 pub fn total_original_params(&self) -> usize {
562 self.layers
563 .iter()
564 .map(|(_, l)| {
565 let (out, inp) = l.weight().dim();
566 out * inp
567 })
568 .sum()
569 }
570
571 pub fn overall_compression_ratio(&self) -> f32 {
573 let trainable = self.total_trainable_params();
574 let total = self.total_original_params() + trainable;
575 if total == 0 {
576 return 0.0;
577 }
578 trainable as f32 / total as f32
579 }
580
581 pub fn layer_names(&self) -> Vec<&str> {
583 self.layers.iter().map(|(n, _)| n.as_str()).collect()
584 }
585
586 pub fn get_layer(&self, name: &str) -> Option<&LoraLinear> {
588 self.layers.iter().find(|(n, _)| n == name).map(|(_, l)| l)
589 }
590
591 pub fn get_layer_mut(&mut self, name: &str) -> Option<&mut LoraLinear> {
593 self.layers
594 .iter_mut()
595 .find(|(n, _)| n == name)
596 .map(|(_, l)| l)
597 }
598
599 pub fn config(&self) -> &LoraConfig {
601 &self.config
602 }
603
604 pub fn summary(&self) -> LoraAdapterSummary {
606 LoraAdapterSummary {
607 num_layers: self.layers.len(),
608 total_trainable: self.total_trainable_params(),
609 total_original: self.total_original_params(),
610 compression_ratio: self.overall_compression_ratio(),
611 rank: self.config.rank,
612 alpha: self.config.alpha,
613 }
614 }
615}
616
617const NF4_LEVELS: [f32; 16] = [
624 -1.0,
625 -0.696_192_8,
626 -0.525_073_05,
627 -0.394_917_5,
628 -0.284_441_38,
629 -0.184_773_43,
630 -0.091_050_04,
631 0.0,
632 0.079_580_3,
633 0.160_930_2,
634 0.246_112_3,
635 0.337_915_24,
636 0.440_709_83,
637 0.562_617,
638 0.722_956_84,
639 1.0,
640];
641
642#[derive(Debug, Clone)]
648pub struct QLoraLinear {
649 quantized_weight: Vec<u8>,
651 scale: Array1<f32>,
653 zero_point: Array1<f32>,
655 group_size: usize,
657 lora_a: Array2<f32>,
659 lora_b: Array2<f32>,
661 out_features: usize,
663 in_features: usize,
665 rank: usize,
667 alpha: f32,
669 scaling: f32,
671}
672
673impl QLoraLinear {
674 pub fn from_weight(
679 weight: Array2<f32>,
680 rank: usize,
681 alpha: f32,
682 group_size: usize,
683 ) -> ModelResult<Self> {
684 if rank == 0 {
685 return Err(ModelError::invalid_config("QLoRA rank must be > 0"));
686 }
687 if alpha <= 0.0 {
688 return Err(ModelError::invalid_config("QLoRA alpha must be > 0.0"));
689 }
690 if group_size == 0 {
691 return Err(ModelError::invalid_config("QLoRA group_size must be > 0"));
692 }
693
694 let (out_features, in_features) = weight.dim();
695 if out_features == 0 || in_features == 0 {
696 return Err(ModelError::invalid_config(
697 "Weight matrix dimensions must be > 0",
698 ));
699 }
700 if rank > out_features.min(in_features) {
701 return Err(ModelError::invalid_config(format!(
702 "QLoRA rank ({}) must not exceed min(out, in) = {}",
703 rank,
704 out_features.min(in_features)
705 )));
706 }
707
708 let total_elements = out_features * in_features;
710 let num_groups = total_elements.div_ceil(group_size);
711
712 let flat: Vec<f32> = weight.iter().copied().collect();
713
714 let mut scale = Array1::zeros(num_groups);
715 let mut zero_point = Array1::zeros(num_groups);
716 let packed_len = total_elements.div_ceil(2);
718 let mut quantized_weight = vec![0u8; packed_len];
719
720 for g in 0..num_groups {
722 let start = g * group_size;
723 let end = (start + group_size).min(total_elements);
724 let group = &flat[start..end];
725
726 let abs_max = group
728 .iter()
729 .map(|v| v.abs())
730 .fold(0.0_f32, f32::max)
731 .max(1e-10);
732
733 scale[g] = abs_max;
734 zero_point[g] = 0.0; for (k, &val) in group.iter().enumerate() {
738 let normalized = (val / abs_max).clamp(-1.0, 1.0);
739 let quant_idx = find_nearest_nf4(normalized);
740 let flat_idx = start + k;
741 let byte_idx = flat_idx / 2;
742 if flat_idx.is_multiple_of(2) {
743 quantized_weight[byte_idx] |= quant_idx;
744 } else {
745 quantized_weight[byte_idx] |= quant_idx << 4;
746 }
747 }
748 }
749
750 let kaiming_scale = (2.0 / in_features as f32).sqrt();
752 let mut rng = SeededRng::new(137 + in_features as u64 + out_features as u64);
753 let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_f32() * kaiming_scale);
754 let lora_b = Array2::zeros((out_features, rank));
755
756 let scaling = alpha / rank as f32;
757
758 Ok(Self {
759 quantized_weight,
760 scale,
761 zero_point,
762 group_size,
763 lora_a,
764 lora_b,
765 out_features,
766 in_features,
767 rank,
768 alpha,
769 scaling,
770 })
771 }
772
773 pub fn dequantize_weight(&self) -> ModelResult<Array2<f32>> {
777 let total_elements = self.out_features * self.in_features;
778 let num_groups = total_elements.div_ceil(self.group_size);
779 let mut flat = vec![0.0_f32; total_elements];
780
781 for g in 0..num_groups {
782 let start = g * self.group_size;
783 let end = (start + self.group_size).min(total_elements);
784 let s = self.scale[g];
785
786 for (offset, val) in flat[start..end].iter_mut().enumerate() {
787 let flat_idx = start + offset;
788 let byte_idx = flat_idx / 2;
789 let quant_idx = if flat_idx.is_multiple_of(2) {
790 self.quantized_weight[byte_idx] & 0x0F
791 } else {
792 (self.quantized_weight[byte_idx] >> 4) & 0x0F
793 };
794 *val = NF4_LEVELS[quant_idx as usize] * s;
795 }
796 }
797
798 Array2::from_shape_vec((self.out_features, self.in_features), flat).map_err(|e| {
799 ModelError::invalid_config(format!("Failed to reshape dequantized weight: {}", e))
800 })
801 }
802
803 pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
805 if input.len() != self.in_features {
806 return Err(ModelError::dimension_mismatch(
807 "QLoraLinear forward input",
808 self.in_features,
809 input.len(),
810 ));
811 }
812
813 let weight = self.dequantize_weight()?;
814
815 let mut output = Array1::zeros(self.out_features);
817 for i in 0..self.out_features {
818 let mut sum = 0.0_f32;
819 for j in 0..self.in_features {
820 sum += weight[[i, j]] * input[j];
821 }
822 output[i] = sum;
823 }
824
825 let mut a_x = Array1::zeros(self.rank);
827 for r in 0..self.rank {
828 let mut sum = 0.0_f32;
829 for j in 0..self.in_features {
830 sum += self.lora_a[[r, j]] * input[j];
831 }
832 a_x[r] = sum;
833 }
834
835 for i in 0..self.out_features {
836 let mut sum = 0.0_f32;
837 for r in 0..self.rank {
838 sum += self.lora_b[[i, r]] * a_x[r];
839 }
840 output[i] += self.scaling * sum;
841 }
842
843 Ok(output)
844 }
845
846 pub fn memory_saved_bytes(&self) -> usize {
848 let total_elements = self.out_features * self.in_features;
849 let fp32_bytes = total_elements * 4; let packed_bytes = self.quantized_weight.len(); let num_groups = total_elements.div_ceil(self.group_size);
852 let scale_bytes = num_groups * 4; let zero_point_bytes = num_groups * 4; let quantized_total = packed_bytes + scale_bytes + zero_point_bytes;
855
856 fp32_bytes.saturating_sub(quantized_total)
857 }
858
859 pub fn trainable_params(&self) -> usize {
861 self.rank * (self.in_features + self.out_features)
862 }
863
864 pub fn lora_a(&self) -> &Array2<f32> {
866 &self.lora_a
867 }
868
869 pub fn lora_b(&self) -> &Array2<f32> {
871 &self.lora_b
872 }
873
874 pub fn group_size(&self) -> usize {
876 self.group_size
877 }
878
879 pub fn rank(&self) -> usize {
881 self.rank
882 }
883
884 pub fn out_features(&self) -> usize {
886 self.out_features
887 }
888
889 pub fn in_features(&self) -> usize {
891 self.in_features
892 }
893
894 pub fn alpha(&self) -> f32 {
896 self.alpha
897 }
898
899 pub fn zero_point(&self) -> &Array1<f32> {
901 &self.zero_point
902 }
903
904 pub fn scale(&self) -> &Array1<f32> {
906 &self.scale
907 }
908}
909
910fn find_nearest_nf4(value: f32) -> u8 {
912 let mut best_idx = 0u8;
913 let mut best_dist = f32::MAX;
914 for (i, &level) in NF4_LEVELS.iter().enumerate() {
915 let dist = (value - level).abs();
916 if dist < best_dist {
917 best_dist = dist;
918 best_idx = i as u8;
919 }
920 }
921 best_idx
922}
923
924#[cfg(test)]
929mod tests {
930 use super::*;
931 use scirs2_core::ndarray::Array2;
932
933 fn make_weight(out: usize, inp: usize) -> Array2<f32> {
935 Array2::from_shape_fn((out, inp), |(i, j)| (i * inp + j) as f32 * 0.01)
936 }
937
938 #[test]
939 fn test_lora_linear_creation() -> ModelResult<()> {
940 let weight = make_weight(64, 32);
941 let lora = LoraLinear::new(weight.clone(), 8, 16.0)?;
942
943 let input = Array1::from_vec(vec![1.0; 32]);
945 let output_lora = lora.forward(&input)?;
946
947 let mut output_plain = Array1::zeros(64);
949 for i in 0..64 {
950 let mut sum = 0.0_f32;
951 for j in 0..32 {
952 sum += weight[[i, j]] * input[j];
953 }
954 output_plain[i] = sum;
955 }
956
957 for i in 0..64 {
959 assert!(
960 (output_lora[i] - output_plain[i]).abs() < 1e-5,
961 "Mismatch at index {}: lora={}, plain={}",
962 i,
963 output_lora[i],
964 output_plain[i]
965 );
966 }
967 Ok(())
968 }
969
970 #[test]
971 fn test_lora_linear_forward_with_nonzero_b() -> ModelResult<()> {
972 let weight = make_weight(16, 8);
973 let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
974
975 let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.1);
977 lora.set_lora_b(b)?;
978
979 let input = Array1::from_vec(vec![1.0; 8]);
980 let output_lora = lora.forward(&input)?;
981
982 let mut output_plain = Array1::zeros(16);
984 for i in 0..16 {
985 let mut sum = 0.0_f32;
986 for j in 0..8 {
987 sum += weight[[i, j]] * input[j];
988 }
989 output_plain[i] = sum;
990 }
991
992 let mut any_diff = false;
994 for i in 0..16 {
995 if (output_lora[i] - output_plain[i]).abs() > 1e-6 {
996 any_diff = true;
997 break;
998 }
999 }
1000 assert!(
1001 any_diff,
1002 "LoRA output should differ from plain output when B != 0"
1003 );
1004 Ok(())
1005 }
1006
1007 #[test]
1008 fn test_lora_linear_merge_unmerge() -> ModelResult<()> {
1009 let weight = make_weight(16, 8);
1010 let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
1011
1012 let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.01);
1014 lora.set_lora_b(b)?;
1015
1016 let input = Array1::from_vec(vec![0.5; 8]);
1017
1018 let output_before = lora.forward(&input)?;
1020
1021 lora.merge()?;
1023 assert!(lora.is_merged());
1024
1025 let output_merged = lora.forward(&input)?;
1027 for i in 0..16 {
1028 assert!(
1029 (output_before[i] - output_merged[i]).abs() < 1e-4,
1030 "Merge changed output at {}: before={}, after={}",
1031 i,
1032 output_before[i],
1033 output_merged[i]
1034 );
1035 }
1036
1037 lora.unmerge()?;
1039 assert!(!lora.is_merged());
1040
1041 for i in 0..16 {
1043 for j in 0..8 {
1044 assert!(
1045 (lora.weight()[[i, j]] - weight[[i, j]]).abs() < 1e-4,
1046 "Unmerge did not restore weight at [{}, {}]",
1047 i,
1048 j
1049 );
1050 }
1051 }
1052 Ok(())
1053 }
1054
1055 #[test]
1056 fn test_lora_linear_trainable_params() -> ModelResult<()> {
1057 let weight = make_weight(64, 32);
1058 let lora = LoraLinear::new(weight, 8, 16.0)?;
1059
1060 assert_eq!(lora.trainable_params(), 768);
1062 assert_eq!(lora.total_params(), 2816);
1064 Ok(())
1065 }
1066
1067 #[test]
1068 fn test_lora_linear_compression_ratio() -> ModelResult<()> {
1069 let weight = make_weight(256, 128);
1070 let lora = LoraLinear::new(weight, 8, 16.0)?;
1071
1072 let ratio = lora.compression_ratio();
1073 assert!(
1077 ratio < 1.0,
1078 "Compression ratio should be < 1.0, got {}",
1079 ratio
1080 );
1081 assert!(
1082 ratio > 0.0,
1083 "Compression ratio should be > 0.0, got {}",
1084 ratio
1085 );
1086
1087 let expected = 3072.0 / 35840.0;
1088 assert!(
1089 (ratio - expected).abs() < 1e-5,
1090 "Expected ratio ~{}, got {}",
1091 expected,
1092 ratio
1093 );
1094 Ok(())
1095 }
1096
1097 #[test]
1098 fn test_lora_adapter_multi_layer() -> ModelResult<()> {
1099 let config = LoraConfig::new(4, 8.0).with_target_modules(vec![
1100 "q_proj".into(),
1101 "k_proj".into(),
1102 "v_proj".into(),
1103 ]);
1104
1105 let mut adapter = LoraAdapter::new(config);
1106 adapter.add_layer("q_proj".into(), make_weight(32, 16))?;
1107 adapter.add_layer("k_proj".into(), make_weight(32, 16))?;
1108 adapter.add_layer("v_proj".into(), make_weight(32, 16))?;
1109
1110 assert_eq!(adapter.layer_names().len(), 3);
1111
1112 let input = Array1::from_vec(vec![1.0; 16]);
1114 for name in &["q_proj", "k_proj", "v_proj"] {
1115 let output = adapter.forward_layer(name, &input)?;
1116 assert_eq!(output.len(), 32);
1117 }
1118
1119 let result = adapter.forward_layer("nonexistent", &input);
1121 assert!(result.is_err());
1122
1123 Ok(())
1124 }
1125
1126 #[test]
1127 fn test_lora_adapter_merge_all() -> ModelResult<()> {
1128 let config = LoraConfig::new(4, 8.0);
1129 let mut adapter = LoraAdapter::new(config);
1130
1131 adapter.add_layer("layer_0".into(), make_weight(16, 8))?;
1132 adapter.add_layer("layer_1".into(), make_weight(16, 8))?;
1133
1134 if let Some(layer) = adapter.get_layer_mut("layer_0") {
1136 let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.01);
1137 layer.set_lora_b(b)?;
1138 }
1139
1140 let input = Array1::from_vec(vec![0.5; 8]);
1141
1142 let out_before_0 = adapter.forward_layer("layer_0", &input)?;
1144 let out_before_1 = adapter.forward_layer("layer_1", &input)?;
1145
1146 adapter.merge_all()?;
1148
1149 let out_after_0 = adapter.forward_layer("layer_0", &input)?;
1151 let out_after_1 = adapter.forward_layer("layer_1", &input)?;
1152
1153 for i in 0..16 {
1154 assert!(
1155 (out_before_0[i] - out_after_0[i]).abs() < 1e-4,
1156 "layer_0 merge changed output"
1157 );
1158 assert!(
1159 (out_before_1[i] - out_after_1[i]).abs() < 1e-4,
1160 "layer_1 merge changed output"
1161 );
1162 }
1163 Ok(())
1164 }
1165
1166 #[test]
1167 fn test_lora_adapter_summary() -> ModelResult<()> {
1168 let config = LoraConfig::new(8, 16.0);
1169 let mut adapter = LoraAdapter::new(config);
1170
1171 adapter.add_layer("proj_q".into(), make_weight(64, 32))?;
1172 adapter.add_layer("proj_v".into(), make_weight(64, 32))?;
1173
1174 let summary = adapter.summary();
1175 assert_eq!(summary.num_layers, 2);
1176 assert_eq!(summary.rank, 8);
1177 assert!((summary.alpha - 16.0).abs() < 1e-6);
1178 assert_eq!(summary.total_trainable, 1536);
1180 assert_eq!(summary.total_original, 4096);
1182 assert!(summary.compression_ratio > 0.0);
1183 assert!(summary.compression_ratio < 1.0);
1184 Ok(())
1185 }
1186
1187 #[test]
1188 fn test_lora_disable_enable() -> ModelResult<()> {
1189 let weight = make_weight(16, 8);
1190 let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
1191
1192 let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.1);
1194 lora.set_lora_b(b)?;
1195
1196 let input = Array1::from_vec(vec![1.0; 8]);
1197
1198 let mut output_plain = Array1::zeros(16);
1200 for i in 0..16 {
1201 let mut sum = 0.0_f32;
1202 for j in 0..8 {
1203 sum += weight[[i, j]] * input[j];
1204 }
1205 output_plain[i] = sum;
1206 }
1207
1208 let output_enabled = lora.forward(&input)?;
1210 let mut any_diff = false;
1211 for i in 0..16 {
1212 if (output_enabled[i] - output_plain[i]).abs() > 1e-6 {
1213 any_diff = true;
1214 break;
1215 }
1216 }
1217 assert!(any_diff, "Enabled LoRA should produce different output");
1218
1219 lora.disable();
1221 assert!(!lora.is_enabled());
1222
1223 let output_disabled = lora.forward(&input)?;
1224 for i in 0..16 {
1225 assert!(
1226 (output_disabled[i] - output_plain[i]).abs() < 1e-5,
1227 "Disabled LoRA should produce same output as plain W"
1228 );
1229 }
1230
1231 lora.enable();
1233 assert!(lora.is_enabled());
1234 let output_reenabled = lora.forward(&input)?;
1235 for i in 0..16 {
1236 assert!(
1237 (output_reenabled[i] - output_enabled[i]).abs() < 1e-5,
1238 "Re-enabled LoRA should match original enabled output"
1239 );
1240 }
1241 Ok(())
1242 }
1243
1244 #[test]
1245 fn test_qlora_creation() -> ModelResult<()> {
1246 let weight = make_weight(32, 16);
1247 let qlora = QLoraLinear::from_weight(weight, 4, 8.0, 64)?;
1248
1249 assert_eq!(qlora.out_features(), 32);
1250 assert_eq!(qlora.in_features(), 16);
1251 assert_eq!(qlora.rank(), 4);
1252 assert_eq!(qlora.group_size(), 64);
1253 assert_eq!(qlora.trainable_params(), 4 * (16 + 32));
1254 Ok(())
1255 }
1256
1257 #[test]
1258 fn test_qlora_forward() -> ModelResult<()> {
1259 let weight = make_weight(16, 8);
1260 let qlora = QLoraLinear::from_weight(weight, 4, 8.0, 32)?;
1261
1262 let input = Array1::from_vec(vec![1.0; 8]);
1263 let output = qlora.forward(&input)?;
1264
1265 assert_eq!(output.len(), 16);
1266 for &val in output.iter() {
1268 assert!(
1269 val.is_finite(),
1270 "QLoRA output contains non-finite value: {}",
1271 val
1272 );
1273 }
1274 Ok(())
1275 }
1276
1277 #[test]
1278 fn test_qlora_memory_savings() -> ModelResult<()> {
1279 let weight = make_weight(256, 128);
1280 let qlora = QLoraLinear::from_weight(weight, 8, 16.0, 64)?;
1281
1282 let saved = qlora.memory_saved_bytes();
1283 assert!(
1284 saved > 0,
1285 "QLoRA should save memory compared to fp32, got saved={} bytes",
1286 saved
1287 );
1288
1289 assert!(
1293 saved > 100_000,
1294 "Expected significant savings for 256x128 matrix, got {} bytes",
1295 saved
1296 );
1297 Ok(())
1298 }
1299
1300 #[test]
1301 fn test_lora_config_validation() -> ModelResult<()> {
1302 let config = LoraConfig::new(8, 16.0);
1304 assert!(config.validate().is_ok());
1305
1306 let bad_rank = LoraConfig::new(0, 16.0);
1308 assert!(bad_rank.validate().is_err());
1309
1310 let bad_alpha = LoraConfig::new(8, -1.0);
1312 assert!(bad_alpha.validate().is_err());
1313
1314 let bad_dropout = LoraConfig::new(8, 16.0).with_dropout(1.5);
1316 assert!(bad_dropout.validate().is_err());
1317
1318 Ok(())
1319 }
1320
1321 #[test]
1322 fn test_lora_batch_forward() -> ModelResult<()> {
1323 let weight = make_weight(16, 8);
1324 let lora = LoraLinear::new(weight, 4, 8.0)?;
1325
1326 let batch = Array2::from_shape_fn((3, 8), |(b, j)| (b * 8 + j) as f32 * 0.1);
1327 let output = lora.forward_batch(&batch)?;
1328
1329 assert_eq!(output.dim(), (3, 16));
1330
1331 for b in 0..3 {
1333 let single_input = Array1::from_vec(batch.row(b).to_vec());
1334 let single_output = lora.forward(&single_input)?;
1335 for i in 0..16 {
1336 assert!(
1337 (output[[b, i]] - single_output[i]).abs() < 1e-4,
1338 "Batch output[{},{}]={} != single output[{}]={}",
1339 b,
1340 i,
1341 output[[b, i]],
1342 i,
1343 single_output[i]
1344 );
1345 }
1346 }
1347 Ok(())
1348 }
1349
1350 #[test]
1351 fn test_qlora_dequantize_roundtrip() -> ModelResult<()> {
1352 let weight = Array2::from_shape_fn((8, 4), |(i, j)| {
1354 ((i as f32 - 4.0) * 0.2 + (j as f32 - 2.0) * 0.1).clamp(-0.9, 0.9)
1355 });
1356
1357 let qlora = QLoraLinear::from_weight(weight.clone(), 2, 4.0, 16)?;
1358 let deq = qlora.dequantize_weight()?;
1359
1360 assert_eq!(deq.dim(), (8, 4));
1361
1362 let mut max_err = 0.0_f32;
1364 for i in 0..8 {
1365 for j in 0..4 {
1366 let err = (weight[[i, j]] - deq[[i, j]]).abs();
1367 if err > max_err {
1368 max_err = err;
1369 }
1370 }
1371 }
1372 assert!(
1374 max_err < 0.5,
1375 "Maximum dequantization error {} is too large",
1376 max_err
1377 );
1378 Ok(())
1379 }
1380}