Skip to main content

peft_rs/
traits.rs

1//! Core traits for PEFT adapters.
2
3use candle_core::Tensor;
4use candle_nn::VarMap;
5
6use crate::Result;
7
8/// Configuration trait for adapter hyperparameters.
9pub trait AdapterConfig: Clone + Send + Sync {
10    /// Validate the configuration parameters.
11    ///
12    /// # Errors
13    ///
14    /// Returns an error if the configuration is invalid.
15    fn validate(&self) -> Result<()>;
16}
17
18/// Core adapter trait for parameter-efficient fine-tuning.
19pub trait Adapter: Send + Sync {
20    /// The configuration type for this adapter.
21    type Config: AdapterConfig;
22
23    /// Forward pass applying the adapter transformation.
24    ///
25    /// # Arguments
26    /// * `input` - Input tensor
27    /// * `base_output` - Optional output from the base layer (for residual adapters)
28    ///
29    /// # Returns
30    /// Transformed tensor
31    ///
32    /// # Errors
33    ///
34    /// Returns an error if the forward pass fails.
35    fn forward(&self, input: &Tensor, base_output: Option<&Tensor>) -> Result<Tensor>;
36
37    /// Get the number of trainable parameters.
38    #[must_use]
39    fn num_parameters(&self) -> usize;
40
41    /// Get the adapter's configuration.
42    fn config(&self) -> &Self::Config;
43}
44
45/// Trait for adapters that can be merged into base weights.
46pub trait Mergeable: Adapter {
47    /// Merge adapter weights into base model weights.
48    ///
49    /// # Arguments
50    /// * `base_weight` - The original weight tensor to merge into
51    ///
52    /// # Returns
53    /// New tensor with adapter weights merged
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if merging fails.
58    fn merge(&self, base_weight: &Tensor) -> Result<Tensor>;
59
60    /// Unmerge adapter weights from merged weights.
61    ///
62    /// # Arguments
63    /// * `merged_weight` - Weight tensor with adapter already merged
64    ///
65    /// # Returns
66    /// Original base weight tensor
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if unmerging fails.
71    fn unmerge(&self, merged_weight: &Tensor) -> Result<Tensor>;
72}
73
74/// Trait for trainable adapters.
75pub trait Trainable: Adapter {
76    /// Register trainable parameters with the variable map.
77    ///
78    /// # Errors
79    ///
80    /// Returns an error if parameter registration fails.
81    fn register_parameters(&self, var_map: &mut VarMap, prefix: &str) -> Result<()>;
82
83    /// Freeze all adapter parameters (disable gradients).
84    fn freeze(&mut self);
85
86    /// Unfreeze all adapter parameters (enable gradients).
87    fn unfreeze(&mut self);
88
89    /// Check if the adapter is frozen.
90    #[must_use]
91    fn is_frozen(&self) -> bool;
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    // Trait object safety check
99    fn _assert_adapter_object_safe(_: &dyn Adapter<Config = crate::LoraConfig>) {}
100}