peft_rs/
model.rs

1//! Model integration for PEFT adapters.
2//!
3//! This module provides functionality for:
4//! - Wrapping models with PEFT adapter management
5//! - Pattern matching for module names (e.g., `*.attention`, `layer.*`)
6//! - Per-module adapter injection and switching
7
8use std::collections::HashMap;
9
10use candle_core::Tensor;
11
12use crate::error::{PeftError, Result};
13use crate::traits::Adapter;
14
15/// Pattern for matching module names.
16#[derive(Debug, Clone)]
17pub enum ModulePattern {
18    /// Match exact module name
19    Exact(String),
20    /// Match modules ending with suffix (e.g., `*.attention`)
21    Suffix(String),
22    /// Match modules starting with prefix (e.g., `layer.*`)
23    Prefix(String),
24    /// Match all modules
25    All,
26}
27
28impl ModulePattern {
29    /// Parse a pattern string into a `ModulePattern`.
30    ///
31    /// # Examples
32    /// - `"encoder.layer.0"` -> `Exact`
33    /// - `"*.attention"` -> `Suffix`
34    /// - `"layer.*"` -> `Prefix`
35    /// - `"*"` -> `All`
36    #[must_use]
37    pub fn parse(pattern: &str) -> Self {
38        match pattern {
39            "*" => Self::All,
40            s if s.starts_with("*.") => Self::Suffix(s[2..].to_string()),
41            s if s.ends_with(".*") => Self::Prefix(s[..s.len() - 2].to_string()),
42            s => Self::Exact(s.to_string()),
43        }
44    }
45
46    /// Check if a module name matches this pattern.
47    #[must_use]
48    pub fn matches(&self, module_name: &str) -> bool {
49        match self {
50            Self::Exact(name) => module_name == name,
51            Self::Suffix(suffix) => module_name.ends_with(suffix),
52            Self::Prefix(prefix) => module_name.starts_with(prefix),
53            Self::All => true,
54        }
55    }
56}
57
58/// Adapter entry for a specific module.
59struct ModuleAdapter<A: Adapter> {
60    /// The adapter instance
61    adapter: A,
62    /// Whether this adapter is currently active
63    active: bool,
64}
65
66/// PEFT model wrapper for managing adapters across modules.
67///
68/// Provides module-level adapter management with pattern-based targeting.
69pub struct PeftModel<A: Adapter> {
70    /// Map of module names to their adapters
71    module_adapters: HashMap<String, HashMap<String, ModuleAdapter<A>>>,
72    /// Currently active adapter name (global default)
73    active_adapter: Option<String>,
74    /// List of all registered adapter names
75    adapter_names: Vec<String>,
76}
77
78impl<A: Adapter> PeftModel<A> {
79    /// Create a new PEFT model wrapper.
80    #[must_use]
81    pub fn new() -> Self {
82        Self {
83            module_adapters: HashMap::new(),
84            active_adapter: None,
85            adapter_names: Vec::new(),
86        }
87    }
88
89    /// Add an adapter to modules matching the given pattern.
90    ///
91    /// # Arguments
92    /// * `adapter_name` - Unique name for the adapter
93    /// * `pattern` - Pattern to match module names
94    /// * `module_names` - List of all module names in the model
95    /// * `adapter_factory` - Function to create adapter instances
96    ///
97    /// # Errors
98    /// Returns an error if adapter creation fails
99    pub fn add_adapter<F>(
100        &mut self,
101        adapter_name: impl Into<String>,
102        pattern: &str,
103        module_names: &[&str],
104        adapter_factory: F,
105    ) -> Result<usize>
106    where
107        F: Fn(&str) -> Result<A>,
108    {
109        let adapter_name = adapter_name.into();
110        let pattern = ModulePattern::parse(pattern);
111        let mut count = 0;
112
113        for &module_name in module_names {
114            if pattern.matches(module_name) {
115                let adapter = adapter_factory(module_name)?;
116                let module_name_owned = module_name.to_string();
117
118                let module_entry = self.module_adapters.entry(module_name_owned).or_default();
119
120                module_entry.insert(
121                    adapter_name.clone(),
122                    ModuleAdapter {
123                        adapter,
124                        active: self.active_adapter.is_none(),
125                    },
126                );
127                count += 1;
128            }
129        }
130
131        // Track adapter name
132        if !self.adapter_names.contains(&adapter_name) {
133            self.adapter_names.push(adapter_name.clone());
134        }
135
136        // Set as active if first adapter
137        if self.active_adapter.is_none() && count > 0 {
138            self.active_adapter = Some(adapter_name);
139        }
140
141        Ok(count)
142    }
143
144    /// Set the active adapter for a specific module.
145    ///
146    /// # Errors
147    /// Returns an error if the module or adapter doesn't exist
148    pub fn set_adapter(&mut self, module_name: &str, adapter_name: &str) -> Result<()> {
149        let adapters = self.module_adapters.get_mut(module_name).ok_or_else(|| {
150            PeftError::AdapterNotFound {
151                name: format!("module '{module_name}' not found"),
152            }
153        })?;
154
155        if !adapters.contains_key(adapter_name) {
156            return Err(PeftError::AdapterNotFound {
157                name: format!("adapter '{adapter_name}' not found in module '{module_name}'"),
158            });
159        }
160
161        // Deactivate all adapters for this module
162        for adapter_entry in adapters.values_mut() {
163            adapter_entry.active = false;
164        }
165
166        // Activate the requested adapter
167        if let Some(entry) = adapters.get_mut(adapter_name) {
168            entry.active = true;
169        }
170
171        Ok(())
172    }
173
174    /// Set the active adapter for all modules.
175    ///
176    /// # Errors
177    /// Returns an error if the adapter doesn't exist in any module
178    pub fn set_adapter_all(&mut self, adapter_name: impl Into<String>) -> Result<()> {
179        let adapter_name = adapter_name.into();
180
181        if !self.adapter_names.contains(&adapter_name) {
182            return Err(PeftError::AdapterNotFound { name: adapter_name });
183        }
184
185        for adapters in self.module_adapters.values_mut() {
186            // Deactivate all
187            for entry in adapters.values_mut() {
188                entry.active = false;
189            }
190            // Activate the requested one if it exists
191            if let Some(entry) = adapters.get_mut(&adapter_name) {
192                entry.active = true;
193            }
194        }
195
196        self.active_adapter = Some(adapter_name);
197        Ok(())
198    }
199
200    /// Get the active adapter name.
201    #[must_use]
202    pub fn active_adapter_name(&self) -> Option<&str> {
203        self.active_adapter.as_deref()
204    }
205
206    /// Get all registered adapter names.
207    #[must_use]
208    pub fn adapter_names(&self) -> &[String] {
209        &self.adapter_names
210    }
211
212    /// Get module names that have adapters.
213    #[must_use]
214    pub fn module_names(&self) -> Vec<&str> {
215        self.module_adapters.keys().map(String::as_str).collect()
216    }
217
218    /// Check if a module has any adapters.
219    #[must_use]
220    pub fn has_adapter(&self, module_name: &str) -> bool {
221        self.module_adapters.contains_key(module_name)
222    }
223
224    /// Forward pass for a specific module.
225    ///
226    /// # Arguments
227    /// * `module_name` - Name of the module
228    /// * `input` - Input tensor
229    /// * `base_output` - Optional base layer output
230    ///
231    /// # Errors
232    /// Returns an error if module not found or no active adapter
233    pub fn forward_module(
234        &self,
235        module_name: &str,
236        input: &Tensor,
237        base_output: Option<&Tensor>,
238    ) -> Result<Tensor> {
239        let adapters =
240            self.module_adapters
241                .get(module_name)
242                .ok_or_else(|| PeftError::AdapterNotFound {
243                    name: format!("module '{module_name}' not found"),
244                })?;
245
246        // Find active adapter
247        for entry in adapters.values() {
248            if entry.active {
249                return entry.adapter.forward(input, base_output);
250            }
251        }
252
253        Err(PeftError::AdapterNotFound {
254            name: format!("no active adapter for module '{module_name}'"),
255        })
256    }
257
258    /// Get a reference to an adapter for a module.
259    ///
260    /// # Errors
261    /// Returns an error if module or adapter not found
262    pub fn get_adapter(&self, module_name: &str, adapter_name: &str) -> Result<&A> {
263        let adapters =
264            self.module_adapters
265                .get(module_name)
266                .ok_or_else(|| PeftError::AdapterNotFound {
267                    name: format!("module '{module_name}' not found"),
268                })?;
269
270        adapters
271            .get(adapter_name)
272            .map(|entry| &entry.adapter)
273            .ok_or_else(|| PeftError::AdapterNotFound {
274                name: format!("adapter '{adapter_name}' not found in module '{module_name}'"),
275            })
276    }
277
278    /// Get the total number of trainable parameters across all active adapters.
279    #[must_use]
280    pub fn num_parameters(&self) -> usize {
281        self.module_adapters
282            .values()
283            .flat_map(|adapters| adapters.values())
284            .filter(|entry| entry.active)
285            .map(|entry| entry.adapter.num_parameters())
286            .sum()
287    }
288
289    /// Get the number of modules with adapters.
290    #[must_use]
291    pub fn num_modules(&self) -> usize {
292        self.module_adapters.len()
293    }
294}
295
296impl<A: Adapter> Default for PeftModel<A> {
297    fn default() -> Self {
298        Self::new()
299    }
300}
301
302/// Create a PEFT model with adapters injected into matching modules.
303///
304/// This is a convenience function for common use cases.
305///
306/// # Arguments
307/// * `module_names` - List of all module names in the model
308/// * `pattern` - Pattern to match module names for adapter injection
309/// * `adapter_name` - Name for the adapter
310/// * `adapter_factory` - Function to create adapter instances
311///
312/// # Returns
313/// A `PeftModel` with adapters injected into matching modules
314///
315/// # Errors
316/// Returns an error if adapter creation fails
317pub fn get_peft_model<A: Adapter, F>(
318    module_names: &[&str],
319    pattern: &str,
320    adapter_name: impl Into<String>,
321    adapter_factory: F,
322) -> Result<PeftModel<A>>
323where
324    F: Fn(&str) -> Result<A>,
325{
326    let mut model = PeftModel::new();
327    model.add_adapter(adapter_name, pattern, module_names, adapter_factory)?;
328    Ok(model)
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use crate::{LoraConfig, LoraLayer};
335    use candle_core::{DType, Device, Tensor};
336
337    #[test]
338    fn test_module_pattern_exact() {
339        let pattern = ModulePattern::parse("encoder.layer.0");
340        assert!(pattern.matches("encoder.layer.0"));
341        assert!(!pattern.matches("encoder.layer.1"));
342        assert!(!pattern.matches("decoder.layer.0"));
343    }
344
345    #[test]
346    fn test_module_pattern_suffix() {
347        let pattern = ModulePattern::parse("*.attention");
348        assert!(pattern.matches("layer.0.attention"));
349        assert!(pattern.matches("encoder.layer.0.attention"));
350        assert!(!pattern.matches("attention.output"));
351    }
352
353    #[test]
354    fn test_module_pattern_prefix() {
355        let pattern = ModulePattern::parse("encoder.*");
356        assert!(pattern.matches("encoder.layer.0"));
357        assert!(pattern.matches("encoder.attention"));
358        assert!(!pattern.matches("decoder.layer.0"));
359    }
360
361    #[test]
362    fn test_module_pattern_all() {
363        let pattern = ModulePattern::parse("*");
364        assert!(pattern.matches("anything"));
365        assert!(pattern.matches("encoder.layer.0"));
366        assert!(pattern.matches(""));
367    }
368
369    #[test]
370    fn test_peft_model_creation() {
371        let model: PeftModel<LoraLayer> = PeftModel::new();
372        assert!(model.module_names().is_empty());
373        assert!(model.active_adapter_name().is_none());
374    }
375
376    #[test]
377    fn test_add_adapter_with_pattern() -> Result<()> {
378        let mut model: PeftModel<LoraLayer> = PeftModel::new();
379        let device = Device::Cpu;
380        let config = LoraConfig::default();
381
382        let module_names = vec![
383            "encoder.layer.0.attention",
384            "encoder.layer.0.mlp",
385            "encoder.layer.1.attention",
386            "encoder.layer.1.mlp",
387            "decoder.layer.0.attention",
388        ];
389
390        let count = model.add_adapter("lora", "*.attention", &module_names, |_| {
391            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
392        })?;
393
394        assert_eq!(count, 3); // 3 attention modules
395        assert_eq!(model.active_adapter_name(), Some("lora"));
396        assert!(model.has_adapter("encoder.layer.0.attention"));
397        assert!(model.has_adapter("encoder.layer.1.attention"));
398        assert!(model.has_adapter("decoder.layer.0.attention"));
399        assert!(!model.has_adapter("encoder.layer.0.mlp"));
400
401        Ok(())
402    }
403
404    #[test]
405    fn test_set_adapter() -> Result<()> {
406        let mut model: PeftModel<LoraLayer> = PeftModel::new();
407        let device = Device::Cpu;
408        let config = LoraConfig::default();
409
410        let module_names = vec!["layer.0"];
411
412        model.add_adapter("adapter1", "*", &module_names, |_| {
413            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
414        })?;
415
416        model.add_adapter("adapter2", "*", &module_names, |_| {
417            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
418        })?;
419
420        // Switch adapter for specific module
421        model.set_adapter("layer.0", "adapter2")?;
422
423        Ok(())
424    }
425
426    #[test]
427    fn test_set_adapter_all() -> Result<()> {
428        let mut model: PeftModel<LoraLayer> = PeftModel::new();
429        let device = Device::Cpu;
430        let config = LoraConfig::default();
431
432        let module_names = vec!["layer.0", "layer.1"];
433
434        model.add_adapter("adapter1", "*", &module_names, |_| {
435            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
436        })?;
437
438        model.add_adapter("adapter2", "*", &module_names, |_| {
439            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
440        })?;
441
442        assert_eq!(model.active_adapter_name(), Some("adapter1"));
443
444        model.set_adapter_all("adapter2")?;
445        assert_eq!(model.active_adapter_name(), Some("adapter2"));
446
447        Ok(())
448    }
449
450    #[test]
451    fn test_forward_module() -> Result<()> {
452        let mut model: PeftModel<LoraLayer> = PeftModel::new();
453        let device = Device::Cpu;
454        let config = LoraConfig::default();
455
456        let module_names = vec!["layer.0"];
457
458        model.add_adapter("lora", "*", &module_names, |_| {
459            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
460        })?;
461
462        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device)?;
463        let output = model.forward_module("layer.0", &input, None)?;
464
465        assert_eq!(output.dims(), &[1, 10, 768]);
466
467        Ok(())
468    }
469
470    #[test]
471    fn test_num_parameters() -> Result<()> {
472        let mut model: PeftModel<LoraLayer> = PeftModel::new();
473        let device = Device::Cpu;
474        let config = LoraConfig::default();
475
476        let module_names = vec!["layer.0", "layer.1"];
477
478        model.add_adapter("lora", "*", &module_names, |_| {
479            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
480        })?;
481
482        // 2 modules, each with 768*8 + 8*768 = 12,288 parameters
483        assert_eq!(model.num_parameters(), 2 * (768 * 8 + 8 * 768));
484
485        Ok(())
486    }
487
488    #[test]
489    fn test_get_peft_model() -> Result<()> {
490        let device = Device::Cpu;
491        let config = LoraConfig::default();
492
493        let module_names = vec!["layer.0.attention", "layer.0.mlp", "layer.1.attention"];
494
495        let model = get_peft_model(&module_names, "*.attention", "lora", |_| {
496            LoraLayer::new_with_zeros(768, 768, config.clone(), &device)
497        })?;
498
499        assert_eq!(model.num_modules(), 2);
500        assert!(model.has_adapter("layer.0.attention"));
501        assert!(model.has_adapter("layer.1.attention"));
502        assert!(!model.has_adapter("layer.0.mlp"));
503
504        Ok(())
505    }
506}