kizzasi_inference/
lora.rs

1//! LoRA (Low-Rank Adaptation) adapter loading for inference
2//!
3//! This module provides support for loading and applying LoRA adapters
4//! at inference time, allowing efficient model fine-tuning and adaptation.
5
6use crate::error::{InferenceError, InferenceResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12/// LoRA adapter configuration
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct LoraConfig {
15    /// Rank of the low-rank matrices
16    pub rank: usize,
17    /// Scaling factor (alpha / rank)
18    pub alpha: f32,
19    /// Dropout rate for LoRA layers
20    pub dropout: f32,
21    /// Target modules to apply LoRA to
22    pub target_modules: Vec<String>,
23}
24
25impl Default for LoraConfig {
26    fn default() -> Self {
27        Self {
28            rank: 8,
29            alpha: 16.0,
30            dropout: 0.0,
31            target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
32        }
33    }
34}
35
36impl LoraConfig {
37    /// Create a new LoRA configuration
38    pub fn new() -> Self {
39        Self::default()
40    }
41
42    /// Set the rank
43    pub fn rank(mut self, rank: usize) -> Self {
44        self.rank = rank;
45        self
46    }
47
48    /// Set alpha (scaling factor)
49    pub fn alpha(mut self, alpha: f32) -> Self {
50        self.alpha = alpha;
51        self
52    }
53
54    /// Set dropout rate
55    pub fn dropout(mut self, dropout: f32) -> Self {
56        self.dropout = dropout;
57        self
58    }
59
60    /// Add target module
61    pub fn add_target_module(mut self, module: impl Into<String>) -> Self {
62        self.target_modules.push(module.into());
63        self
64    }
65
66    /// Get the effective scaling factor
67    pub fn scaling(&self) -> f32 {
68        self.alpha / self.rank as f32
69    }
70}
71
72/// A LoRA adapter consisting of two low-rank matrices
73#[derive(Debug, Clone)]
74pub struct LoraAdapter {
75    /// Low-rank matrix A (rank × in_features)
76    pub lora_a: Array2<f32>,
77    /// Low-rank matrix B (out_features × rank)
78    pub lora_b: Array2<f32>,
79    /// Scaling factor
80    pub scaling: f32,
81    /// Adapter name/identifier
82    pub name: String,
83}
84
85impl LoraAdapter {
86    /// Create a new LoRA adapter
87    pub fn new(
88        lora_a: Array2<f32>,
89        lora_b: Array2<f32>,
90        scaling: f32,
91        name: impl Into<String>,
92    ) -> InferenceResult<Self> {
93        // Validate dimensions: A is (rank, in_features), B is (out_features, rank)
94        let rank_a = lora_a.nrows();
95        let rank_b = lora_b.ncols();
96
97        if rank_a != rank_b {
98            return Err(InferenceError::DimensionMismatch {
99                expected: rank_a,
100                got: rank_b,
101            });
102        }
103
104        Ok(Self {
105            lora_a,
106            lora_b,
107            scaling,
108            name: name.into(),
109        })
110    }
111
112    /// Get the rank of this adapter
113    pub fn rank(&self) -> usize {
114        self.lora_a.nrows()
115    }
116
117    /// Get input features dimension
118    pub fn in_features(&self) -> usize {
119        self.lora_a.ncols()
120    }
121
122    /// Get output features dimension
123    pub fn out_features(&self) -> usize {
124        self.lora_b.nrows()
125    }
126
127    /// Apply the LoRA adapter to an input
128    ///
129    /// Computes: output = input + scaling * (input @ A^T @ B^T)
130    pub fn apply(&self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
131        if input.len() != self.in_features() {
132            return Err(InferenceError::DimensionMismatch {
133                expected: self.in_features(),
134                got: input.len(),
135            });
136        }
137
138        // Compute input @ A^T
139        let mut hidden = Array1::zeros(self.rank());
140        for i in 0..self.rank() {
141            hidden[i] = input.dot(&self.lora_a.row(i));
142        }
143
144        // Compute hidden @ B^T
145        let mut output = Array1::zeros(self.out_features());
146        for i in 0..self.out_features() {
147            output[i] = hidden.dot(&self.lora_b.row(i));
148        }
149
150        // Scale and add to original input (identity residual)
151        // For dimension matching, we assume output has same dim as input for residual
152        // In practice, output dimension might differ - this is simplified
153        if output.len() == input.len() {
154            output = &output * self.scaling + input;
155        } else {
156            output = &output * self.scaling;
157        }
158
159        Ok(output)
160    }
161
162    /// Apply the LoRA adapter to a batch of inputs
163    pub fn apply_batch(&self, inputs: &Array2<f32>) -> InferenceResult<Array2<f32>> {
164        let batch_size = inputs.nrows();
165        let mut outputs = Vec::with_capacity(batch_size);
166
167        for i in 0..batch_size {
168            let input_row = inputs.row(i).to_owned();
169            let output_row = self.apply(&input_row)?;
170            outputs.push(output_row);
171        }
172
173        // Stack outputs into a 2D array
174        let out_dim = outputs[0].len();
175        let flat: Vec<f32> = outputs.into_iter().flat_map(|x| x.to_vec()).collect();
176
177        Array2::from_shape_vec((batch_size, out_dim), flat).map_err(|e| {
178            InferenceError::ForwardError(format!("Failed to stack LoRA outputs: {}", e))
179        })
180    }
181}
182
183/// Manager for multiple LoRA adapters
184pub struct LoraAdapterManager {
185    /// Map of adapter names to adapters
186    adapters: HashMap<String, LoraAdapter>,
187    /// Active adapter name (if any)
188    active_adapter: Option<String>,
189    /// Configuration
190    config: LoraConfig,
191}
192
193impl LoraAdapterManager {
194    /// Create a new adapter manager
195    pub fn new(config: LoraConfig) -> Self {
196        Self {
197            adapters: HashMap::new(),
198            active_adapter: None,
199            config,
200        }
201    }
202
203    /// Register a new adapter
204    pub fn register_adapter(&mut self, adapter: LoraAdapter) {
205        let name = adapter.name.clone();
206        self.adapters.insert(name, adapter);
207    }
208
209    /// Activate an adapter by name
210    pub fn activate(&mut self, name: impl AsRef<str>) -> InferenceResult<()> {
211        let name_ref = name.as_ref();
212        if !self.adapters.contains_key(name_ref) {
213            return Err(InferenceError::ForwardError(format!(
214                "Adapter '{}' not found",
215                name_ref
216            )));
217        }
218        self.active_adapter = Some(name_ref.to_string());
219        Ok(())
220    }
221
222    /// Deactivate the current adapter
223    pub fn deactivate(&mut self) {
224        self.active_adapter = None;
225    }
226
227    /// Get the active adapter
228    pub fn active_adapter(&self) -> Option<&LoraAdapter> {
229        self.active_adapter
230            .as_ref()
231            .and_then(|name| self.adapters.get(name))
232    }
233
234    /// Apply the active adapter (if any) to input
235    pub fn apply(&self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
236        if let Some(adapter) = self.active_adapter() {
237            adapter.apply(input)
238        } else {
239            // No active adapter, return input unchanged
240            Ok(input.clone())
241        }
242    }
243
244    /// Apply the active adapter to a batch
245    pub fn apply_batch(&self, inputs: &Array2<f32>) -> InferenceResult<Array2<f32>> {
246        if let Some(adapter) = self.active_adapter() {
247            adapter.apply_batch(inputs)
248        } else {
249            Ok(inputs.clone())
250        }
251    }
252
253    /// List all registered adapters
254    pub fn list_adapters(&self) -> Vec<&String> {
255        self.adapters.keys().collect()
256    }
257
258    /// Get adapter by name
259    pub fn get_adapter(&self, name: impl AsRef<str>) -> Option<&LoraAdapter> {
260        self.adapters.get(name.as_ref())
261    }
262
263    /// Remove an adapter
264    pub fn remove_adapter(&mut self, name: impl AsRef<str>) -> Option<LoraAdapter> {
265        let name_ref = name.as_ref();
266        // Deactivate if it's the active one
267        if self.active_adapter.as_deref() == Some(name_ref) {
268            self.deactivate();
269        }
270        self.adapters.remove(name_ref)
271    }
272
273    /// Get configuration
274    pub fn config(&self) -> &LoraConfig {
275        &self.config
276    }
277}
278
279/// Builder for creating LoRA adapters from components
280pub struct LoraAdapterBuilder {
281    lora_a: Option<Array2<f32>>,
282    lora_b: Option<Array2<f32>>,
283    scaling: f32,
284    name: String,
285}
286
287impl LoraAdapterBuilder {
288    /// Create a new builder
289    pub fn new(name: impl Into<String>) -> Self {
290        Self {
291            lora_a: None,
292            lora_b: None,
293            scaling: 1.0,
294            name: name.into(),
295        }
296    }
297
298    /// Set matrix A
299    pub fn lora_a(mut self, matrix: Array2<f32>) -> Self {
300        self.lora_a = Some(matrix);
301        self
302    }
303
304    /// Set matrix B
305    pub fn lora_b(mut self, matrix: Array2<f32>) -> Self {
306        self.lora_b = Some(matrix);
307        self
308    }
309
310    /// Set scaling factor
311    pub fn scaling(mut self, scaling: f32) -> Self {
312        self.scaling = scaling;
313        self
314    }
315
316    /// Set scaling from config
317    pub fn scaling_from_config(mut self, config: &LoraConfig) -> Self {
318        self.scaling = config.scaling();
319        self
320    }
321
322    /// Build the adapter
323    pub fn build(self) -> InferenceResult<LoraAdapter> {
324        let lora_a = self.lora_a.ok_or_else(|| {
325            InferenceError::ForwardError("LoRA matrix A not provided".to_string())
326        })?;
327        let lora_b = self.lora_b.ok_or_else(|| {
328            InferenceError::ForwardError("LoRA matrix B not provided".to_string())
329        })?;
330
331        LoraAdapter::new(lora_a, lora_b, self.scaling, self.name)
332    }
333}
334
335/// LoRA adapter loader for reading from disk
336pub struct LoraAdapterLoader {
337    /// Base path for adapter files
338    base_path: PathBuf,
339}
340
341impl LoraAdapterLoader {
342    /// Create a new loader with base path
343    pub fn new(base_path: impl AsRef<Path>) -> Self {
344        Self {
345            base_path: base_path.as_ref().to_path_buf(),
346        }
347    }
348
349    /// Load an adapter from directory
350    ///
351    /// Expected structure:
352    /// - adapter_name/
353    ///   - config.json
354    ///   - lora_a.safetensors (or .npy)
355    ///   - lora_b.safetensors (or .npy)
356    pub fn load(
357        &self,
358        adapter_name: impl AsRef<str>,
359    ) -> InferenceResult<(LoraAdapter, LoraConfig)> {
360        let adapter_path = self.base_path.join(adapter_name.as_ref());
361
362        // Load config
363        let config_path = adapter_path.join("config.json");
364        let config: LoraConfig = if config_path.exists() {
365            let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
366                InferenceError::ForwardError(format!("Failed to read config: {}", e))
367            })?;
368            serde_json::from_str(&config_str).map_err(|e| {
369                InferenceError::ForwardError(format!("Failed to parse config: {}", e))
370            })?
371        } else {
372            LoraConfig::default()
373        };
374
375        // For now, return a placeholder adapter since we don't have actual file loading
376        // In a real implementation, you'd load from safetensors or numpy files
377        let rank = config.rank;
378        let lora_a = Array2::zeros((rank, 128)); // Placeholder dimensions
379        let lora_b = Array2::zeros((128, rank));
380        let scaling = config.scaling();
381
382        let adapter = LoraAdapter::new(lora_a, lora_b, scaling, adapter_name.as_ref())?;
383        Ok((adapter, config))
384    }
385
386    /// List available adapters in the base path
387    pub fn list_available(&self) -> InferenceResult<Vec<String>> {
388        let mut adapters = Vec::new();
389
390        let entries = std::fs::read_dir(&self.base_path).map_err(|e| {
391            InferenceError::ForwardError(format!("Failed to read adapter directory: {}", e))
392        })?;
393
394        for entry in entries.flatten() {
395            if entry.path().is_dir() {
396                if let Some(name) = entry.file_name().to_str() {
397                    adapters.push(name.to_string());
398                }
399            }
400        }
401
402        Ok(adapters)
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409
410    #[test]
411    fn test_lora_config() {
412        let config = LoraConfig::new().rank(16).alpha(32.0);
413
414        assert_eq!(config.rank, 16);
415        assert_eq!(config.alpha, 32.0);
416        assert_eq!(config.scaling(), 2.0); // alpha / rank = 32 / 16
417    }
418
419    #[test]
420    fn test_lora_adapter_creation() {
421        let lora_a = Array2::from_shape_vec((4, 8), vec![1.0; 32]).unwrap();
422        let lora_b = Array2::from_shape_vec((8, 4), vec![0.5; 32]).unwrap();
423
424        let adapter = LoraAdapter::new(lora_a, lora_b, 0.5, "test").unwrap();
425
426        assert_eq!(adapter.rank(), 4);
427        assert_eq!(adapter.in_features(), 8);
428        assert_eq!(adapter.out_features(), 8);
429    }
430
431    #[test]
432    fn test_lora_adapter_dimension_mismatch() {
433        let lora_a = Array2::from_shape_vec((4, 8), vec![1.0; 32]).unwrap();
434        let lora_b = Array2::from_shape_vec((8, 5), vec![0.5; 40]).unwrap(); // Rank mismatch
435
436        let result = LoraAdapter::new(lora_a, lora_b, 0.5, "test");
437        assert!(result.is_err());
438    }
439
440    #[test]
441    fn test_lora_adapter_apply() {
442        let rank = 2;
443        let in_features = 4;
444        let out_features = 4;
445
446        let lora_a = Array2::from_shape_vec((rank, in_features), vec![0.1; 8]).unwrap();
447        let lora_b = Array2::from_shape_vec((out_features, rank), vec![0.2; 8]).unwrap();
448
449        let adapter = LoraAdapter::new(lora_a, lora_b, 1.0, "test").unwrap();
450
451        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
452        let output = adapter.apply(&input).unwrap();
453
454        assert_eq!(output.len(), out_features);
455        // Output should be input + LoRA modification
456    }
457
458    #[test]
459    fn test_lora_manager() {
460        let config = LoraConfig::new();
461        let mut manager = LoraAdapterManager::new(config);
462
463        let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
464        let lora_b = Array2::from_shape_vec((4, 2), vec![0.2; 8]).unwrap();
465        let adapter = LoraAdapter::new(lora_a, lora_b, 1.0, "adapter1").unwrap();
466
467        manager.register_adapter(adapter);
468        assert_eq!(manager.list_adapters().len(), 1);
469
470        manager.activate("adapter1").unwrap();
471        assert!(manager.active_adapter().is_some());
472
473        manager.deactivate();
474        assert!(manager.active_adapter().is_none());
475    }
476
477    #[test]
478    fn test_lora_manager_apply_without_adapter() {
479        let config = LoraConfig::new();
480        let manager = LoraAdapterManager::new(config);
481
482        let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
483        let output = manager.apply(&input).unwrap();
484
485        // Without adapter, output should equal input
486        assert_eq!(output, input);
487    }
488
489    #[test]
490    fn test_lora_builder() {
491        let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
492        let lora_b = Array2::from_shape_vec((4, 2), vec![0.2; 8]).unwrap();
493
494        let adapter = LoraAdapterBuilder::new("test")
495            .lora_a(lora_a)
496            .lora_b(lora_b)
497            .scaling(0.5)
498            .build()
499            .unwrap();
500
501        assert_eq!(adapter.name, "test");
502        assert_eq!(adapter.scaling, 0.5);
503    }
504
505    #[test]
506    fn test_lora_builder_missing_matrix() {
507        let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
508
509        let result = LoraAdapterBuilder::new("test")
510            .lora_a(lora_a)
511            // Missing lora_b
512            .build();
513
514        assert!(result.is_err());
515    }
516
517    #[test]
518    fn test_lora_adapter_batch() {
519        let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
520        let lora_b = Array2::from_shape_vec((4, 2), vec![0.2; 8]).unwrap();
521        let adapter = LoraAdapter::new(lora_a, lora_b, 1.0, "test").unwrap();
522
523        let inputs = Array2::from_shape_vec(
524            (3, 4),
525            vec![
526                1.0, 2.0, 3.0, 4.0, // Sample 1
527                5.0, 6.0, 7.0, 8.0, // Sample 2
528                9.0, 10.0, 11.0, 12.0, // Sample 3
529            ],
530        )
531        .unwrap();
532
533        let outputs = adapter.apply_batch(&inputs).unwrap();
534        assert_eq!(outputs.nrows(), 3);
535    }
536}