1use std::collections::HashMap;
9
10use candle_core::Tensor;
11
12use crate::error::{PeftError, Result};
13use crate::traits::Adapter;
14
15#[derive(Debug, Clone)]
17pub enum ModulePattern {
18 Exact(String),
20 Suffix(String),
22 Prefix(String),
24 All,
26}
27
28impl ModulePattern {
29 #[must_use]
37 pub fn parse(pattern: &str) -> Self {
38 match pattern {
39 "*" => Self::All,
40 s if s.starts_with("*.") => Self::Suffix(s[2..].to_string()),
41 s if s.ends_with(".*") => Self::Prefix(s[..s.len() - 2].to_string()),
42 s => Self::Exact(s.to_string()),
43 }
44 }
45
46 #[must_use]
48 pub fn matches(&self, module_name: &str) -> bool {
49 match self {
50 Self::Exact(name) => module_name == name,
51 Self::Suffix(suffix) => module_name.ends_with(suffix),
52 Self::Prefix(prefix) => module_name.starts_with(prefix),
53 Self::All => true,
54 }
55 }
56}
57
58struct ModuleAdapter<A: Adapter> {
60 adapter: A,
62 active: bool,
64}
65
66pub struct PeftModel<A: Adapter> {
70 module_adapters: HashMap<String, HashMap<String, ModuleAdapter<A>>>,
72 active_adapter: Option<String>,
74 adapter_names: Vec<String>,
76}
77
78impl<A: Adapter> PeftModel<A> {
79 #[must_use]
81 pub fn new() -> Self {
82 Self {
83 module_adapters: HashMap::new(),
84 active_adapter: None,
85 adapter_names: Vec::new(),
86 }
87 }
88
89 pub fn add_adapter<F>(
100 &mut self,
101 adapter_name: impl Into<String>,
102 pattern: &str,
103 module_names: &[&str],
104 adapter_factory: F,
105 ) -> Result<usize>
106 where
107 F: Fn(&str) -> Result<A>,
108 {
109 let adapter_name = adapter_name.into();
110 let pattern = ModulePattern::parse(pattern);
111 let mut count = 0;
112
113 for &module_name in module_names {
114 if pattern.matches(module_name) {
115 let adapter = adapter_factory(module_name)?;
116 let module_name_owned = module_name.to_string();
117
118 let module_entry = self.module_adapters.entry(module_name_owned).or_default();
119
120 module_entry.insert(
121 adapter_name.clone(),
122 ModuleAdapter {
123 adapter,
124 active: self.active_adapter.is_none(),
125 },
126 );
127 count += 1;
128 }
129 }
130
131 if !self.adapter_names.contains(&adapter_name) {
133 self.adapter_names.push(adapter_name.clone());
134 }
135
136 if self.active_adapter.is_none() && count > 0 {
138 self.active_adapter = Some(adapter_name);
139 }
140
141 Ok(count)
142 }
143
144 pub fn set_adapter(&mut self, module_name: &str, adapter_name: &str) -> Result<()> {
149 let adapters = self.module_adapters.get_mut(module_name).ok_or_else(|| {
150 PeftError::AdapterNotFound {
151 name: format!("module '{module_name}' not found"),
152 }
153 })?;
154
155 if !adapters.contains_key(adapter_name) {
156 return Err(PeftError::AdapterNotFound {
157 name: format!("adapter '{adapter_name}' not found in module '{module_name}'"),
158 });
159 }
160
161 for adapter_entry in adapters.values_mut() {
163 adapter_entry.active = false;
164 }
165
166 if let Some(entry) = adapters.get_mut(adapter_name) {
168 entry.active = true;
169 }
170
171 Ok(())
172 }
173
174 pub fn set_adapter_all(&mut self, adapter_name: impl Into<String>) -> Result<()> {
179 let adapter_name = adapter_name.into();
180
181 if !self.adapter_names.contains(&adapter_name) {
182 return Err(PeftError::AdapterNotFound { name: adapter_name });
183 }
184
185 for adapters in self.module_adapters.values_mut() {
186 for entry in adapters.values_mut() {
188 entry.active = false;
189 }
190 if let Some(entry) = adapters.get_mut(&adapter_name) {
192 entry.active = true;
193 }
194 }
195
196 self.active_adapter = Some(adapter_name);
197 Ok(())
198 }
199
200 #[must_use]
202 pub fn active_adapter_name(&self) -> Option<&str> {
203 self.active_adapter.as_deref()
204 }
205
206 #[must_use]
208 pub fn adapter_names(&self) -> &[String] {
209 &self.adapter_names
210 }
211
212 #[must_use]
214 pub fn module_names(&self) -> Vec<&str> {
215 self.module_adapters.keys().map(String::as_str).collect()
216 }
217
218 #[must_use]
220 pub fn has_adapter(&self, module_name: &str) -> bool {
221 self.module_adapters.contains_key(module_name)
222 }
223
224 pub fn forward_module(
234 &self,
235 module_name: &str,
236 input: &Tensor,
237 base_output: Option<&Tensor>,
238 ) -> Result<Tensor> {
239 let adapters =
240 self.module_adapters
241 .get(module_name)
242 .ok_or_else(|| PeftError::AdapterNotFound {
243 name: format!("module '{module_name}' not found"),
244 })?;
245
246 for entry in adapters.values() {
248 if entry.active {
249 return entry.adapter.forward(input, base_output);
250 }
251 }
252
253 Err(PeftError::AdapterNotFound {
254 name: format!("no active adapter for module '{module_name}'"),
255 })
256 }
257
258 pub fn get_adapter(&self, module_name: &str, adapter_name: &str) -> Result<&A> {
263 let adapters =
264 self.module_adapters
265 .get(module_name)
266 .ok_or_else(|| PeftError::AdapterNotFound {
267 name: format!("module '{module_name}' not found"),
268 })?;
269
270 adapters
271 .get(adapter_name)
272 .map(|entry| &entry.adapter)
273 .ok_or_else(|| PeftError::AdapterNotFound {
274 name: format!("adapter '{adapter_name}' not found in module '{module_name}'"),
275 })
276 }
277
278 #[must_use]
280 pub fn num_parameters(&self) -> usize {
281 self.module_adapters
282 .values()
283 .flat_map(|adapters| adapters.values())
284 .filter(|entry| entry.active)
285 .map(|entry| entry.adapter.num_parameters())
286 .sum()
287 }
288
289 #[must_use]
291 pub fn num_modules(&self) -> usize {
292 self.module_adapters.len()
293 }
294}
295
296impl<A: Adapter> Default for PeftModel<A> {
297 fn default() -> Self {
298 Self::new()
299 }
300}
301
302pub fn get_peft_model<A: Adapter, F>(
318 module_names: &[&str],
319 pattern: &str,
320 adapter_name: impl Into<String>,
321 adapter_factory: F,
322) -> Result<PeftModel<A>>
323where
324 F: Fn(&str) -> Result<A>,
325{
326 let mut model = PeftModel::new();
327 model.add_adapter(adapter_name, pattern, module_names, adapter_factory)?;
328 Ok(model)
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use crate::{LoraConfig, LoraLayer};
335 use candle_core::{DType, Device, Tensor};
336
337 #[test]
338 fn test_module_pattern_exact() {
339 let pattern = ModulePattern::parse("encoder.layer.0");
340 assert!(pattern.matches("encoder.layer.0"));
341 assert!(!pattern.matches("encoder.layer.1"));
342 assert!(!pattern.matches("decoder.layer.0"));
343 }
344
345 #[test]
346 fn test_module_pattern_suffix() {
347 let pattern = ModulePattern::parse("*.attention");
348 assert!(pattern.matches("layer.0.attention"));
349 assert!(pattern.matches("encoder.layer.0.attention"));
350 assert!(!pattern.matches("attention.output"));
351 }
352
353 #[test]
354 fn test_module_pattern_prefix() {
355 let pattern = ModulePattern::parse("encoder.*");
356 assert!(pattern.matches("encoder.layer.0"));
357 assert!(pattern.matches("encoder.attention"));
358 assert!(!pattern.matches("decoder.layer.0"));
359 }
360
361 #[test]
362 fn test_module_pattern_all() {
363 let pattern = ModulePattern::parse("*");
364 assert!(pattern.matches("anything"));
365 assert!(pattern.matches("encoder.layer.0"));
366 assert!(pattern.matches(""));
367 }
368
369 #[test]
370 fn test_peft_model_creation() {
371 let model: PeftModel<LoraLayer> = PeftModel::new();
372 assert!(model.module_names().is_empty());
373 assert!(model.active_adapter_name().is_none());
374 }
375
376 #[test]
377 fn test_add_adapter_with_pattern() -> Result<()> {
378 let mut model: PeftModel<LoraLayer> = PeftModel::new();
379 let device = Device::Cpu;
380 let config = LoraConfig::default();
381
382 let module_names = vec![
383 "encoder.layer.0.attention",
384 "encoder.layer.0.mlp",
385 "encoder.layer.1.attention",
386 "encoder.layer.1.mlp",
387 "decoder.layer.0.attention",
388 ];
389
390 let count = model.add_adapter("lora", "*.attention", &module_names, |_| {
391 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
392 })?;
393
394 assert_eq!(count, 3); assert_eq!(model.active_adapter_name(), Some("lora"));
396 assert!(model.has_adapter("encoder.layer.0.attention"));
397 assert!(model.has_adapter("encoder.layer.1.attention"));
398 assert!(model.has_adapter("decoder.layer.0.attention"));
399 assert!(!model.has_adapter("encoder.layer.0.mlp"));
400
401 Ok(())
402 }
403
404 #[test]
405 fn test_set_adapter() -> Result<()> {
406 let mut model: PeftModel<LoraLayer> = PeftModel::new();
407 let device = Device::Cpu;
408 let config = LoraConfig::default();
409
410 let module_names = vec!["layer.0"];
411
412 model.add_adapter("adapter1", "*", &module_names, |_| {
413 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
414 })?;
415
416 model.add_adapter("adapter2", "*", &module_names, |_| {
417 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
418 })?;
419
420 model.set_adapter("layer.0", "adapter2")?;
422
423 Ok(())
424 }
425
426 #[test]
427 fn test_set_adapter_all() -> Result<()> {
428 let mut model: PeftModel<LoraLayer> = PeftModel::new();
429 let device = Device::Cpu;
430 let config = LoraConfig::default();
431
432 let module_names = vec!["layer.0", "layer.1"];
433
434 model.add_adapter("adapter1", "*", &module_names, |_| {
435 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
436 })?;
437
438 model.add_adapter("adapter2", "*", &module_names, |_| {
439 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
440 })?;
441
442 assert_eq!(model.active_adapter_name(), Some("adapter1"));
443
444 model.set_adapter_all("adapter2")?;
445 assert_eq!(model.active_adapter_name(), Some("adapter2"));
446
447 Ok(())
448 }
449
450 #[test]
451 fn test_forward_module() -> Result<()> {
452 let mut model: PeftModel<LoraLayer> = PeftModel::new();
453 let device = Device::Cpu;
454 let config = LoraConfig::default();
455
456 let module_names = vec!["layer.0"];
457
458 model.add_adapter("lora", "*", &module_names, |_| {
459 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
460 })?;
461
462 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device)?;
463 let output = model.forward_module("layer.0", &input, None)?;
464
465 assert_eq!(output.dims(), &[1, 10, 768]);
466
467 Ok(())
468 }
469
470 #[test]
471 fn test_num_parameters() -> Result<()> {
472 let mut model: PeftModel<LoraLayer> = PeftModel::new();
473 let device = Device::Cpu;
474 let config = LoraConfig::default();
475
476 let module_names = vec!["layer.0", "layer.1"];
477
478 model.add_adapter("lora", "*", &module_names, |_| {
479 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
480 })?;
481
482 assert_eq!(model.num_parameters(), 2 * (768 * 8 + 8 * 768));
484
485 Ok(())
486 }
487
488 #[test]
489 fn test_get_peft_model() -> Result<()> {
490 let device = Device::Cpu;
491 let config = LoraConfig::default();
492
493 let module_names = vec!["layer.0.attention", "layer.0.mlp", "layer.1.attention"];
494
495 let model = get_peft_model(&module_names, "*.attention", "lora", |_| {
496 LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
497 })?;
498
499 assert_eq!(model.num_modules(), 2);
500 assert!(model.has_adapter("layer.0.attention"));
501 assert!(model.has_adapter("layer.1.attention"));
502 assert!(!model.has_adapter("layer.0.mlp"));
503
504 Ok(())
505 }
506}