Skip to main content

peft_rs/adapters/
prompt_tuning.rs

1//! Prompt Tuning implementation.
2//!
3//! Prompt tuning prepends learnable "soft prompt" embeddings to the input,
4//! allowing the model to be steered without modifying weights.
5//!
6//! Reference: <https://arxiv.org/abs/2104.08691>
7
8use candle_core::{Device, Tensor};
9use serde::{Deserialize, Serialize};
10
11use crate::error::{PeftError, Result};
12use crate::traits::{Adapter, AdapterConfig};
13
14/// Configuration for prompt tuning.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PromptTuningConfig {
17    /// Number of virtual tokens (soft prompt length).
18    pub num_virtual_tokens: usize,
19
20    /// Hidden size of the model embeddings.
21    pub hidden_size: usize,
22
23    /// Initialization strategy.
24    #[serde(default)]
25    pub init_strategy: PromptInit,
26}
27
28/// Initialization strategy for soft prompts.
29#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub enum PromptInit {
31    /// Random initialization from normal distribution.
32    #[default]
33    Random,
34    /// Initialize from text tokens (requires tokenizer).
35    Text(String),
36}
37
38impl Default for PromptTuningConfig {
39    fn default() -> Self {
40        Self {
41            num_virtual_tokens: 20,
42            hidden_size: 768,
43            init_strategy: PromptInit::Random,
44        }
45    }
46}
47
48impl AdapterConfig for PromptTuningConfig {
49    fn validate(&self) -> Result<()> {
50        if self.num_virtual_tokens == 0 {
51            return Err(PeftError::InvalidConfig(
52                "num_virtual_tokens must be > 0".into(),
53            ));
54        }
55        if self.hidden_size == 0 {
56            return Err(PeftError::InvalidConfig("hidden_size must be > 0".into()));
57        }
58        Ok(())
59    }
60}
61
62/// Prompt tuning layer.
63///
64/// Maintains soft prompt embeddings that are prepended to input embeddings.
65pub struct PromptTuningLayer {
66    /// Soft prompt embeddings: [`num_virtual_tokens`, `hidden_size`]
67    soft_prompt: Tensor,
68    /// Configuration
69    config: PromptTuningConfig,
70}
71
72impl PromptTuningLayer {
73    /// Create a new prompt tuning layer with random initialization.
74    ///
75    /// # Arguments
76    /// * `config` - Prompt tuning configuration
77    /// * `device` - Device to create tensors on
78    ///
79    /// # Errors
80    ///
81    /// Returns an error if configuration validation fails or layer construction fails.
82    pub fn new(config: PromptTuningConfig, device: &Device) -> Result<Self> {
83        config.validate()?;
84
85        let soft_prompt = Tensor::randn(
86            0.0f32,
87            0.02,
88            (config.num_virtual_tokens, config.hidden_size),
89            device,
90        )?;
91
92        Ok(Self {
93            soft_prompt,
94            config,
95        })
96    }
97
98    /// Get the soft prompt embeddings.
99    #[must_use]
100    pub fn soft_prompt(&self) -> &Tensor {
101        &self.soft_prompt
102    }
103
104    /// Prepend soft prompts to input embeddings.
105    ///
106    /// # Arguments
107    /// * `input_embeds` - Input embeddings [batch, `seq_len`, hidden]
108    ///
109    /// # Returns
110    /// Concatenated embeddings [batch, `num_virtual_tokens` + `seq_len`, hidden]
111    ///
112    /// # Errors
113    ///
114    /// Returns an error if tensor operations fail.
115    pub fn prepend_to_input(&self, input_embeds: &Tensor) -> Result<Tensor> {
116        let batch_size = input_embeds.dim(0)?;
117
118        // Expand soft prompt for batch: [1, num_virtual_tokens, hidden] -> [batch, ...]
119        let expanded_prompt = self.soft_prompt.unsqueeze(0)?.expand((
120            batch_size,
121            self.config.num_virtual_tokens,
122            self.config.hidden_size,
123        ))?;
124
125        // Concatenate along sequence dimension
126        Ok(Tensor::cat(&[&expanded_prompt, input_embeds], 1)?)
127    }
128}
129
130impl Adapter for PromptTuningLayer {
131    type Config = PromptTuningConfig;
132
133    fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> Result<Tensor> {
134        self.prepend_to_input(input)
135    }
136
137    fn num_parameters(&self) -> usize {
138        self.config.num_virtual_tokens * self.config.hidden_size
139    }
140
141    fn config(&self) -> &Self::Config {
142        &self.config
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use super::*;
149    use candle_core::DType;
150
151    #[test]
152    fn test_prompt_tuning_creation() {
153        let config = PromptTuningConfig::default();
154        let device = Device::Cpu;
155        let layer = PromptTuningLayer::new(config, &device);
156        assert!(layer.is_ok());
157    }
158
159    #[test]
160    fn test_prepend_to_input() {
161        let config = PromptTuningConfig {
162            num_virtual_tokens: 10,
163            hidden_size: 768,
164            ..Default::default()
165        };
166        let device = Device::Cpu;
167        let layer = PromptTuningLayer::new(config, &device).unwrap();
168
169        let input = Tensor::zeros(&[2, 20, 768], DType::F32, &device).unwrap();
170        let output = layer.prepend_to_input(&input).unwrap();
171
172        // Output should be [2, 10+20, 768] = [2, 30, 768]
173        assert_eq!(output.shape().dims(), &[2, 30, 768]);
174    }
175
176    #[test]
177    fn test_num_parameters() {
178        let config = PromptTuningConfig {
179            num_virtual_tokens: 20,
180            hidden_size: 768,
181            ..Default::default()
182        };
183        let device = Device::Cpu;
184        let layer = PromptTuningLayer::new(config, &device).unwrap();
185
186        assert_eq!(layer.num_parameters(), 20 * 768);
187    }
188}