Skip to main content

peft_rs/
registry.rs

1//! Multi-adapter registry for managing multiple PEFT adapters.
2//!
3//! This module provides functionality for:
4//! - Registering multiple named adapters
5//! - Switching between active adapters
6//! - Managing adapter lifecycle
7//!
8//! Note: Adapter composition (combining multiple adapters) is planned for a future release.
9
10use std::collections::HashMap;
11
12use candle_core::Tensor;
13
14use crate::error::{PeftError, Result};
15use crate::traits::Adapter;
16
17/// Registry for managing multiple named adapters.
18///
19/// Allows switching between different adapters at runtime.
20///
21/// Note: Adapter composition (combining multiple adapters) is planned for a future release.
22pub struct AdapterRegistry<A: Adapter> {
23    /// Map of adapter names to adapters
24    adapters: HashMap<String, A>,
25    /// Currently active adapter name
26    active_adapter: Option<String>,
27}
28
29impl<A: Adapter> AdapterRegistry<A> {
30    /// Create a new empty adapter registry.
31    #[must_use]
32    pub fn new() -> Self {
33        Self {
34            adapters: HashMap::new(),
35            active_adapter: None,
36        }
37    }
38
39    /// Register a new adapter with the given name.
40    ///
41    /// # Arguments
42    /// * `name` - Unique name for the adapter
43    /// * `adapter` - The adapter to register
44    ///
45    /// # Errors
46    /// Returns an error if an adapter with this name already exists
47    pub fn register_adapter(&mut self, name: impl Into<String>, adapter: A) -> Result<()> {
48        let name = name.into();
49
50        if self.adapters.contains_key(&name) {
51            return Err(PeftError::AdapterExists { name });
52        }
53
54        self.adapters.insert(name.clone(), adapter);
55
56        // Set as active if it's the first adapter
57        if self.active_adapter.is_none() {
58            self.active_adapter = Some(name);
59        }
60
61        Ok(())
62    }
63
64    /// Set the active adapter by name.
65    ///
66    /// # Arguments
67    /// * `name` - Name of the adapter to activate
68    ///
69    /// # Errors
70    /// Returns an error if no adapter with this name exists
71    pub fn set_active_adapter(&mut self, name: impl Into<String>) -> Result<()> {
72        let name = name.into();
73
74        if !self.adapters.contains_key(&name) {
75            return Err(PeftError::AdapterNotFound { name });
76        }
77
78        self.active_adapter = Some(name);
79        Ok(())
80    }
81
82    /// Get a reference to the active adapter.
83    ///
84    /// # Errors
85    /// Returns an error if no adapter is currently active
86    pub fn get_active_adapter(&self) -> Result<&A> {
87        let name = self
88            .active_adapter
89            .as_ref()
90            .ok_or_else(|| PeftError::AdapterNotFound {
91                name: "no active adapter".to_string(),
92            })?;
93
94        self.adapters
95            .get(name)
96            .ok_or_else(|| PeftError::AdapterNotFound { name: name.clone() })
97    }
98
99    /// Get a mutable reference to the active adapter.
100    ///
101    /// # Errors
102    /// Returns an error if no adapter is currently active
103    pub fn get_active_adapter_mut(&mut self) -> Result<&mut A> {
104        let name = self
105            .active_adapter
106            .as_ref()
107            .ok_or_else(|| PeftError::AdapterNotFound {
108                name: "no active adapter".to_string(),
109            })?
110            .clone();
111
112        self.adapters
113            .get_mut(&name)
114            .ok_or_else(|| PeftError::AdapterNotFound { name })
115    }
116
117    /// Get a reference to an adapter by name.
118    ///
119    /// # Arguments
120    /// * `name` - Name of the adapter
121    ///
122    /// # Errors
123    /// Returns an error if no adapter with this name exists
124    pub fn get_adapter(&self, name: &str) -> Result<&A> {
125        self.adapters
126            .get(name)
127            .ok_or_else(|| PeftError::AdapterNotFound {
128                name: name.to_string(),
129            })
130    }
131
132    /// Get a mutable reference to an adapter by name.
133    ///
134    /// # Arguments
135    /// * `name` - Name of the adapter
136    ///
137    /// # Errors
138    /// Returns an error if no adapter with this name exists
139    pub fn get_adapter_mut(&mut self, name: &str) -> Result<&mut A> {
140        self.adapters
141            .get_mut(name)
142            .ok_or_else(|| PeftError::AdapterNotFound {
143                name: name.to_string(),
144            })
145    }
146
147    /// Check if an adapter with the given name exists.
148    #[must_use]
149    pub fn contains_adapter(&self, name: &str) -> bool {
150        self.adapters.contains_key(name)
151    }
152
153    /// Get the name of the currently active adapter.
154    #[must_use]
155    pub fn active_adapter_name(&self) -> Option<&str> {
156        self.active_adapter.as_deref()
157    }
158
159    /// Get a list of all registered adapter names.
160    #[must_use]
161    pub fn adapter_names(&self) -> Vec<&str> {
162        self.adapters.keys().map(String::as_str).collect()
163    }
164
165    /// Get the number of registered adapters.
166    #[must_use]
167    pub fn len(&self) -> usize {
168        self.adapters.len()
169    }
170
171    /// Check if the registry is empty.
172    #[must_use]
173    pub fn is_empty(&self) -> bool {
174        self.adapters.is_empty()
175    }
176
177    /// Remove an adapter by name.
178    ///
179    /// # Arguments
180    /// * `name` - Name of the adapter to remove
181    ///
182    /// # Returns
183    /// The removed adapter, if it existed
184    ///
185    /// # Errors
186    /// Returns an error if trying to remove the active adapter
187    pub fn remove_adapter(&mut self, name: &str) -> Result<Option<A>> {
188        // Don't allow removing the active adapter
189        if self.active_adapter.as_deref() == Some(name) {
190            return Err(PeftError::InvalidConfig(
191                "Cannot remove the currently active adapter".to_string(),
192            ));
193        }
194
195        Ok(self.adapters.remove(name))
196    }
197
198    /// Clear all adapters from the registry.
199    pub fn clear(&mut self) {
200        self.adapters.clear();
201        self.active_adapter = None;
202    }
203
204    /// Apply the active adapter to an input tensor.
205    ///
206    /// # Arguments
207    /// * `input` - Input tensor
208    /// * `base_output` - Optional base layer output
209    ///
210    /// # Errors
211    /// Returns an error if no active adapter or forward pass fails
212    pub fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
213        let adapter = self.get_active_adapter()?;
214        adapter.forward(input, base_output)
215    }
216}
217
218impl<A: Adapter> Default for AdapterRegistry<A> {
219    fn default() -> Self {
220        Self::new()
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::{LoraConfig, LoraLayer};
228    use candle_core::Device;
229
230    #[test]
231    fn test_registry_creation() {
232        let registry: AdapterRegistry<LoraLayer> = AdapterRegistry::new();
233        assert!(registry.is_empty());
234        assert_eq!(registry.len(), 0);
235        assert!(registry.active_adapter_name().is_none());
236    }
237
238    #[test]
239    fn test_register_adapter() -> Result<()> {
240        let mut registry = AdapterRegistry::new();
241        let device = Device::Cpu;
242        let config = LoraConfig::default();
243
244        let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
245        let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
246
247        registry.register_adapter("adapter1", adapter1)?;
248        assert_eq!(registry.len(), 1);
249        assert_eq!(registry.active_adapter_name(), Some("adapter1"));
250
251        registry.register_adapter("adapter2", adapter2)?;
252        assert_eq!(registry.len(), 2);
253        // Active adapter should still be adapter1
254        assert_eq!(registry.active_adapter_name(), Some("adapter1"));
255
256        Ok(())
257    }
258
259    #[test]
260    fn test_register_duplicate_adapter() -> Result<()> {
261        let mut registry = AdapterRegistry::new();
262        let device = Device::Cpu;
263        let config = LoraConfig::default();
264
265        let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
266        let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
267
268        registry.register_adapter("adapter1", adapter1)?;
269        let result = registry.register_adapter("adapter1", adapter2);
270
271        assert!(result.is_err());
272        assert!(matches!(
273            result.unwrap_err(),
274            PeftError::AdapterExists { .. }
275        ));
276
277        Ok(())
278    }
279
280    #[test]
281    fn test_set_active_adapter() -> Result<()> {
282        let mut registry = AdapterRegistry::new();
283        let device = Device::Cpu;
284        let config = LoraConfig::default();
285
286        let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
287        let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
288
289        registry.register_adapter("adapter1", adapter1)?;
290        registry.register_adapter("adapter2", adapter2)?;
291
292        assert_eq!(registry.active_adapter_name(), Some("adapter1"));
293
294        registry.set_active_adapter("adapter2")?;
295        assert_eq!(registry.active_adapter_name(), Some("adapter2"));
296
297        Ok(())
298    }
299
300    #[test]
301    fn test_set_nonexistent_adapter() -> Result<()> {
302        let mut registry = AdapterRegistry::new();
303        let device = Device::Cpu;
304        let config = LoraConfig::default();
305
306        let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
307        registry.register_adapter("adapter1", adapter1)?;
308
309        let result = registry.set_active_adapter("nonexistent");
310        assert!(result.is_err());
311        assert!(matches!(
312            result.unwrap_err(),
313            PeftError::AdapterNotFound { .. }
314        ));
315
316        Ok(())
317    }
318
319    #[test]
320    fn test_get_adapter() -> Result<()> {
321        let mut registry = AdapterRegistry::new();
322        let device = Device::Cpu;
323        let config = LoraConfig::default();
324
325        let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
326        registry.register_adapter("adapter1", adapter1)?;
327
328        let retrieved = registry.get_adapter("adapter1")?;
329        assert_eq!(retrieved.num_parameters(), 768 * 8 + 8 * 768);
330
331        Ok(())
332    }
333
334    #[test]
335    fn test_get_active_adapter() -> Result<()> {
336        let mut registry = AdapterRegistry::new();
337        let device = Device::Cpu;
338        let config = LoraConfig::default();
339
340        let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
341        registry.register_adapter("adapter1", adapter1)?;
342
343        let active = registry.get_active_adapter()?;
344        assert_eq!(active.num_parameters(), 768 * 8 + 8 * 768);
345
346        Ok(())
347    }
348
349    #[test]
350    fn test_contains_adapter() -> Result<()> {
351        let mut registry = AdapterRegistry::new();
352        let device = Device::Cpu;
353        let config = LoraConfig::default();
354
355        let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
356        registry.register_adapter("adapter1", adapter1)?;
357
358        assert!(registry.contains_adapter("adapter1"));
359        assert!(!registry.contains_adapter("adapter2"));
360
361        Ok(())
362    }
363
364    #[test]
365    fn test_adapter_names() -> Result<()> {
366        let mut registry = AdapterRegistry::new();
367        let device = Device::Cpu;
368        let config = LoraConfig::default();
369
370        let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
371        let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
372
373        registry.register_adapter("adapter1", adapter1)?;
374        registry.register_adapter("adapter2", adapter2)?;
375
376        let mut names = registry.adapter_names();
377        names.sort_unstable();
378        assert_eq!(names, vec!["adapter1", "adapter2"]);
379
380        Ok(())
381    }
382
383    #[test]
384    fn test_remove_adapter() -> Result<()> {
385        let mut registry = AdapterRegistry::new();
386        let device = Device::Cpu;
387        let config = LoraConfig::default();
388
389        let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
390        let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
391
392        registry.register_adapter("adapter1", adapter1)?;
393        registry.register_adapter("adapter2", adapter2)?;
394
395        // Can remove non-active adapter
396        let removed = registry.remove_adapter("adapter2")?;
397        assert!(removed.is_some());
398        assert_eq!(registry.len(), 1);
399
400        // Cannot remove active adapter
401        let result = registry.remove_adapter("adapter1");
402        assert!(result.is_err());
403
404        Ok(())
405    }
406
407    #[test]
408    fn test_clear() -> Result<()> {
409        let mut registry = AdapterRegistry::new();
410        let device = Device::Cpu;
411        let config = LoraConfig::default();
412
413        let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
414        registry.register_adapter("adapter1", adapter1)?;
415
416        assert_eq!(registry.len(), 1);
417        registry.clear();
418        assert_eq!(registry.len(), 0);
419        assert!(registry.active_adapter_name().is_none());
420
421        Ok(())
422    }
423
424    #[test]
425    fn test_forward_with_active_adapter() -> Result<()> {
426        use candle_core::{DType, Tensor};
427
428        let mut registry = AdapterRegistry::new();
429        let device = Device::Cpu;
430        let config = LoraConfig::default();
431
432        let adapter = LoraLayer::new_with_zeros(768, 768, config, &device)?;
433        registry.register_adapter("test_adapter", adapter)?;
434
435        let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device)?;
436        let output = registry.forward(&input, None)?;
437
438        assert_eq!(output.shape().dims(), &[1, 10, 768]);
439
440        Ok(())
441    }
442}