1use crate::error::{InferenceError, InferenceResult};
7use scirs2_core::ndarray::{Array1, Array2};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::path::{Path, PathBuf};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct LoraConfig {
15 pub rank: usize,
17 pub alpha: f32,
19 pub dropout: f32,
21 pub target_modules: Vec<String>,
23}
24
25impl Default for LoraConfig {
26 fn default() -> Self {
27 Self {
28 rank: 8,
29 alpha: 16.0,
30 dropout: 0.0,
31 target_modules: vec!["q_proj".to_string(), "v_proj".to_string()],
32 }
33 }
34}
35
36impl LoraConfig {
37 pub fn new() -> Self {
39 Self::default()
40 }
41
42 pub fn rank(mut self, rank: usize) -> Self {
44 self.rank = rank;
45 self
46 }
47
48 pub fn alpha(mut self, alpha: f32) -> Self {
50 self.alpha = alpha;
51 self
52 }
53
54 pub fn dropout(mut self, dropout: f32) -> Self {
56 self.dropout = dropout;
57 self
58 }
59
60 pub fn add_target_module(mut self, module: impl Into<String>) -> Self {
62 self.target_modules.push(module.into());
63 self
64 }
65
66 pub fn scaling(&self) -> f32 {
68 self.alpha / self.rank as f32
69 }
70}
71
72#[derive(Debug, Clone)]
74pub struct LoraAdapter {
75 pub lora_a: Array2<f32>,
77 pub lora_b: Array2<f32>,
79 pub scaling: f32,
81 pub name: String,
83}
84
85impl LoraAdapter {
86 pub fn new(
88 lora_a: Array2<f32>,
89 lora_b: Array2<f32>,
90 scaling: f32,
91 name: impl Into<String>,
92 ) -> InferenceResult<Self> {
93 let rank_a = lora_a.nrows();
95 let rank_b = lora_b.ncols();
96
97 if rank_a != rank_b {
98 return Err(InferenceError::DimensionMismatch {
99 expected: rank_a,
100 got: rank_b,
101 });
102 }
103
104 Ok(Self {
105 lora_a,
106 lora_b,
107 scaling,
108 name: name.into(),
109 })
110 }
111
112 pub fn rank(&self) -> usize {
114 self.lora_a.nrows()
115 }
116
117 pub fn in_features(&self) -> usize {
119 self.lora_a.ncols()
120 }
121
122 pub fn out_features(&self) -> usize {
124 self.lora_b.nrows()
125 }
126
127 pub fn apply(&self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
131 if input.len() != self.in_features() {
132 return Err(InferenceError::DimensionMismatch {
133 expected: self.in_features(),
134 got: input.len(),
135 });
136 }
137
138 let mut hidden = Array1::zeros(self.rank());
140 for i in 0..self.rank() {
141 hidden[i] = input.dot(&self.lora_a.row(i));
142 }
143
144 let mut output = Array1::zeros(self.out_features());
146 for i in 0..self.out_features() {
147 output[i] = hidden.dot(&self.lora_b.row(i));
148 }
149
150 if output.len() == input.len() {
154 output = &output * self.scaling + input;
155 } else {
156 output = &output * self.scaling;
157 }
158
159 Ok(output)
160 }
161
162 pub fn apply_batch(&self, inputs: &Array2<f32>) -> InferenceResult<Array2<f32>> {
164 let batch_size = inputs.nrows();
165 let mut outputs = Vec::with_capacity(batch_size);
166
167 for i in 0..batch_size {
168 let input_row = inputs.row(i).to_owned();
169 let output_row = self.apply(&input_row)?;
170 outputs.push(output_row);
171 }
172
173 let out_dim = outputs[0].len();
175 let flat: Vec<f32> = outputs.into_iter().flat_map(|x| x.to_vec()).collect();
176
177 Array2::from_shape_vec((batch_size, out_dim), flat).map_err(|e| {
178 InferenceError::ForwardError(format!("Failed to stack LoRA outputs: {}", e))
179 })
180 }
181}
182
183pub struct LoraAdapterManager {
185 adapters: HashMap<String, LoraAdapter>,
187 active_adapter: Option<String>,
189 config: LoraConfig,
191}
192
193impl LoraAdapterManager {
194 pub fn new(config: LoraConfig) -> Self {
196 Self {
197 adapters: HashMap::new(),
198 active_adapter: None,
199 config,
200 }
201 }
202
203 pub fn register_adapter(&mut self, adapter: LoraAdapter) {
205 let name = adapter.name.clone();
206 self.adapters.insert(name, adapter);
207 }
208
209 pub fn activate(&mut self, name: impl AsRef<str>) -> InferenceResult<()> {
211 let name_ref = name.as_ref();
212 if !self.adapters.contains_key(name_ref) {
213 return Err(InferenceError::ForwardError(format!(
214 "Adapter '{}' not found",
215 name_ref
216 )));
217 }
218 self.active_adapter = Some(name_ref.to_string());
219 Ok(())
220 }
221
222 pub fn deactivate(&mut self) {
224 self.active_adapter = None;
225 }
226
227 pub fn active_adapter(&self) -> Option<&LoraAdapter> {
229 self.active_adapter
230 .as_ref()
231 .and_then(|name| self.adapters.get(name))
232 }
233
234 pub fn apply(&self, input: &Array1<f32>) -> InferenceResult<Array1<f32>> {
236 if let Some(adapter) = self.active_adapter() {
237 adapter.apply(input)
238 } else {
239 Ok(input.clone())
241 }
242 }
243
244 pub fn apply_batch(&self, inputs: &Array2<f32>) -> InferenceResult<Array2<f32>> {
246 if let Some(adapter) = self.active_adapter() {
247 adapter.apply_batch(inputs)
248 } else {
249 Ok(inputs.clone())
250 }
251 }
252
253 pub fn list_adapters(&self) -> Vec<&String> {
255 self.adapters.keys().collect()
256 }
257
258 pub fn get_adapter(&self, name: impl AsRef<str>) -> Option<&LoraAdapter> {
260 self.adapters.get(name.as_ref())
261 }
262
263 pub fn remove_adapter(&mut self, name: impl AsRef<str>) -> Option<LoraAdapter> {
265 let name_ref = name.as_ref();
266 if self.active_adapter.as_deref() == Some(name_ref) {
268 self.deactivate();
269 }
270 self.adapters.remove(name_ref)
271 }
272
273 pub fn config(&self) -> &LoraConfig {
275 &self.config
276 }
277}
278
279pub struct LoraAdapterBuilder {
281 lora_a: Option<Array2<f32>>,
282 lora_b: Option<Array2<f32>>,
283 scaling: f32,
284 name: String,
285}
286
287impl LoraAdapterBuilder {
288 pub fn new(name: impl Into<String>) -> Self {
290 Self {
291 lora_a: None,
292 lora_b: None,
293 scaling: 1.0,
294 name: name.into(),
295 }
296 }
297
298 pub fn lora_a(mut self, matrix: Array2<f32>) -> Self {
300 self.lora_a = Some(matrix);
301 self
302 }
303
304 pub fn lora_b(mut self, matrix: Array2<f32>) -> Self {
306 self.lora_b = Some(matrix);
307 self
308 }
309
310 pub fn scaling(mut self, scaling: f32) -> Self {
312 self.scaling = scaling;
313 self
314 }
315
316 pub fn scaling_from_config(mut self, config: &LoraConfig) -> Self {
318 self.scaling = config.scaling();
319 self
320 }
321
322 pub fn build(self) -> InferenceResult<LoraAdapter> {
324 let lora_a = self.lora_a.ok_or_else(|| {
325 InferenceError::ForwardError("LoRA matrix A not provided".to_string())
326 })?;
327 let lora_b = self.lora_b.ok_or_else(|| {
328 InferenceError::ForwardError("LoRA matrix B not provided".to_string())
329 })?;
330
331 LoraAdapter::new(lora_a, lora_b, self.scaling, self.name)
332 }
333}
334
335pub struct LoraAdapterLoader {
337 base_path: PathBuf,
339}
340
341impl LoraAdapterLoader {
342 pub fn new(base_path: impl AsRef<Path>) -> Self {
344 Self {
345 base_path: base_path.as_ref().to_path_buf(),
346 }
347 }
348
349 pub fn load(
357 &self,
358 adapter_name: impl AsRef<str>,
359 ) -> InferenceResult<(LoraAdapter, LoraConfig)> {
360 let adapter_path = self.base_path.join(adapter_name.as_ref());
361
362 let config_path = adapter_path.join("config.json");
364 let config: LoraConfig = if config_path.exists() {
365 let config_str = std::fs::read_to_string(&config_path).map_err(|e| {
366 InferenceError::ForwardError(format!("Failed to read config: {}", e))
367 })?;
368 serde_json::from_str(&config_str).map_err(|e| {
369 InferenceError::ForwardError(format!("Failed to parse config: {}", e))
370 })?
371 } else {
372 LoraConfig::default()
373 };
374
375 let rank = config.rank;
378 let lora_a = Array2::zeros((rank, 128)); let lora_b = Array2::zeros((128, rank));
380 let scaling = config.scaling();
381
382 let adapter = LoraAdapter::new(lora_a, lora_b, scaling, adapter_name.as_ref())?;
383 Ok((adapter, config))
384 }
385
386 pub fn list_available(&self) -> InferenceResult<Vec<String>> {
388 let mut adapters = Vec::new();
389
390 let entries = std::fs::read_dir(&self.base_path).map_err(|e| {
391 InferenceError::ForwardError(format!("Failed to read adapter directory: {}", e))
392 })?;
393
394 for entry in entries.flatten() {
395 if entry.path().is_dir() {
396 if let Some(name) = entry.file_name().to_str() {
397 adapters.push(name.to_string());
398 }
399 }
400 }
401
402 Ok(adapters)
403 }
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409
410 #[test]
411 fn test_lora_config() {
412 let config = LoraConfig::new().rank(16).alpha(32.0);
413
414 assert_eq!(config.rank, 16);
415 assert_eq!(config.alpha, 32.0);
416 assert_eq!(config.scaling(), 2.0); }
418
419 #[test]
420 fn test_lora_adapter_creation() {
421 let lora_a = Array2::from_shape_vec((4, 8), vec![1.0; 32]).unwrap();
422 let lora_b = Array2::from_shape_vec((8, 4), vec![0.5; 32]).unwrap();
423
424 let adapter = LoraAdapter::new(lora_a, lora_b, 0.5, "test").unwrap();
425
426 assert_eq!(adapter.rank(), 4);
427 assert_eq!(adapter.in_features(), 8);
428 assert_eq!(adapter.out_features(), 8);
429 }
430
431 #[test]
432 fn test_lora_adapter_dimension_mismatch() {
433 let lora_a = Array2::from_shape_vec((4, 8), vec![1.0; 32]).unwrap();
434 let lora_b = Array2::from_shape_vec((8, 5), vec![0.5; 40]).unwrap(); let result = LoraAdapter::new(lora_a, lora_b, 0.5, "test");
437 assert!(result.is_err());
438 }
439
440 #[test]
441 fn test_lora_adapter_apply() {
442 let rank = 2;
443 let in_features = 4;
444 let out_features = 4;
445
446 let lora_a = Array2::from_shape_vec((rank, in_features), vec![0.1; 8]).unwrap();
447 let lora_b = Array2::from_shape_vec((out_features, rank), vec![0.2; 8]).unwrap();
448
449 let adapter = LoraAdapter::new(lora_a, lora_b, 1.0, "test").unwrap();
450
451 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
452 let output = adapter.apply(&input).unwrap();
453
454 assert_eq!(output.len(), out_features);
455 }
457
458 #[test]
459 fn test_lora_manager() {
460 let config = LoraConfig::new();
461 let mut manager = LoraAdapterManager::new(config);
462
463 let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
464 let lora_b = Array2::from_shape_vec((4, 2), vec![0.2; 8]).unwrap();
465 let adapter = LoraAdapter::new(lora_a, lora_b, 1.0, "adapter1").unwrap();
466
467 manager.register_adapter(adapter);
468 assert_eq!(manager.list_adapters().len(), 1);
469
470 manager.activate("adapter1").unwrap();
471 assert!(manager.active_adapter().is_some());
472
473 manager.deactivate();
474 assert!(manager.active_adapter().is_none());
475 }
476
477 #[test]
478 fn test_lora_manager_apply_without_adapter() {
479 let config = LoraConfig::new();
480 let manager = LoraAdapterManager::new(config);
481
482 let input = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
483 let output = manager.apply(&input).unwrap();
484
485 assert_eq!(output, input);
487 }
488
489 #[test]
490 fn test_lora_builder() {
491 let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
492 let lora_b = Array2::from_shape_vec((4, 2), vec![0.2; 8]).unwrap();
493
494 let adapter = LoraAdapterBuilder::new("test")
495 .lora_a(lora_a)
496 .lora_b(lora_b)
497 .scaling(0.5)
498 .build()
499 .unwrap();
500
501 assert_eq!(adapter.name, "test");
502 assert_eq!(adapter.scaling, 0.5);
503 }
504
505 #[test]
506 fn test_lora_builder_missing_matrix() {
507 let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
508
509 let result = LoraAdapterBuilder::new("test")
510 .lora_a(lora_a)
511 .build();
513
514 assert!(result.is_err());
515 }
516
517 #[test]
518 fn test_lora_adapter_batch() {
519 let lora_a = Array2::from_shape_vec((2, 4), vec![0.1; 8]).unwrap();
520 let lora_b = Array2::from_shape_vec((4, 2), vec![0.2; 8]).unwrap();
521 let adapter = LoraAdapter::new(lora_a, lora_b, 1.0, "test").unwrap();
522
523 let inputs = Array2::from_shape_vec(
524 (3, 4),
525 vec![
526 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ],
530 )
531 .unwrap();
532
533 let outputs = adapter.apply_batch(&inputs).unwrap();
534 assert_eq!(outputs.nrows(), 3);
535 }
536}