peft_rs/adapters/
prompt_tuning.rs1use candle_core::{Device, Tensor};
9use serde::{Deserialize, Serialize};
10
11use crate::error::{PeftError, Result};
12use crate::traits::{Adapter, AdapterConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PromptTuningConfig {
17 pub num_virtual_tokens: usize,
19
20 pub hidden_size: usize,
22
23 #[serde(default)]
25 pub init_strategy: PromptInit,
26}
27
28#[derive(Debug, Clone, Default, Serialize, Deserialize)]
30pub enum PromptInit {
31 #[default]
33 Random,
34 Text(String),
36}
37
38impl Default for PromptTuningConfig {
39 fn default() -> Self {
40 Self {
41 num_virtual_tokens: 20,
42 hidden_size: 768,
43 init_strategy: PromptInit::Random,
44 }
45 }
46}
47
48impl AdapterConfig for PromptTuningConfig {
49 fn validate(&self) -> Result<()> {
50 if self.num_virtual_tokens == 0 {
51 return Err(PeftError::InvalidConfig(
52 "num_virtual_tokens must be > 0".into(),
53 ));
54 }
55 if self.hidden_size == 0 {
56 return Err(PeftError::InvalidConfig("hidden_size must be > 0".into()));
57 }
58 Ok(())
59 }
60}
61
62pub struct PromptTuningLayer {
66 soft_prompt: Tensor,
68 config: PromptTuningConfig,
70}
71
72impl PromptTuningLayer {
73 pub fn new(config: PromptTuningConfig, device: &Device) -> Result<Self> {
83 config.validate()?;
84
85 let soft_prompt = Tensor::randn(
86 0.0f32,
87 0.02,
88 (config.num_virtual_tokens, config.hidden_size),
89 device,
90 )?;
91
92 Ok(Self {
93 soft_prompt,
94 config,
95 })
96 }
97
98 #[must_use]
100 pub fn soft_prompt(&self) -> &Tensor {
101 &self.soft_prompt
102 }
103
104 pub fn prepend_to_input(&self, input_embeds: &Tensor) -> Result<Tensor> {
116 let batch_size = input_embeds.dim(0)?;
117
118 let expanded_prompt = self.soft_prompt.unsqueeze(0)?.expand((
120 batch_size,
121 self.config.num_virtual_tokens,
122 self.config.hidden_size,
123 ))?;
124
125 Ok(Tensor::cat(&[&expanded_prompt, input_embeds], 1)?)
127 }
128}
129
130impl Adapter for PromptTuningLayer {
131 type Config = PromptTuningConfig;
132
133 fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> Result<Tensor> {
134 self.prepend_to_input(input)
135 }
136
137 fn num_parameters(&self) -> usize {
138 self.config.num_virtual_tokens * self.config.hidden_size
139 }
140
141 fn config(&self) -> &Self::Config {
142 &self.config
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use super::*;
149 use candle_core::DType;
150
151 #[test]
152 fn test_prompt_tuning_creation() {
153 let config = PromptTuningConfig::default();
154 let device = Device::Cpu;
155 let layer = PromptTuningLayer::new(config, &device);
156 assert!(layer.is_ok());
157 }
158
159 #[test]
160 fn test_prepend_to_input() {
161 let config = PromptTuningConfig {
162 num_virtual_tokens: 10,
163 hidden_size: 768,
164 ..Default::default()
165 };
166 let device = Device::Cpu;
167 let layer = PromptTuningLayer::new(config, &device).unwrap();
168
169 let input = Tensor::zeros(&[2, 20, 768], DType::F32, &device).unwrap();
170 let output = layer.prepend_to_input(&input).unwrap();
171
172 assert_eq!(output.shape().dims(), &[2, 30, 768]);
174 }
175
176 #[test]
177 fn test_num_parameters() {
178 let config = PromptTuningConfig {
179 num_virtual_tokens: 20,
180 hidden_size: 768,
181 ..Default::default()
182 };
183 let device = Device::Cpu;
184 let layer = PromptTuningLayer::new(config, &device).unwrap();
185
186 assert_eq!(layer.num_parameters(), 20 * 768);
187 }
188}