peft_rs/adapters/
prefix_tuning.rs1use candle_core::{Device, IndexOp, Tensor};
9use serde::{Deserialize, Serialize};
10
11use crate::error::{PeftError, Result};
12use crate::traits::{Adapter, AdapterConfig};
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct PrefixTuningConfig {
17 pub num_prefix_tokens: usize,
19
20 pub prefix_dim: usize,
22
23 pub num_heads: usize,
25
26 pub num_layers: usize,
28
29 #[serde(default = "default_true")]
31 pub use_reparameterization: bool,
32}
33
34fn default_true() -> bool {
35 true
36}
37
38impl Default for PrefixTuningConfig {
39 fn default() -> Self {
40 Self {
41 num_prefix_tokens: 20,
42 prefix_dim: 512,
43 num_heads: 12,
44 num_layers: 12,
45 use_reparameterization: true,
46 }
47 }
48}
49
50impl AdapterConfig for PrefixTuningConfig {
51 fn validate(&self) -> Result<()> {
52 if self.num_prefix_tokens == 0 {
53 return Err(PeftError::InvalidConfig(
54 "num_prefix_tokens must be > 0".into(),
55 ));
56 }
57 if self.prefix_dim == 0 {
58 return Err(PeftError::InvalidConfig("prefix_dim must be > 0".into()));
59 }
60 Ok(())
61 }
62}
63
64pub struct PrefixTuningLayer {
68 prefix_keys: Tensor,
70 prefix_values: Tensor,
72 config: PrefixTuningConfig,
74}
75
76impl PrefixTuningLayer {
77 pub fn new(config: PrefixTuningConfig, head_dim: usize, device: &Device) -> Result<Self> {
88 config.validate()?;
89
90 let shape = (
91 config.num_layers,
92 config.num_prefix_tokens,
93 config.num_heads,
94 head_dim,
95 );
96
97 let prefix_keys = Tensor::randn(0.0f32, 0.02, shape, device)?;
99 let prefix_values = Tensor::randn(0.0f32, 0.02, shape, device)?;
100
101 Ok(Self {
102 prefix_keys,
103 prefix_values,
104 config,
105 })
106 }
107
108 pub fn get_prefix_keys(&self, layer_idx: usize) -> Result<Tensor> {
114 Ok(self.prefix_keys.i(layer_idx)?)
115 }
116
117 pub fn get_prefix_values(&self, layer_idx: usize) -> Result<Tensor> {
123 Ok(self.prefix_values.i(layer_idx)?)
124 }
125}
126
127impl Adapter for PrefixTuningLayer {
128 type Config = PrefixTuningConfig;
129
130 fn forward(&self, input: &Tensor, _base_output: Option<&Tensor>) -> Result<Tensor> {
131 Ok(input.clone())
134 }
135
136 fn num_parameters(&self) -> usize {
137 self.prefix_keys.elem_count() + self.prefix_values.elem_count()
138 }
139
140 fn config(&self) -> &Self::Config {
141 &self.config
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_prefix_tuning_creation() {
151 let config = PrefixTuningConfig::default();
152 let device = Device::Cpu;
153 let layer = PrefixTuningLayer::new(config, 64, &device);
154 assert!(layer.is_ok());
155 }
156
157 #[test]
158 fn test_prefix_shapes() {
159 let config = PrefixTuningConfig {
160 num_prefix_tokens: 10,
161 num_heads: 8,
162 num_layers: 6,
163 ..Default::default()
164 };
165 let device = Device::Cpu;
166 let layer = PrefixTuningLayer::new(config, 64, &device).unwrap();
167
168 let keys = layer.get_prefix_keys(0).unwrap();
169 assert_eq!(keys.shape().dims(), &[10, 8, 64]);
170 }
171}