Skip to main content

peft_rs/adapters/
prefix_tuning.rs

1//! Prefix Tuning implementation.
2//!
3//! Prefix tuning prepends trainable "prefix" vectors to the keys and values
4//! in attention layers, without modifying the original model weights.
5//!
6//! Reference: <https://arxiv.org/abs/2101.00190>
7
8use candle_core::{Device, IndexOp, Tensor};
9use serde::{Deserialize, Serialize};
10
11use crate::error::{PeftError, Result};
12use crate::traits::{Adapter, AdapterConfig};
13
14/// Configuration for prefix tuning.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PrefixTuningConfig {
17    /// Number of prefix tokens to prepend.
18    pub num_prefix_tokens: usize,
19
20    /// Hidden dimension of prefix vectors.
21    pub prefix_dim: usize,
22
23    /// Number of attention heads.
24    pub num_heads: usize,
25
26    /// Number of layers to apply prefix to.
27    pub num_layers: usize,
28
29    /// Whether to use a reparameterization MLP.
30    #[serde(default = "default_true")]
31    pub use_reparameterization: bool,
32}
33
34fn default_true() -> bool {
35    true
36}
37
38impl Default for PrefixTuningConfig {
39    fn default() -> Self {
40        Self {
41            num_prefix_tokens: 20,
42            prefix_dim: 512,
43            num_heads: 12,
44            num_layers: 12,
45            use_reparameterization: true,
46        }
47    }
48}
49
50impl AdapterConfig for PrefixTuningConfig {
51    fn validate(&self) -> Result<()> {
52        if self.num_prefix_tokens == 0 {
53            return Err(PeftError::InvalidConfig(
54                "num_prefix_tokens must be > 0".into(),
55            ));
56        }
57        if self.prefix_dim == 0 {
58            return Err(PeftError::InvalidConfig("prefix_dim must be > 0".into()));
59        }
60        Ok(())
61    }
62}
63
64/// Prefix tuning layer.
65///
66/// Stores trainable prefix embeddings for keys and values.
67pub struct PrefixTuningLayer {
68    /// Prefix embeddings for keys: [`num_layers`, `num_prefix_tokens`, `num_heads`, `head_dim`]
69    prefix_keys: Tensor,
70    /// Prefix embeddings for values: [`num_layers`, `num_prefix_tokens`, `num_heads`, `head_dim`]
71    prefix_values: Tensor,
72    /// Configuration
73    config: PrefixTuningConfig,
74}
75
76impl PrefixTuningLayer {
77    /// Create a new prefix tuning layer.
78    ///
79    /// # Arguments
80    /// * `config` - Prefix tuning configuration
81    /// * `head_dim` - Dimension per attention head
82    /// * `device` - Device to create tensors on
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if configuration validation fails or layer construction fails.
87    pub fn new(config: PrefixTuningConfig, head_dim: usize, device: &Device) -> Result<Self> {
88        config.validate()?;
89
90        let shape = (
91            config.num_layers,
92            config.num_prefix_tokens,
93            config.num_heads,
94            head_dim,
95        );
96
97        // Initialize with small random values
98        let prefix_keys = Tensor::randn(0.0f32, 0.02, shape, device)?;
99        let prefix_values = Tensor::randn(0.0f32, 0.02, shape, device)?;
100
101        Ok(Self {
102            prefix_keys,
103            prefix_values,
104            config,
105        })
106    }
107
108    /// Get prefix keys for a specific layer.
109    ///
110    /// # Errors
111    ///
112    /// Returns an error if the layer index is out of bounds.
113    pub fn get_prefix_keys(&self, layer_idx: usize) -> Result<Tensor> {
114        Ok(self.prefix_keys.i(layer_idx)?)
115    }
116
117    /// Get prefix values for a specific layer.
118    ///
119    /// # Errors
120    ///
121    /// Returns an error if the layer index is out of bounds.
122    pub fn get_prefix_values(&self, layer_idx: usize) -> Result<Tensor> {
123        Ok(self.prefix_values.i(layer_idx)?)
124    }
125}
126
127impl Adapter for PrefixTuningLayer {
128    type Config = PrefixTuningConfig;
129
130    fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> Result<Tensor> {
131        // Prefix tuning doesn't modify the input directly;
132        // it provides prefixes to be concatenated in attention
133        Ok(input.clone())
134    }
135
136    fn num_parameters(&self) -> usize {
137        self.prefix_keys.elem_count() + self.prefix_values.elem_count()
138    }
139
140    fn config(&self) -> &Self::Config {
141        &self.config
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_prefix_tuning_creation() {
151        let config = PrefixTuningConfig::default();
152        let device = Device::Cpu;
153        let layer = PrefixTuningLayer::new(config, 64, &device);
154        assert!(layer.is_ok());
155    }
156
157    #[test]
158    fn test_prefix_shapes() {
159        let config = PrefixTuningConfig {
160            num_prefix_tokens: 10,
161            num_heads: 8,
162            num_layers: 6,
163            ..Default::default()
164        };
165        let device = Device::Cpu;
166        let layer = PrefixTuningLayer::new(config, 64, &device).unwrap();
167
168        let keys = layer.get_prefix_keys(0).unwrap();
169        assert_eq!(keys.shape().dims(), &[10, 8, 64]);
170    }
171}