peft_rs/traits.rs
1//! Core traits for PEFT adapters.
2
3use candle_core::Tensor;
4use candle_nn::VarMap;
5
6use crate::Result;
7
8/// Configuration trait for adapter hyperparameters.
9pub trait AdapterConfig: Clone + Send + Sync {
10 /// Validate the configuration parameters.
11 ///
12 /// # Errors
13 ///
14 /// Returns an error if the configuration is invalid.
15 fn validate(&self) -> Result<()>;
16}
17
18/// Core adapter trait for parameter-efficient fine-tuning.
19pub trait Adapter: Send + Sync {
20 /// The configuration type for this adapter.
21 type Config: AdapterConfig;
22
23 /// Forward pass applying the adapter transformation.
24 ///
25 /// # Arguments
26 /// * `input` - Input tensor
27 /// * `base_output` - Optional output from the base layer (for residual adapters)
28 ///
29 /// # Returns
30 /// Transformed tensor
31 ///
32 /// # Errors
33 ///
34 /// Returns an error if the forward pass fails.
35 fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor>;
36
37 /// Get the number of trainable parameters.
38 #[must_use]
39 fn num_parameters(&self) -> usize;
40
41 /// Get the adapter's configuration.
42 fn config(&self) -> &Self::Config;
43}
44
45/// Trait for adapters that can be merged into base weights.
46pub trait Mergeable: Adapter {
47 /// Merge adapter weights into base model weights.
48 ///
49 /// # Arguments
50 /// * `base_weight` - The original weight tensor to merge into
51 ///
52 /// # Returns
53 /// New tensor with adapter weights merged
54 ///
55 /// # Errors
56 ///
57 /// Returns an error if merging fails.
58 fn merge(&self, base_weight: &Tensor) -> Result<Tensor>;
59
60 /// Unmerge adapter weights from merged weights.
61 ///
62 /// # Arguments
63 /// * `merged_weight` - Weight tensor with adapter already merged
64 ///
65 /// # Returns
66 /// Original base weight tensor
67 ///
68 /// # Errors
69 ///
70 /// Returns an error if unmerging fails.
71 fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor>;
72}
73
74/// Trait for trainable adapters.
75pub trait Trainable: Adapter {
76 /// Register trainable parameters with the variable map.
77 ///
78 /// # Errors
79 ///
80 /// Returns an error if parameter registration fails.
81 fn register_parameters(&self, var_map: &mut VarMap, prefix: &str) -> Result<()>;
82
83 /// Freeze all adapter parameters (disable gradients).
84 fn freeze(&mut self);
85
86 /// Unfreeze all adapter parameters (enable gradients).
87 fn unfreeze(&mut self);
88
89 /// Check if the adapter is frozen.
90 #[must_use]
91 fn is_frozen(&self) -> bool;
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 // Trait object safety check
99 fn _assert_adapter_object_safe(_: &dyn Adapter<Config = crate::LoraConfig>) {}
100}