Skip to main content

kizzasi_model/
lora.rs

1//! LoRA (Low-Rank Adaptation) for efficient fine-tuning
2//!
3//! Implements LoRA and QLoRA for parameter-efficient fine-tuning of neural networks.
4//! LoRA adds trainable low-rank decomposition matrices A and B to frozen weight matrices,
5//! so the effective weight becomes W' = W + (alpha/rank) * B @ A.
6//!
7//! # Key Benefits
8//!
9//! - **Memory Efficient**: Only rank * (in + out) parameters are trainable vs in * out
10//! - **No Inference Latency**: LoRA weights can be merged into the original weights
11//! - **Composable**: Multiple LoRA adapters can be swapped without reloading the base model
12//!
13//! # Example
14//!
15//! ```rust,ignore
16//! use kizzasi_model::lora::{LoraConfig, LoraAdapter};
17//! use scirs2_core::ndarray::Array2;
18//!
19//! let config = LoraConfig::new(8, 16.0)
20//!     .with_target_modules(vec!["q_proj".into(), "v_proj".into()]);
21//!
22//! let mut adapter = LoraAdapter::new(config);
23//! let weight = Array2::zeros((512, 256));
24//! adapter.add_layer("q_proj".into(), weight)?;
25//! ```
26
27use crate::error::{ModelError, ModelResult};
28use scirs2_core::ndarray::{Array1, Array2};
29use serde::{Deserialize, Serialize};
30
31// ---------------------------------------------------------------------------
32// Deterministic PRNG (xorshift64) — same approach as rwkv7
33// ---------------------------------------------------------------------------
34
35/// Simple xorshift64 PRNG for deterministic LoRA weight initialization.
36struct SeededRng {
37    state: u64,
38}
39
40impl SeededRng {
41    fn new(seed: u64) -> Self {
42        Self { state: seed.max(1) }
43    }
44
45    /// Returns a float in [-1, 1)
46    fn next_f32(&mut self) -> f32 {
47        self.state ^= self.state << 13;
48        self.state ^= self.state >> 7;
49        self.state ^= self.state << 17;
50        // Map u64 to [-1, 1)
51        (self.state as f64 / u64::MAX as f64 * 2.0 - 1.0) as f32
52    }
53}
54
55// ---------------------------------------------------------------------------
56// LoraConfig
57// ---------------------------------------------------------------------------
58
59/// Configuration for LoRA adaptation
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct LoraConfig {
62    /// Rank of the low-rank decomposition (typically 4-64)
63    pub rank: usize,
64    /// Scaling factor (typically equal to rank or 2*rank)
65    pub alpha: f32,
66    /// Dropout probability for LoRA layers (0.0 = no dropout)
67    pub dropout: f32,
68    /// Which module names to apply LoRA to (e.g., "q_proj", "v_proj")
69    pub target_modules: Vec<String>,
70    /// Whether weight is stored as (fan_in, fan_out) instead of (fan_out, fan_in)
71    pub fan_in_fan_out: bool,
72}
73
74impl LoraConfig {
75    /// Create a new LoRA configuration with the given rank and alpha
76    pub fn new(rank: usize, alpha: f32) -> Self {
77        Self {
78            rank,
79            alpha,
80            dropout: 0.0,
81            target_modules: Vec::new(),
82            fan_in_fan_out: false,
83        }
84    }
85
86    /// Set dropout probability
87    pub fn with_dropout(mut self, dropout: f32) -> Self {
88        self.dropout = dropout;
89        self
90    }
91
92    /// Set target module names
93    pub fn with_target_modules(mut self, modules: Vec<String>) -> Self {
94        self.target_modules = modules;
95        self
96    }
97
98    /// Set fan_in_fan_out flag
99    pub fn with_fan_in_fan_out(mut self, fan_in_fan_out: bool) -> Self {
100        self.fan_in_fan_out = fan_in_fan_out;
101        self
102    }
103
104    /// Validate the configuration
105    pub fn validate(&self) -> ModelResult<()> {
106        if self.rank == 0 {
107            return Err(ModelError::invalid_config("LoRA rank must be > 0"));
108        }
109        if self.alpha <= 0.0 {
110            return Err(ModelError::invalid_config("LoRA alpha must be > 0.0"));
111        }
112        if !(0.0..=1.0).contains(&self.dropout) {
113            return Err(ModelError::invalid_config(
114                "LoRA dropout must be in [0.0, 1.0]",
115            ));
116        }
117        Ok(())
118    }
119}
120
121// ---------------------------------------------------------------------------
122// LoraLinear
123// ---------------------------------------------------------------------------
124
125/// A single LoRA-adapted linear layer.
126///
127/// Computes: output = W @ x + (alpha/rank) * B @ (A @ x)
128///
129/// - `lora_a` has shape (rank, in_features) — initialized with Kaiming uniform
130/// - `lora_b` has shape (out_features, rank) — initialized with zeros
131///
132/// Because B starts at zero, the initial LoRA contribution is zero and the
133/// model produces the same output as the original frozen weights.
134#[derive(Debug, Clone)]
135pub struct LoraLinear {
136    /// Original frozen weight matrix (out_features, in_features)
137    weight: Array2<f32>,
138    /// LoRA A matrix (rank, in_features)
139    lora_a: Array2<f32>,
140    /// LoRA B matrix (out_features, rank)
141    lora_b: Array2<f32>,
142    /// Rank of decomposition
143    rank: usize,
144    /// Alpha scaling factor
145    alpha: f32,
146    /// Computed scaling = alpha / rank
147    scaling: f32,
148    /// Whether LoRA weights have been merged into W
149    merged: bool,
150    /// Whether LoRA adaptation is active
151    enabled: bool,
152}
153
154impl LoraLinear {
155    /// Create a new LoRA-adapted linear layer.
156    ///
157    /// The weight matrix should have shape (out_features, in_features).
158    /// LoRA A is initialized with Kaiming-uniform values, B with zeros.
159    pub fn new(weight: Array2<f32>, rank: usize, alpha: f32) -> ModelResult<Self> {
160        if rank == 0 {
161            return Err(ModelError::invalid_config("LoRA rank must be > 0"));
162        }
163        if alpha <= 0.0 {
164            return Err(ModelError::invalid_config("LoRA alpha must be > 0.0"));
165        }
166
167        let (out_features, in_features) = weight.dim();
168        if out_features == 0 || in_features == 0 {
169            return Err(ModelError::invalid_config(
170                "Weight matrix dimensions must be > 0",
171            ));
172        }
173        if rank > out_features.min(in_features) {
174            return Err(ModelError::invalid_config(format!(
175                "LoRA rank ({}) must not exceed min(out_features, in_features) = {}",
176                rank,
177                out_features.min(in_features)
178            )));
179        }
180
181        // Kaiming uniform initialization for A: scale = sqrt(2 / in_features)
182        let kaiming_scale = (2.0 / in_features as f32).sqrt();
183        let mut rng = SeededRng::new(42 + in_features as u64 + out_features as u64);
184        let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_f32() * kaiming_scale);
185
186        // B initialized to zero so initial LoRA contribution is zero
187        let lora_b = Array2::zeros((out_features, rank));
188
189        let scaling = alpha / rank as f32;
190
191        Ok(Self {
192            weight,
193            lora_a,
194            lora_b,
195            rank,
196            alpha,
197            scaling,
198            merged: false,
199            enabled: true,
200        })
201    }
202
203    /// Forward pass for a single input vector.
204    ///
205    /// Computes: output = W @ x + scaling * B @ (A @ x)
206    pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
207        let (out_features, in_features) = self.weight.dim();
208        if input.len() != in_features {
209            return Err(ModelError::dimension_mismatch(
210                "LoraLinear forward input",
211                in_features,
212                input.len(),
213            ));
214        }
215
216        // W @ x
217        let mut output = Array1::zeros(out_features);
218        for i in 0..out_features {
219            let mut sum = 0.0_f32;
220            for j in 0..in_features {
221                sum += self.weight[[i, j]] * input[j];
222            }
223            output[i] = sum;
224        }
225
226        // Add LoRA contribution if enabled and not already merged
227        if self.enabled && !self.merged {
228            // A @ x  -> shape (rank,)
229            let mut a_x = Array1::zeros(self.rank);
230            for r in 0..self.rank {
231                let mut sum = 0.0_f32;
232                for j in 0..in_features {
233                    sum += self.lora_a[[r, j]] * input[j];
234                }
235                a_x[r] = sum;
236            }
237
238            // B @ (A @ x) -> shape (out_features,)
239            for i in 0..out_features {
240                let mut sum = 0.0_f32;
241                for r in 0..self.rank {
242                    sum += self.lora_b[[i, r]] * a_x[r];
243                }
244                output[i] += self.scaling * sum;
245            }
246        }
247
248        Ok(output)
249    }
250
251    /// Forward pass for a batch of inputs.
252    ///
253    /// Input shape: (batch_size, in_features)
254    /// Output shape: (batch_size, out_features)
255    pub fn forward_batch(&self, input: &Array2<f32>) -> ModelResult<Array2<f32>> {
256        let (batch_size, input_dim) = input.dim();
257        let (out_features, in_features) = self.weight.dim();
258
259        if input_dim != in_features {
260            return Err(ModelError::dimension_mismatch(
261                "LoraLinear forward_batch input dim",
262                in_features,
263                input_dim,
264            ));
265        }
266
267        // output = input @ W^T  (batch_size, out_features)
268        let mut output = Array2::zeros((batch_size, out_features));
269        for b in 0..batch_size {
270            for i in 0..out_features {
271                let mut sum = 0.0_f32;
272                for j in 0..in_features {
273                    sum += input[[b, j]] * self.weight[[i, j]];
274                }
275                output[[b, i]] = sum;
276            }
277        }
278
279        // Add LoRA contribution
280        if self.enabled && !self.merged {
281            for b in 0..batch_size {
282                // A @ x_b -> (rank,)
283                let a_x: Vec<f32> = (0..self.rank)
284                    .map(|r| {
285                        let mut sum = 0.0_f32;
286                        for j in 0..in_features {
287                            sum += self.lora_a[[r, j]] * input[[b, j]];
288                        }
289                        sum
290                    })
291                    .collect();
292
293                // B @ (A @ x_b) -> (out_features,)
294                for i in 0..out_features {
295                    let mut sum = 0.0_f32;
296                    for (r, &ax_r) in a_x.iter().enumerate() {
297                        sum += self.lora_b[[i, r]] * ax_r;
298                    }
299                    output[[b, i]] += self.scaling * sum;
300                }
301            }
302        }
303
304        Ok(output)
305    }
306
307    /// Merge LoRA weights into the original weight matrix for inference.
308    ///
309    /// After merging, forward passes use only the modified W with no extra computation.
310    /// W = W + scaling * B @ A
311    pub fn merge(&mut self) -> ModelResult<()> {
312        if self.merged {
313            return Err(ModelError::invalid_config(
314                "LoRA weights are already merged",
315            ));
316        }
317
318        let (out_features, in_features) = self.weight.dim();
319
320        // W += scaling * B @ A
321        for i in 0..out_features {
322            for j in 0..in_features {
323                let mut delta = 0.0_f32;
324                for r in 0..self.rank {
325                    delta += self.lora_b[[i, r]] * self.lora_a[[r, j]];
326                }
327                self.weight[[i, j]] += self.scaling * delta;
328            }
329        }
330
331        self.merged = true;
332        Ok(())
333    }
334
335    /// Unmerge LoRA weights from the original weight matrix.
336    ///
337    /// Restores W to its original values for continued training.
338    /// W = W - scaling * B @ A
339    pub fn unmerge(&mut self) -> ModelResult<()> {
340        if !self.merged {
341            return Err(ModelError::invalid_config("LoRA weights are not merged"));
342        }
343
344        let (out_features, in_features) = self.weight.dim();
345
346        // W -= scaling * B @ A
347        for i in 0..out_features {
348            for j in 0..in_features {
349                let mut delta = 0.0_f32;
350                for r in 0..self.rank {
351                    delta += self.lora_b[[i, r]] * self.lora_a[[r, j]];
352                }
353                self.weight[[i, j]] -= self.scaling * delta;
354            }
355        }
356
357        self.merged = false;
358        Ok(())
359    }
360
361    /// Number of trainable parameters: rank * (in_features + out_features)
362    pub fn trainable_params(&self) -> usize {
363        let (out_features, in_features) = self.weight.dim();
364        self.rank * (in_features + out_features)
365    }
366
367    /// Total parameters including frozen weights
368    pub fn total_params(&self) -> usize {
369        let (out_features, in_features) = self.weight.dim();
370        in_features * out_features + self.rank * (in_features + out_features)
371    }
372
373    /// Ratio of trainable to total parameters
374    pub fn compression_ratio(&self) -> f32 {
375        self.trainable_params() as f32 / self.total_params() as f32
376    }
377
378    /// Get reference to LoRA A matrix
379    pub fn lora_a(&self) -> &Array2<f32> {
380        &self.lora_a
381    }
382
383    /// Get reference to LoRA B matrix
384    pub fn lora_b(&self) -> &Array2<f32> {
385        &self.lora_b
386    }
387
388    /// Set the LoRA A matrix, validating dimensions
389    pub fn set_lora_a(&mut self, a: Array2<f32>) -> ModelResult<()> {
390        let (_, in_features) = self.weight.dim();
391        let (a_rank, a_in) = a.dim();
392        if a_rank != self.rank {
393            return Err(ModelError::dimension_mismatch(
394                "set_lora_a rank",
395                self.rank,
396                a_rank,
397            ));
398        }
399        if a_in != in_features {
400            return Err(ModelError::dimension_mismatch(
401                "set_lora_a in_features",
402                in_features,
403                a_in,
404            ));
405        }
406        self.lora_a = a;
407        Ok(())
408    }
409
410    /// Set the LoRA B matrix, validating dimensions
411    pub fn set_lora_b(&mut self, b: Array2<f32>) -> ModelResult<()> {
412        let (out_features, _) = self.weight.dim();
413        let (b_out, b_rank) = b.dim();
414        if b_out != out_features {
415            return Err(ModelError::dimension_mismatch(
416                "set_lora_b out_features",
417                out_features,
418                b_out,
419            ));
420        }
421        if b_rank != self.rank {
422            return Err(ModelError::dimension_mismatch(
423                "set_lora_b rank",
424                self.rank,
425                b_rank,
426            ));
427        }
428        self.lora_b = b;
429        Ok(())
430    }
431
432    /// Enable LoRA adaptation
433    pub fn enable(&mut self) {
434        self.enabled = true;
435    }
436
437    /// Disable LoRA adaptation (output equals original W @ x)
438    pub fn disable(&mut self) {
439        self.enabled = false;
440    }
441
442    /// Whether LoRA is currently enabled
443    pub fn is_enabled(&self) -> bool {
444        self.enabled
445    }
446
447    /// Whether LoRA weights are merged into W
448    pub fn is_merged(&self) -> bool {
449        self.merged
450    }
451
452    /// Get the weight matrix reference
453    pub fn weight(&self) -> &Array2<f32> {
454        &self.weight
455    }
456
457    /// Get the rank
458    pub fn rank(&self) -> usize {
459        self.rank
460    }
461
462    /// Get the alpha
463    pub fn alpha(&self) -> f32 {
464        self.alpha
465    }
466
467    /// Get the scaling factor
468    pub fn scaling(&self) -> f32 {
469        self.scaling
470    }
471}
472
473// ---------------------------------------------------------------------------
474// LoraAdapter
475// ---------------------------------------------------------------------------
476
477/// Summary statistics for a LoRA adapter
478#[derive(Debug, Clone, Serialize, Deserialize)]
479pub struct LoraAdapterSummary {
480    /// Number of LoRA-adapted layers
481    pub num_layers: usize,
482    /// Total trainable parameters across all layers
483    pub total_trainable: usize,
484    /// Total original (frozen) parameters across all layers
485    pub total_original: usize,
486    /// Overall compression ratio (trainable / total)
487    pub compression_ratio: f32,
488    /// LoRA rank
489    pub rank: usize,
490    /// LoRA alpha
491    pub alpha: f32,
492}
493
494/// Manages LoRA adaptation for a collection of layers
495#[derive(Debug, Clone)]
496pub struct LoraAdapter {
497    /// LoRA configuration
498    config: LoraConfig,
499    /// Named LoRA layers
500    layers: Vec<(String, LoraLinear)>,
501}
502
503impl LoraAdapter {
504    /// Create a new LoRA adapter with the given configuration
505    pub fn new(config: LoraConfig) -> Self {
506        Self {
507            config,
508            layers: Vec::new(),
509        }
510    }
511
512    /// Add a layer to the adapter with the given name and weight matrix
513    pub fn add_layer(&mut self, name: String, weight: Array2<f32>) -> ModelResult<()> {
514        // Check for duplicate names
515        if self.layers.iter().any(|(n, _)| n == &name) {
516            return Err(ModelError::invalid_config(format!(
517                "LoRA layer '{}' already exists",
518                name
519            )));
520        }
521
522        let layer = LoraLinear::new(weight, self.config.rank, self.config.alpha)?;
523        self.layers.push((name, layer));
524        Ok(())
525    }
526
527    /// Forward pass through a named layer
528    pub fn forward_layer(&self, name: &str, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
529        let layer = self.get_layer(name).ok_or_else(|| {
530            ModelError::invalid_config(format!("LoRA layer '{}' not found", name))
531        })?;
532        layer.forward(input)
533    }
534
535    /// Merge all LoRA weights into the original weight matrices
536    pub fn merge_all(&mut self) -> ModelResult<()> {
537        for (_, layer) in &mut self.layers {
538            if !layer.is_merged() {
539                layer.merge()?;
540            }
541        }
542        Ok(())
543    }
544
545    /// Unmerge all LoRA weights from the original weight matrices
546    pub fn unmerge_all(&mut self) -> ModelResult<()> {
547        for (_, layer) in &mut self.layers {
548            if layer.is_merged() {
549                layer.unmerge()?;
550            }
551        }
552        Ok(())
553    }
554
555    /// Total trainable parameters across all layers
556    pub fn total_trainable_params(&self) -> usize {
557        self.layers.iter().map(|(_, l)| l.trainable_params()).sum()
558    }
559
560    /// Total original (frozen) parameters across all layers
561    pub fn total_original_params(&self) -> usize {
562        self.layers
563            .iter()
564            .map(|(_, l)| {
565                let (out, inp) = l.weight().dim();
566                out * inp
567            })
568            .sum()
569    }
570
571    /// Overall compression ratio
572    pub fn overall_compression_ratio(&self) -> f32 {
573        let trainable = self.total_trainable_params();
574        let total = self.total_original_params() + trainable;
575        if total == 0 {
576            return 0.0;
577        }
578        trainable as f32 / total as f32
579    }
580
581    /// Get names of all layers
582    pub fn layer_names(&self) -> Vec<&str> {
583        self.layers.iter().map(|(n, _)| n.as_str()).collect()
584    }
585
586    /// Get an immutable reference to a named layer
587    pub fn get_layer(&self, name: &str) -> Option<&LoraLinear> {
588        self.layers.iter().find(|(n, _)| n == name).map(|(_, l)| l)
589    }
590
591    /// Get a mutable reference to a named layer
592    pub fn get_layer_mut(&mut self, name: &str) -> Option<&mut LoraLinear> {
593        self.layers
594            .iter_mut()
595            .find(|(n, _)| n == name)
596            .map(|(_, l)| l)
597    }
598
599    /// Get the adapter configuration
600    pub fn config(&self) -> &LoraConfig {
601        &self.config
602    }
603
604    /// Get a summary of the adapter
605    pub fn summary(&self) -> LoraAdapterSummary {
606        LoraAdapterSummary {
607            num_layers: self.layers.len(),
608            total_trainable: self.total_trainable_params(),
609            total_original: self.total_original_params(),
610            compression_ratio: self.overall_compression_ratio(),
611            rank: self.config.rank,
612            alpha: self.config.alpha,
613        }
614    }
615}
616
617// ---------------------------------------------------------------------------
618// QLoRA (Quantized LoRA)
619// ---------------------------------------------------------------------------
620
621/// NF4 (Normal Float 4-bit) quantization values.
622/// These are the 16 quantization levels optimized for normally distributed weights.
623const NF4_LEVELS: [f32; 16] = [
624    -1.0,
625    -0.696_192_8,
626    -0.525_073_05,
627    -0.394_917_5,
628    -0.284_441_38,
629    -0.184_773_43,
630    -0.091_050_04,
631    0.0,
632    0.079_580_3,
633    0.160_930_2,
634    0.246_112_3,
635    0.337_915_24,
636    0.440_709_83,
637    0.562_617,
638    0.722_956_84,
639    1.0,
640];
641
642/// QLoRA: Quantized LoRA for memory-efficient fine-tuning.
643///
644/// The base weight matrix is quantized to 4-bit NF4 format with group-wise
645/// quantization, while LoRA matrices A and B remain in full fp32 precision.
646/// This dramatically reduces memory usage for the frozen base weights.
647#[derive(Debug, Clone)]
648pub struct QLoraLinear {
649    /// 4-bit quantized weight (two values packed per byte)
650    quantized_weight: Vec<u8>,
651    /// Per-group dequantization scale
652    scale: Array1<f32>,
653    /// Per-group zero point
654    zero_point: Array1<f32>,
655    /// Quantization group size
656    group_size: usize,
657    /// LoRA A matrix (rank, in_features) — full precision
658    lora_a: Array2<f32>,
659    /// LoRA B matrix (out_features, rank) — full precision
660    lora_b: Array2<f32>,
661    /// Output features dimension
662    out_features: usize,
663    /// Input features dimension
664    in_features: usize,
665    /// LoRA rank
666    rank: usize,
667    /// LoRA alpha
668    alpha: f32,
669    /// Computed scaling = alpha / rank
670    scaling: f32,
671}
672
673impl QLoraLinear {
674    /// Create a QLoRA layer from a full-precision weight matrix.
675    ///
676    /// The weight is quantized to 4-bit NF4 format with the given group size.
677    /// LoRA matrices are initialized as in standard LoRA (A=Kaiming, B=zeros).
678    pub fn from_weight(
679        weight: Array2<f32>,
680        rank: usize,
681        alpha: f32,
682        group_size: usize,
683    ) -> ModelResult<Self> {
684        if rank == 0 {
685            return Err(ModelError::invalid_config("QLoRA rank must be > 0"));
686        }
687        if alpha <= 0.0 {
688            return Err(ModelError::invalid_config("QLoRA alpha must be > 0.0"));
689        }
690        if group_size == 0 {
691            return Err(ModelError::invalid_config("QLoRA group_size must be > 0"));
692        }
693
694        let (out_features, in_features) = weight.dim();
695        if out_features == 0 || in_features == 0 {
696            return Err(ModelError::invalid_config(
697                "Weight matrix dimensions must be > 0",
698            ));
699        }
700        if rank > out_features.min(in_features) {
701            return Err(ModelError::invalid_config(format!(
702                "QLoRA rank ({}) must not exceed min(out, in) = {}",
703                rank,
704                out_features.min(in_features)
705            )));
706        }
707
708        // Flatten weight for quantization
709        let total_elements = out_features * in_features;
710        let num_groups = total_elements.div_ceil(group_size);
711
712        let flat: Vec<f32> = weight.iter().copied().collect();
713
714        let mut scale = Array1::zeros(num_groups);
715        let mut zero_point = Array1::zeros(num_groups);
716        // Two 4-bit values per byte
717        let packed_len = total_elements.div_ceil(2);
718        let mut quantized_weight = vec![0u8; packed_len];
719
720        // Quantize group by group
721        for g in 0..num_groups {
722            let start = g * group_size;
723            let end = (start + group_size).min(total_elements);
724            let group = &flat[start..end];
725
726            // Find absmax for the group
727            let abs_max = group
728                .iter()
729                .map(|v| v.abs())
730                .fold(0.0_f32, f32::max)
731                .max(1e-10);
732
733            scale[g] = abs_max;
734            zero_point[g] = 0.0; // symmetric quantization
735
736            // Quantize each element to nearest NF4 level
737            for (k, &val) in group.iter().enumerate() {
738                let normalized = (val / abs_max).clamp(-1.0, 1.0);
739                let quant_idx = find_nearest_nf4(normalized);
740                let flat_idx = start + k;
741                let byte_idx = flat_idx / 2;
742                if flat_idx.is_multiple_of(2) {
743                    quantized_weight[byte_idx] |= quant_idx;
744                } else {
745                    quantized_weight[byte_idx] |= quant_idx << 4;
746                }
747            }
748        }
749
750        // Initialize LoRA matrices
751        let kaiming_scale = (2.0 / in_features as f32).sqrt();
752        let mut rng = SeededRng::new(137 + in_features as u64 + out_features as u64);
753        let lora_a = Array2::from_shape_fn((rank, in_features), |_| rng.next_f32() * kaiming_scale);
754        let lora_b = Array2::zeros((out_features, rank));
755
756        let scaling = alpha / rank as f32;
757
758        Ok(Self {
759            quantized_weight,
760            scale,
761            zero_point,
762            group_size,
763            lora_a,
764            lora_b,
765            out_features,
766            in_features,
767            rank,
768            alpha,
769            scaling,
770        })
771    }
772
773    /// Dequantize the weight matrix back to full precision.
774    ///
775    /// This is an approximate reconstruction — quantization is lossy.
776    pub fn dequantize_weight(&self) -> ModelResult<Array2<f32>> {
777        let total_elements = self.out_features * self.in_features;
778        let num_groups = total_elements.div_ceil(self.group_size);
779        let mut flat = vec![0.0_f32; total_elements];
780
781        for g in 0..num_groups {
782            let start = g * self.group_size;
783            let end = (start + self.group_size).min(total_elements);
784            let s = self.scale[g];
785
786            for (offset, val) in flat[start..end].iter_mut().enumerate() {
787                let flat_idx = start + offset;
788                let byte_idx = flat_idx / 2;
789                let quant_idx = if flat_idx.is_multiple_of(2) {
790                    self.quantized_weight[byte_idx] & 0x0F
791                } else {
792                    (self.quantized_weight[byte_idx] >> 4) & 0x0F
793                };
794                *val = NF4_LEVELS[quant_idx as usize] * s;
795            }
796        }
797
798        Array2::from_shape_vec((self.out_features, self.in_features), flat).map_err(|e| {
799            ModelError::invalid_config(format!("Failed to reshape dequantized weight: {}", e))
800        })
801    }
802
803    /// Forward pass: dequantize weight, compute W @ x + scaling * B @ (A @ x)
804    pub fn forward(&self, input: &Array1<f32>) -> ModelResult<Array1<f32>> {
805        if input.len() != self.in_features {
806            return Err(ModelError::dimension_mismatch(
807                "QLoraLinear forward input",
808                self.in_features,
809                input.len(),
810            ));
811        }
812
813        let weight = self.dequantize_weight()?;
814
815        // W @ x
816        let mut output = Array1::zeros(self.out_features);
817        for i in 0..self.out_features {
818            let mut sum = 0.0_f32;
819            for j in 0..self.in_features {
820                sum += weight[[i, j]] * input[j];
821            }
822            output[i] = sum;
823        }
824
825        // LoRA contribution: scaling * B @ (A @ x)
826        let mut a_x = Array1::zeros(self.rank);
827        for r in 0..self.rank {
828            let mut sum = 0.0_f32;
829            for j in 0..self.in_features {
830                sum += self.lora_a[[r, j]] * input[j];
831            }
832            a_x[r] = sum;
833        }
834
835        for i in 0..self.out_features {
836            let mut sum = 0.0_f32;
837            for r in 0..self.rank {
838                sum += self.lora_b[[i, r]] * a_x[r];
839            }
840            output[i] += self.scaling * sum;
841        }
842
843        Ok(output)
844    }
845
846    /// Memory saved compared to storing full fp32 weights, in bytes.
847    pub fn memory_saved_bytes(&self) -> usize {
848        let total_elements = self.out_features * self.in_features;
849        let fp32_bytes = total_elements * 4; // 4 bytes per f32
850        let packed_bytes = self.quantized_weight.len(); // 0.5 bytes per element
851        let num_groups = total_elements.div_ceil(self.group_size);
852        let scale_bytes = num_groups * 4; // scale: f32 per group
853        let zero_point_bytes = num_groups * 4; // zero_point: f32 per group
854        let quantized_total = packed_bytes + scale_bytes + zero_point_bytes;
855
856        fp32_bytes.saturating_sub(quantized_total)
857    }
858
859    /// Number of trainable parameters (LoRA A and B)
860    pub fn trainable_params(&self) -> usize {
861        self.rank * (self.in_features + self.out_features)
862    }
863
864    /// Get the LoRA A matrix
865    pub fn lora_a(&self) -> &Array2<f32> {
866        &self.lora_a
867    }
868
869    /// Get the LoRA B matrix
870    pub fn lora_b(&self) -> &Array2<f32> {
871        &self.lora_b
872    }
873
874    /// Get the quantization group size
875    pub fn group_size(&self) -> usize {
876        self.group_size
877    }
878
879    /// Get the rank
880    pub fn rank(&self) -> usize {
881        self.rank
882    }
883
884    /// Get output features
885    pub fn out_features(&self) -> usize {
886        self.out_features
887    }
888
889    /// Get input features
890    pub fn in_features(&self) -> usize {
891        self.in_features
892    }
893
894    /// Get the alpha scaling factor
895    pub fn alpha(&self) -> f32 {
896        self.alpha
897    }
898
899    /// Get the per-group zero points
900    pub fn zero_point(&self) -> &Array1<f32> {
901        &self.zero_point
902    }
903
904    /// Get the per-group scales
905    pub fn scale(&self) -> &Array1<f32> {
906        &self.scale
907    }
908}
909
910/// Find the nearest NF4 quantization level index for a normalized value in [-1, 1].
911fn find_nearest_nf4(value: f32) -> u8 {
912    let mut best_idx = 0u8;
913    let mut best_dist = f32::MAX;
914    for (i, &level) in NF4_LEVELS.iter().enumerate() {
915        let dist = (value - level).abs();
916        if dist < best_dist {
917            best_dist = dist;
918            best_idx = i as u8;
919        }
920    }
921    best_idx
922}
923
924// ---------------------------------------------------------------------------
925// Tests
926// ---------------------------------------------------------------------------
927
928#[cfg(test)]
929mod tests {
930    use super::*;
931    use scirs2_core::ndarray::Array2;
932
933    /// Helper: create a simple weight matrix with known values
934    fn make_weight(out: usize, inp: usize) -> Array2<f32> {
935        Array2::from_shape_fn((out, inp), |(i, j)| (i * inp + j) as f32 * 0.01)
936    }
937
938    #[test]
939    fn test_lora_linear_creation() -> ModelResult<()> {
940        let weight = make_weight(64, 32);
941        let lora = LoraLinear::new(weight.clone(), 8, 16.0)?;
942
943        // B is zero, so forward should equal W @ x
944        let input = Array1::from_vec(vec![1.0; 32]);
945        let output_lora = lora.forward(&input)?;
946
947        // Compute W @ x directly
948        let mut output_plain = Array1::zeros(64);
949        for i in 0..64 {
950            let mut sum = 0.0_f32;
951            for j in 0..32 {
952                sum += weight[[i, j]] * input[j];
953            }
954            output_plain[i] = sum;
955        }
956
957        // Should be identical since B = 0
958        for i in 0..64 {
959            assert!(
960                (output_lora[i] - output_plain[i]).abs() < 1e-5,
961                "Mismatch at index {}: lora={}, plain={}",
962                i,
963                output_lora[i],
964                output_plain[i]
965            );
966        }
967        Ok(())
968    }
969
970    #[test]
971    fn test_lora_linear_forward_with_nonzero_b() -> ModelResult<()> {
972        let weight = make_weight(16, 8);
973        let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
974
975        // Set B to non-zero
976        let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.1);
977        lora.set_lora_b(b)?;
978
979        let input = Array1::from_vec(vec![1.0; 8]);
980        let output_lora = lora.forward(&input)?;
981
982        // Plain W @ x
983        let mut output_plain = Array1::zeros(16);
984        for i in 0..16 {
985            let mut sum = 0.0_f32;
986            for j in 0..8 {
987                sum += weight[[i, j]] * input[j];
988            }
989            output_plain[i] = sum;
990        }
991
992        // Output should differ from plain since B != 0
993        let mut any_diff = false;
994        for i in 0..16 {
995            if (output_lora[i] - output_plain[i]).abs() > 1e-6 {
996                any_diff = true;
997                break;
998            }
999        }
1000        assert!(
1001            any_diff,
1002            "LoRA output should differ from plain output when B != 0"
1003        );
1004        Ok(())
1005    }
1006
1007    #[test]
1008    fn test_lora_linear_merge_unmerge() -> ModelResult<()> {
1009        let weight = make_weight(16, 8);
1010        let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
1011
1012        // Set non-zero B
1013        let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.01);
1014        lora.set_lora_b(b)?;
1015
1016        let input = Array1::from_vec(vec![0.5; 8]);
1017
1018        // Get output before merge
1019        let output_before = lora.forward(&input)?;
1020
1021        // Merge
1022        lora.merge()?;
1023        assert!(lora.is_merged());
1024
1025        // Output after merge should be the same
1026        let output_merged = lora.forward(&input)?;
1027        for i in 0..16 {
1028            assert!(
1029                (output_before[i] - output_merged[i]).abs() < 1e-4,
1030                "Merge changed output at {}: before={}, after={}",
1031                i,
1032                output_before[i],
1033                output_merged[i]
1034            );
1035        }
1036
1037        // Unmerge
1038        lora.unmerge()?;
1039        assert!(!lora.is_merged());
1040
1041        // Weight should be back to original
1042        for i in 0..16 {
1043            for j in 0..8 {
1044                assert!(
1045                    (lora.weight()[[i, j]] - weight[[i, j]]).abs() < 1e-4,
1046                    "Unmerge did not restore weight at [{}, {}]",
1047                    i,
1048                    j
1049                );
1050            }
1051        }
1052        Ok(())
1053    }
1054
1055    #[test]
1056    fn test_lora_linear_trainable_params() -> ModelResult<()> {
1057        let weight = make_weight(64, 32);
1058        let lora = LoraLinear::new(weight, 8, 16.0)?;
1059
1060        // trainable = rank * (in + out) = 8 * (32 + 64) = 768
1061        assert_eq!(lora.trainable_params(), 768);
1062        // total = 64*32 + 768 = 2048 + 768 = 2816
1063        assert_eq!(lora.total_params(), 2816);
1064        Ok(())
1065    }
1066
1067    #[test]
1068    fn test_lora_linear_compression_ratio() -> ModelResult<()> {
1069        let weight = make_weight(256, 128);
1070        let lora = LoraLinear::new(weight, 8, 16.0)?;
1071
1072        let ratio = lora.compression_ratio();
1073        // trainable = 8 * (128 + 256) = 3072
1074        // total = 256*128 + 3072 = 32768 + 3072 = 35840
1075        // ratio = 3072 / 35840 ≈ 0.0857
1076        assert!(
1077            ratio < 1.0,
1078            "Compression ratio should be < 1.0, got {}",
1079            ratio
1080        );
1081        assert!(
1082            ratio > 0.0,
1083            "Compression ratio should be > 0.0, got {}",
1084            ratio
1085        );
1086
1087        let expected = 3072.0 / 35840.0;
1088        assert!(
1089            (ratio - expected).abs() < 1e-5,
1090            "Expected ratio ~{}, got {}",
1091            expected,
1092            ratio
1093        );
1094        Ok(())
1095    }
1096
1097    #[test]
1098    fn test_lora_adapter_multi_layer() -> ModelResult<()> {
1099        let config = LoraConfig::new(4, 8.0).with_target_modules(vec![
1100            "q_proj".into(),
1101            "k_proj".into(),
1102            "v_proj".into(),
1103        ]);
1104
1105        let mut adapter = LoraAdapter::new(config);
1106        adapter.add_layer("q_proj".into(), make_weight(32, 16))?;
1107        adapter.add_layer("k_proj".into(), make_weight(32, 16))?;
1108        adapter.add_layer("v_proj".into(), make_weight(32, 16))?;
1109
1110        assert_eq!(adapter.layer_names().len(), 3);
1111
1112        // Forward through each layer
1113        let input = Array1::from_vec(vec![1.0; 16]);
1114        for name in &["q_proj", "k_proj", "v_proj"] {
1115            let output = adapter.forward_layer(name, &input)?;
1116            assert_eq!(output.len(), 32);
1117        }
1118
1119        // Forward through nonexistent layer should fail
1120        let result = adapter.forward_layer("nonexistent", &input);
1121        assert!(result.is_err());
1122
1123        Ok(())
1124    }
1125
1126    #[test]
1127    fn test_lora_adapter_merge_all() -> ModelResult<()> {
1128        let config = LoraConfig::new(4, 8.0);
1129        let mut adapter = LoraAdapter::new(config);
1130
1131        adapter.add_layer("layer_0".into(), make_weight(16, 8))?;
1132        adapter.add_layer("layer_1".into(), make_weight(16, 8))?;
1133
1134        // Set non-zero B on one layer
1135        if let Some(layer) = adapter.get_layer_mut("layer_0") {
1136            let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.01);
1137            layer.set_lora_b(b)?;
1138        }
1139
1140        let input = Array1::from_vec(vec![0.5; 8]);
1141
1142        // Capture output before merge
1143        let out_before_0 = adapter.forward_layer("layer_0", &input)?;
1144        let out_before_1 = adapter.forward_layer("layer_1", &input)?;
1145
1146        // Merge all
1147        adapter.merge_all()?;
1148
1149        // Outputs should match
1150        let out_after_0 = adapter.forward_layer("layer_0", &input)?;
1151        let out_after_1 = adapter.forward_layer("layer_1", &input)?;
1152
1153        for i in 0..16 {
1154            assert!(
1155                (out_before_0[i] - out_after_0[i]).abs() < 1e-4,
1156                "layer_0 merge changed output"
1157            );
1158            assert!(
1159                (out_before_1[i] - out_after_1[i]).abs() < 1e-4,
1160                "layer_1 merge changed output"
1161            );
1162        }
1163        Ok(())
1164    }
1165
1166    #[test]
1167    fn test_lora_adapter_summary() -> ModelResult<()> {
1168        let config = LoraConfig::new(8, 16.0);
1169        let mut adapter = LoraAdapter::new(config);
1170
1171        adapter.add_layer("proj_q".into(), make_weight(64, 32))?;
1172        adapter.add_layer("proj_v".into(), make_weight(64, 32))?;
1173
1174        let summary = adapter.summary();
1175        assert_eq!(summary.num_layers, 2);
1176        assert_eq!(summary.rank, 8);
1177        assert!((summary.alpha - 16.0).abs() < 1e-6);
1178        // Each layer: trainable = 8*(32+64) = 768; two layers = 1536
1179        assert_eq!(summary.total_trainable, 1536);
1180        // Each layer original = 64*32 = 2048; two layers = 4096
1181        assert_eq!(summary.total_original, 4096);
1182        assert!(summary.compression_ratio > 0.0);
1183        assert!(summary.compression_ratio < 1.0);
1184        Ok(())
1185    }
1186
1187    #[test]
1188    fn test_lora_disable_enable() -> ModelResult<()> {
1189        let weight = make_weight(16, 8);
1190        let mut lora = LoraLinear::new(weight.clone(), 4, 8.0)?;
1191
1192        // Set non-zero B
1193        let b = Array2::from_shape_fn((16, 4), |(i, j)| (i + j) as f32 * 0.1);
1194        lora.set_lora_b(b)?;
1195
1196        let input = Array1::from_vec(vec![1.0; 8]);
1197
1198        // Compute plain W @ x
1199        let mut output_plain = Array1::zeros(16);
1200        for i in 0..16 {
1201            let mut sum = 0.0_f32;
1202            for j in 0..8 {
1203                sum += weight[[i, j]] * input[j];
1204            }
1205            output_plain[i] = sum;
1206        }
1207
1208        // With LoRA enabled, output differs
1209        let output_enabled = lora.forward(&input)?;
1210        let mut any_diff = false;
1211        for i in 0..16 {
1212            if (output_enabled[i] - output_plain[i]).abs() > 1e-6 {
1213                any_diff = true;
1214                break;
1215            }
1216        }
1217        assert!(any_diff, "Enabled LoRA should produce different output");
1218
1219        // Disable LoRA
1220        lora.disable();
1221        assert!(!lora.is_enabled());
1222
1223        let output_disabled = lora.forward(&input)?;
1224        for i in 0..16 {
1225            assert!(
1226                (output_disabled[i] - output_plain[i]).abs() < 1e-5,
1227                "Disabled LoRA should produce same output as plain W"
1228            );
1229        }
1230
1231        // Re-enable
1232        lora.enable();
1233        assert!(lora.is_enabled());
1234        let output_reenabled = lora.forward(&input)?;
1235        for i in 0..16 {
1236            assert!(
1237                (output_reenabled[i] - output_enabled[i]).abs() < 1e-5,
1238                "Re-enabled LoRA should match original enabled output"
1239            );
1240        }
1241        Ok(())
1242    }
1243
1244    #[test]
1245    fn test_qlora_creation() -> ModelResult<()> {
1246        let weight = make_weight(32, 16);
1247        let qlora = QLoraLinear::from_weight(weight, 4, 8.0, 64)?;
1248
1249        assert_eq!(qlora.out_features(), 32);
1250        assert_eq!(qlora.in_features(), 16);
1251        assert_eq!(qlora.rank(), 4);
1252        assert_eq!(qlora.group_size(), 64);
1253        assert_eq!(qlora.trainable_params(), 4 * (16 + 32));
1254        Ok(())
1255    }
1256
1257    #[test]
1258    fn test_qlora_forward() -> ModelResult<()> {
1259        let weight = make_weight(16, 8);
1260        let qlora = QLoraLinear::from_weight(weight, 4, 8.0, 32)?;
1261
1262        let input = Array1::from_vec(vec![1.0; 8]);
1263        let output = qlora.forward(&input)?;
1264
1265        assert_eq!(output.len(), 16);
1266        // Output should be finite
1267        for &val in output.iter() {
1268            assert!(
1269                val.is_finite(),
1270                "QLoRA output contains non-finite value: {}",
1271                val
1272            );
1273        }
1274        Ok(())
1275    }
1276
1277    #[test]
1278    fn test_qlora_memory_savings() -> ModelResult<()> {
1279        let weight = make_weight(256, 128);
1280        let qlora = QLoraLinear::from_weight(weight, 8, 16.0, 64)?;
1281
1282        let saved = qlora.memory_saved_bytes();
1283        assert!(
1284            saved > 0,
1285            "QLoRA should save memory compared to fp32, got saved={} bytes",
1286            saved
1287        );
1288
1289        // fp32 = 256*128*4 = 131072 bytes
1290        // quantized ≈ (256*128)/2 + groups*8 = 16384 + ~4096 = ~20480
1291        // saved ≈ 110592
1292        assert!(
1293            saved > 100_000,
1294            "Expected significant savings for 256x128 matrix, got {} bytes",
1295            saved
1296        );
1297        Ok(())
1298    }
1299
1300    #[test]
1301    fn test_lora_config_validation() -> ModelResult<()> {
1302        // Valid config
1303        let config = LoraConfig::new(8, 16.0);
1304        assert!(config.validate().is_ok());
1305
1306        // Invalid rank
1307        let bad_rank = LoraConfig::new(0, 16.0);
1308        assert!(bad_rank.validate().is_err());
1309
1310        // Invalid alpha
1311        let bad_alpha = LoraConfig::new(8, -1.0);
1312        assert!(bad_alpha.validate().is_err());
1313
1314        // Invalid dropout
1315        let bad_dropout = LoraConfig::new(8, 16.0).with_dropout(1.5);
1316        assert!(bad_dropout.validate().is_err());
1317
1318        Ok(())
1319    }
1320
1321    #[test]
1322    fn test_lora_batch_forward() -> ModelResult<()> {
1323        let weight = make_weight(16, 8);
1324        let lora = LoraLinear::new(weight, 4, 8.0)?;
1325
1326        let batch = Array2::from_shape_fn((3, 8), |(b, j)| (b * 8 + j) as f32 * 0.1);
1327        let output = lora.forward_batch(&batch)?;
1328
1329        assert_eq!(output.dim(), (3, 16));
1330
1331        // Each row of batch output should match single forward
1332        for b in 0..3 {
1333            let single_input = Array1::from_vec(batch.row(b).to_vec());
1334            let single_output = lora.forward(&single_input)?;
1335            for i in 0..16 {
1336                assert!(
1337                    (output[[b, i]] - single_output[i]).abs() < 1e-4,
1338                    "Batch output[{},{}]={} != single output[{}]={}",
1339                    b,
1340                    i,
1341                    output[[b, i]],
1342                    i,
1343                    single_output[i]
1344                );
1345            }
1346        }
1347        Ok(())
1348    }
1349
1350    #[test]
1351    fn test_qlora_dequantize_roundtrip() -> ModelResult<()> {
1352        // With small values, NF4 quantization should approximately recover them
1353        let weight = Array2::from_shape_fn((8, 4), |(i, j)| {
1354            ((i as f32 - 4.0) * 0.2 + (j as f32 - 2.0) * 0.1).clamp(-0.9, 0.9)
1355        });
1356
1357        let qlora = QLoraLinear::from_weight(weight.clone(), 2, 4.0, 16)?;
1358        let deq = qlora.dequantize_weight()?;
1359
1360        assert_eq!(deq.dim(), (8, 4));
1361
1362        // Quantization is lossy but should be in the right ballpark
1363        let mut max_err = 0.0_f32;
1364        for i in 0..8 {
1365            for j in 0..4 {
1366                let err = (weight[[i, j]] - deq[[i, j]]).abs();
1367                if err > max_err {
1368                    max_err = err;
1369                }
1370            }
1371        }
1372        // NF4 with small group sizes should have bounded error
1373        assert!(
1374            max_err < 0.5,
1375            "Maximum dequantization error {} is too large",
1376            max_err
1377        );
1378        Ok(())
1379    }
1380}