1use std::collections::HashMap;
11
12use candle_core::Tensor;
13
14use crate::error::{PeftError, Result};
15use crate::traits::Adapter;
16
17pub struct AdapterRegistry<A: Adapter> {
23 adapters: HashMap<String, A>,
25 active_adapter: Option<String>,
27}
28
29impl<A: Adapter> AdapterRegistry<A> {
30 #[must_use]
32 pub fn new() -> Self {
33 Self {
34 adapters: HashMap::new(),
35 active_adapter: None,
36 }
37 }
38
39 pub fn register_adapter(&mut self, name: impl Into<String>, adapter: A) -> Result<()> {
48 let name = name.into();
49
50 if self.adapters.contains_key(&name) {
51 return Err(PeftError::AdapterExists { name });
52 }
53
54 self.adapters.insert(name.clone(), adapter);
55
56 if self.active_adapter.is_none() {
58 self.active_adapter = Some(name);
59 }
60
61 Ok(())
62 }
63
64 pub fn set_active_adapter(&mut self, name: impl Into<String>) -> Result<()> {
72 let name = name.into();
73
74 if !self.adapters.contains_key(&name) {
75 return Err(PeftError::AdapterNotFound { name });
76 }
77
78 self.active_adapter = Some(name);
79 Ok(())
80 }
81
82 pub fn get_active_adapter(&self) -> Result<&A> {
87 let name = self
88 .active_adapter
89 .as_ref()
90 .ok_or_else(|| PeftError::AdapterNotFound {
91 name: "no active adapter".to_string(),
92 })?;
93
94 self.adapters
95 .get(name)
96 .ok_or_else(|| PeftError::AdapterNotFound { name: name.clone() })
97 }
98
99 pub fn get_active_adapter_mut(&mut self) -> Result<&mut A> {
104 let name = self
105 .active_adapter
106 .as_ref()
107 .ok_or_else(|| PeftError::AdapterNotFound {
108 name: "no active adapter".to_string(),
109 })?
110 .clone();
111
112 self.adapters
113 .get_mut(&name)
114 .ok_or_else(|| PeftError::AdapterNotFound { name })
115 }
116
117 pub fn get_adapter(&self, name: &str) -> Result<&A> {
125 self.adapters
126 .get(name)
127 .ok_or_else(|| PeftError::AdapterNotFound {
128 name: name.to_string(),
129 })
130 }
131
132 pub fn get_adapter_mut(&mut self, name: &str) -> Result<&mut A> {
140 self.adapters
141 .get_mut(name)
142 .ok_or_else(|| PeftError::AdapterNotFound {
143 name: name.to_string(),
144 })
145 }
146
147 #[must_use]
149 pub fn contains_adapter(&self, name: &str) -> bool {
150 self.adapters.contains_key(name)
151 }
152
153 #[must_use]
155 pub fn active_adapter_name(&self) -> Option<&str> {
156 self.active_adapter.as_deref()
157 }
158
159 #[must_use]
161 pub fn adapter_names(&self) -> Vec<&str> {
162 self.adapters.keys().map(String::as_str).collect()
163 }
164
165 #[must_use]
167 pub fn len(&self) -> usize {
168 self.adapters.len()
169 }
170
171 #[must_use]
173 pub fn is_empty(&self) -> bool {
174 self.adapters.is_empty()
175 }
176
177 pub fn remove_adapter(&mut self, name: &str) -> Result<Option<A>> {
188 if self.active_adapter.as_deref() == Some(name) {
190 return Err(PeftError::InvalidConfig(
191 "Cannot remove the currently active adapter".to_string(),
192 ));
193 }
194
195 Ok(self.adapters.remove(name))
196 }
197
198 pub fn clear(&mut self) {
200 self.adapters.clear();
201 self.active_adapter = None;
202 }
203
204 pub fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor> {
213 let adapter = self.get_active_adapter()?;
214 adapter.forward(input, base_output)
215 }
216}
217
218impl<A: Adapter> Default for AdapterRegistry<A> {
219 fn default() -> Self {
220 Self::new()
221 }
222}
223
224#[cfg(test)]
225mod tests {
226 use super::*;
227 use crate::{LoraConfig, LoraLayer};
228 use candle_core::Device;
229
230 #[test]
231 fn test_registry_creation() {
232 let registry: AdapterRegistry<LoraLayer> = AdapterRegistry::new();
233 assert!(registry.is_empty());
234 assert_eq!(registry.len(), 0);
235 assert!(registry.active_adapter_name().is_none());
236 }
237
238 #[test]
239 fn test_register_adapter() -> Result<()> {
240 let mut registry = AdapterRegistry::new();
241 let device = Device::Cpu;
242 let config = LoraConfig::default();
243
244 let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
245 let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
246
247 registry.register_adapter("adapter1", adapter1)?;
248 assert_eq!(registry.len(), 1);
249 assert_eq!(registry.active_adapter_name(), Some("adapter1"));
250
251 registry.register_adapter("adapter2", adapter2)?;
252 assert_eq!(registry.len(), 2);
253 assert_eq!(registry.active_adapter_name(), Some("adapter1"));
255
256 Ok(())
257 }
258
259 #[test]
260 fn test_register_duplicate_adapter() -> Result<()> {
261 let mut registry = AdapterRegistry::new();
262 let device = Device::Cpu;
263 let config = LoraConfig::default();
264
265 let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
266 let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
267
268 registry.register_adapter("adapter1", adapter1)?;
269 let result = registry.register_adapter("adapter1", adapter2);
270
271 assert!(result.is_err());
272 assert!(matches!(
273 result.unwrap_err(),
274 PeftError::AdapterExists { .. }
275 ));
276
277 Ok(())
278 }
279
280 #[test]
281 fn test_set_active_adapter() -> Result<()> {
282 let mut registry = AdapterRegistry::new();
283 let device = Device::Cpu;
284 let config = LoraConfig::default();
285
286 let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
287 let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
288
289 registry.register_adapter("adapter1", adapter1)?;
290 registry.register_adapter("adapter2", adapter2)?;
291
292 assert_eq!(registry.active_adapter_name(), Some("adapter1"));
293
294 registry.set_active_adapter("adapter2")?;
295 assert_eq!(registry.active_adapter_name(), Some("adapter2"));
296
297 Ok(())
298 }
299
300 #[test]
301 fn test_set_nonexistent_adapter() -> Result<()> {
302 let mut registry = AdapterRegistry::new();
303 let device = Device::Cpu;
304 let config = LoraConfig::default();
305
306 let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
307 registry.register_adapter("adapter1", adapter1)?;
308
309 let result = registry.set_active_adapter("nonexistent");
310 assert!(result.is_err());
311 assert!(matches!(
312 result.unwrap_err(),
313 PeftError::AdapterNotFound { .. }
314 ));
315
316 Ok(())
317 }
318
319 #[test]
320 fn test_get_adapter() -> Result<()> {
321 let mut registry = AdapterRegistry::new();
322 let device = Device::Cpu;
323 let config = LoraConfig::default();
324
325 let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
326 registry.register_adapter("adapter1", adapter1)?;
327
328 let retrieved = registry.get_adapter("adapter1")?;
329 assert_eq!(retrieved.num_parameters(), 768 * 8 + 8 * 768);
330
331 Ok(())
332 }
333
334 #[test]
335 fn test_get_active_adapter() -> Result<()> {
336 let mut registry = AdapterRegistry::new();
337 let device = Device::Cpu;
338 let config = LoraConfig::default();
339
340 let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
341 registry.register_adapter("adapter1", adapter1)?;
342
343 let active = registry.get_active_adapter()?;
344 assert_eq!(active.num_parameters(), 768 * 8 + 8 * 768);
345
346 Ok(())
347 }
348
349 #[test]
350 fn test_contains_adapter() -> Result<()> {
351 let mut registry = AdapterRegistry::new();
352 let device = Device::Cpu;
353 let config = LoraConfig::default();
354
355 let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
356 registry.register_adapter("adapter1", adapter1)?;
357
358 assert!(registry.contains_adapter("adapter1"));
359 assert!(!registry.contains_adapter("adapter2"));
360
361 Ok(())
362 }
363
364 #[test]
365 fn test_adapter_names() -> Result<()> {
366 let mut registry = AdapterRegistry::new();
367 let device = Device::Cpu;
368 let config = LoraConfig::default();
369
370 let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
371 let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
372
373 registry.register_adapter("adapter1", adapter1)?;
374 registry.register_adapter("adapter2", adapter2)?;
375
376 let mut names = registry.adapter_names();
377 names.sort_unstable();
378 assert_eq!(names, vec!["adapter1", "adapter2"]);
379
380 Ok(())
381 }
382
383 #[test]
384 fn test_remove_adapter() -> Result<()> {
385 let mut registry = AdapterRegistry::new();
386 let device = Device::Cpu;
387 let config = LoraConfig::default();
388
389 let adapter1 = LoraLayer::new_with_zeros(768, 768, config.clone(), &device)?;
390 let adapter2 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
391
392 registry.register_adapter("adapter1", adapter1)?;
393 registry.register_adapter("adapter2", adapter2)?;
394
395 let removed = registry.remove_adapter("adapter2")?;
397 assert!(removed.is_some());
398 assert_eq!(registry.len(), 1);
399
400 let result = registry.remove_adapter("adapter1");
402 assert!(result.is_err());
403
404 Ok(())
405 }
406
407 #[test]
408 fn test_clear() -> Result<()> {
409 let mut registry = AdapterRegistry::new();
410 let device = Device::Cpu;
411 let config = LoraConfig::default();
412
413 let adapter1 = LoraLayer::new_with_zeros(768, 768, config, &device)?;
414 registry.register_adapter("adapter1", adapter1)?;
415
416 assert_eq!(registry.len(), 1);
417 registry.clear();
418 assert_eq!(registry.len(), 0);
419 assert!(registry.active_adapter_name().is_none());
420
421 Ok(())
422 }
423
424 #[test]
425 fn test_forward_with_active_adapter() -> Result<()> {
426 use candle_core::{DType, Tensor};
427
428 let mut registry = AdapterRegistry::new();
429 let device = Device::Cpu;
430 let config = LoraConfig::default();
431
432 let adapter = LoraLayer::new_with_zeros(768, 768, config, &device)?;
433 registry.register_adapter("test_adapter", adapter)?;
434
435 let input = Tensor::zeros(&[1, 10, 768], DType::F32, &device)?;
436 let output = registry.forward(&input, None)?;
437
438 assert_eq!(output.shape().dims(), &[1, 10, 768]);
439
440 Ok(())
441 }
442}