1use ghostflow_core::Tensor;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct PromptTuningConfig {
15 pub num_virtual_tokens: usize,
17 pub d_model: usize,
19 pub init_strategy: PromptInitStrategy,
21 pub reparameterize: bool,
23 pub hidden_dim: usize,
25}
26
27#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum PromptInitStrategy {
30 Random,
32 Vocab,
34 Text,
36}
37
38impl Default for PromptTuningConfig {
39 fn default() -> Self {
40 PromptTuningConfig {
41 num_virtual_tokens: 20,
42 d_model: 768,
43 init_strategy: PromptInitStrategy::Random,
44 reparameterize: false,
45 hidden_dim: 512,
46 }
47 }
48}
49
50impl PromptTuningConfig {
51 pub fn short(d_model: usize) -> Self {
53 PromptTuningConfig {
54 num_virtual_tokens: 10,
55 d_model,
56 ..Default::default()
57 }
58 }
59
60 pub fn medium(d_model: usize) -> Self {
62 PromptTuningConfig {
63 num_virtual_tokens: 30,
64 d_model,
65 ..Default::default()
66 }
67 }
68
69 pub fn long(d_model: usize) -> Self {
71 PromptTuningConfig {
72 num_virtual_tokens: 80,
73 d_model,
74 ..Default::default()
75 }
76 }
77
78 pub fn with_reparameterization(mut self, hidden_dim: usize) -> Self {
80 self.reparameterize = true;
81 self.hidden_dim = hidden_dim;
82 self
83 }
84}
85
86pub struct PromptTuning {
88 config: PromptTuningConfig,
89 prompt_embeddings: Tensor,
91 reparam_encoder: Option<Tensor>,
93 reparam_decoder: Option<Tensor>,
94}
95
96impl PromptTuning {
97 pub fn new(config: PromptTuningConfig) -> Result<Self, String> {
99 let prompt_embeddings = if config.reparameterize {
100 Tensor::randn(&[config.num_virtual_tokens, config.hidden_dim])
102 } else {
103 Tensor::randn(&[config.num_virtual_tokens, config.d_model])
105 };
106
107 let (reparam_encoder, reparam_decoder) = if config.reparameterize {
108 let encoder = Tensor::randn(&[config.hidden_dim, config.d_model]);
109 let decoder = Tensor::randn(&[config.d_model, config.hidden_dim]);
110 (Some(encoder), Some(decoder))
111 } else {
112 (None, None)
113 };
114
115 Ok(PromptTuning {
116 config,
117 prompt_embeddings,
118 reparam_encoder,
119 reparam_decoder,
120 })
121 }
122
123 pub fn get_prompt_embeddings(&self) -> Result<Tensor, String> {
125 if self.config.reparameterize {
126 let encoder = self.reparam_encoder.as_ref()
128 .ok_or("Encoder not initialized")?;
129 self.prompt_embeddings.matmul(encoder)
130 .map_err(|e| format!("Failed to reparameterize: {:?}", e))
131 } else {
132 Ok(self.prompt_embeddings.clone())
133 }
134 }
135
136 pub fn prepend_prompts(&self, input_embeddings: &Tensor) -> Result<Tensor, String> {
138 let prompt_embeds = self.get_prompt_embeddings()?;
139
140 let input_dims = input_embeddings.dims();
141 let prompt_dims = prompt_embeds.dims();
142
143 if input_dims.len() != 3 || prompt_dims.len() != 2 {
144 return Err("Expected input [batch, seq_len, d_model] and prompt [num_tokens, d_model]".to_string());
145 }
146
147 let batch_size = input_dims[0];
148 let seq_len = input_dims[1];
149 let d_model = input_dims[2];
150 let num_prompts = prompt_dims[0];
151
152 let mut result = Vec::with_capacity(batch_size * (num_prompts + seq_len) * d_model);
154
155 let prompt_data = prompt_embeds.data_f32();
156 let input_data = input_embeddings.data_f32();
157
158 for b in 0..batch_size {
159 result.extend_from_slice(&prompt_data);
161
162 let start = b * seq_len * d_model;
164 let end = start + seq_len * d_model;
165 result.extend_from_slice(&input_data[start..end]);
166 }
167
168 Tensor::from_slice(&result, &[batch_size, num_prompts + seq_len, d_model])
169 .map_err(|e| format!("Failed to prepend prompts: {:?}", e))
170 }
171
172 pub fn num_parameters(&self) -> usize {
174 let prompt_params = self.prompt_embeddings.data_f32().len();
175 let reparam_params = if self.config.reparameterize {
176 self.reparam_encoder.as_ref().map(|t| t.data_f32().len()).unwrap_or(0) +
177 self.reparam_decoder.as_ref().map(|t| t.data_f32().len()).unwrap_or(0)
178 } else {
179 0
180 };
181 prompt_params + reparam_params
182 }
183
184 pub fn parameter_efficiency(&self, total_model_params: usize) -> f32 {
186 let tunable_params = self.num_parameters();
187 tunable_params as f32 / total_model_params as f32
188 }
189}
190
191#[derive(Debug, Clone)]
193pub struct PrefixTuningConfig {
194 pub num_prefix_tokens: usize,
196 pub num_layers: usize,
198 pub d_model: usize,
200 pub num_heads: usize,
202 pub init_strategy: PromptInitStrategy,
204 pub prefix_kv: bool,
206}
207
208impl Default for PrefixTuningConfig {
209 fn default() -> Self {
210 PrefixTuningConfig {
211 num_prefix_tokens: 10,
212 num_layers: 12,
213 d_model: 768,
214 num_heads: 12,
215 init_strategy: PromptInitStrategy::Random,
216 prefix_kv: true,
217 }
218 }
219}
220
221impl PrefixTuningConfig {
222 pub fn small(num_layers: usize, d_model: usize, num_heads: usize) -> Self {
224 PrefixTuningConfig {
225 num_prefix_tokens: 5,
226 num_layers,
227 d_model,
228 num_heads,
229 ..Default::default()
230 }
231 }
232
233 pub fn large(num_layers: usize, d_model: usize, num_heads: usize) -> Self {
235 PrefixTuningConfig {
236 num_prefix_tokens: 20,
237 num_layers,
238 d_model,
239 num_heads,
240 ..Default::default()
241 }
242 }
243}
244
245pub struct PrefixTuning {
247 config: PrefixTuningConfig,
248 prefix_params: HashMap<usize, LayerPrefix>,
250}
251
252#[derive(Clone)]
254pub struct LayerPrefix {
255 pub prefix_key: Tensor,
257 pub prefix_value: Tensor,
259}
260
261impl PrefixTuning {
262 pub fn new(config: PrefixTuningConfig) -> Result<Self, String> {
264 let mut prefix_params = HashMap::new();
265
266 let head_dim = config.d_model / config.num_heads;
267
268 for layer_idx in 0..config.num_layers {
269 let prefix_key = Tensor::randn(&[config.num_prefix_tokens, config.d_model]);
270 let prefix_value = if config.prefix_kv {
271 Tensor::randn(&[config.num_prefix_tokens, config.d_model])
272 } else {
273 prefix_key.clone()
274 };
275
276 prefix_params.insert(layer_idx, LayerPrefix {
277 prefix_key,
278 prefix_value,
279 });
280 }
281
282 Ok(PrefixTuning {
283 config,
284 prefix_params,
285 })
286 }
287
288 pub fn get_layer_prefix(&self, layer_idx: usize) -> Option<&LayerPrefix> {
290 self.prefix_params.get(&layer_idx)
291 }
292
293 pub fn prepend_to_kv(
295 &self,
296 layer_idx: usize,
297 key: &Tensor,
298 value: &Tensor,
299 ) -> Result<(Tensor, Tensor), String> {
300 let prefix = self.get_layer_prefix(layer_idx)
301 .ok_or(format!("No prefix for layer {}", layer_idx))?;
302
303 let new_key = self.concatenate_prefix(&prefix.prefix_key, key)?;
304 let new_value = self.concatenate_prefix(&prefix.prefix_value, value)?;
305
306 Ok((new_key, new_value))
307 }
308
309 fn concatenate_prefix(&self, prefix: &Tensor, tensor: &Tensor) -> Result<Tensor, String> {
311 let prefix_dims = prefix.dims();
312 let tensor_dims = tensor.dims();
313
314 if tensor_dims.len() != 3 {
315 return Err("Expected tensor [batch, seq_len, d_model]".to_string());
316 }
317
318 let batch_size = tensor_dims[0];
319 let seq_len = tensor_dims[1];
320 let d_model = tensor_dims[2];
321 let num_prefix = prefix_dims[0];
322
323 let mut result = Vec::with_capacity(batch_size * (num_prefix + seq_len) * d_model);
324
325 let prefix_data = prefix.data_f32();
326 let tensor_data = tensor.data_f32();
327
328 for b in 0..batch_size {
329 result.extend_from_slice(&prefix_data);
331
332 let start = b * seq_len * d_model;
334 let end = start + seq_len * d_model;
335 result.extend_from_slice(&tensor_data[start..end]);
336 }
337
338 Tensor::from_slice(&result, &[batch_size, num_prefix + seq_len, d_model])
339 .map_err(|e| format!("Failed to concatenate prefix: {:?}", e))
340 }
341
342 pub fn num_parameters(&self) -> usize {
344 let mut total = 0;
345 for prefix in self.prefix_params.values() {
346 total += prefix.prefix_key.data_f32().len();
347 total += prefix.prefix_value.data_f32().len();
348 }
349 total
350 }
351
352 pub fn parameter_efficiency(&self, total_model_params: usize) -> f32 {
354 let tunable_params = self.num_parameters();
355 tunable_params as f32 / total_model_params as f32
356 }
357}
358
359pub struct PTuningV2 {
361 prefix_tuning: PrefixTuning,
362 prefix_mlp: Option<Tensor>,
364}
365
366impl PTuningV2 {
367 pub fn new(config: PrefixTuningConfig) -> Result<Self, String> {
369 let prefix_tuning = PrefixTuning::new(config.clone())?;
370
371 let prefix_mlp = Some(Tensor::randn(&[config.d_model, config.d_model]));
373
374 Ok(PTuningV2 {
375 prefix_tuning,
376 prefix_mlp,
377 })
378 }
379
380 pub fn get_layer_prefix_transformed(&self, layer_idx: usize) -> Option<LayerPrefix> {
382 let prefix = self.prefix_tuning.get_layer_prefix(layer_idx)?;
383
384 if let Some(mlp) = &self.prefix_mlp {
385 Some(prefix.clone())
388 } else {
389 Some(prefix.clone())
390 }
391 }
392
393 pub fn num_parameters(&self) -> usize {
395 let prefix_params = self.prefix_tuning.num_parameters();
396 let mlp_params = self.prefix_mlp.as_ref()
397 .map(|t| t.data_f32().len())
398 .unwrap_or(0);
399 prefix_params + mlp_params
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_prompt_tuning_config() {
409 let config = PromptTuningConfig::default();
410 assert_eq!(config.num_virtual_tokens, 20);
411 assert_eq!(config.d_model, 768);
412
413 let short = PromptTuningConfig::short(512);
414 assert_eq!(short.num_virtual_tokens, 10);
415 assert_eq!(short.d_model, 512);
416 }
417
418 #[test]
419 fn test_prompt_tuning_creation() {
420 let config = PromptTuningConfig::default();
421 let prompt_tuning = PromptTuning::new(config).unwrap();
422
423 let embeddings = prompt_tuning.get_prompt_embeddings().unwrap();
424 assert_eq!(embeddings.dims(), &[20, 768]);
425 }
426
427 #[test]
428 fn test_prompt_tuning_prepend() {
429 let config = PromptTuningConfig {
430 num_virtual_tokens: 5,
431 d_model: 64,
432 ..Default::default()
433 };
434 let prompt_tuning = PromptTuning::new(config).unwrap();
435
436 let input = Tensor::randn(&[2, 10, 64]);
437 let output = prompt_tuning.prepend_prompts(&input).unwrap();
438
439 assert_eq!(output.dims(), &[2, 15, 64]); }
441
442 #[test]
443 fn test_prompt_tuning_reparameterization() {
444 let config = PromptTuningConfig {
445 num_virtual_tokens: 10,
446 d_model: 768,
447 reparameterize: true,
448 hidden_dim: 256,
449 ..Default::default()
450 };
451 let prompt_tuning = PromptTuning::new(config).unwrap();
452
453 let embeddings = prompt_tuning.get_prompt_embeddings().unwrap();
454 assert_eq!(embeddings.dims(), &[10, 768]);
455 }
456
457 #[test]
458 fn test_prompt_tuning_parameters() {
459 let config = PromptTuningConfig {
460 num_virtual_tokens: 20,
461 d_model: 768,
462 ..Default::default()
463 };
464 let prompt_tuning = PromptTuning::new(config).unwrap();
465
466 let num_params = prompt_tuning.num_parameters();
467 assert_eq!(num_params, 20 * 768);
468
469 let efficiency = prompt_tuning.parameter_efficiency(100_000_000);
470 assert!(efficiency < 0.01); }
472
473 #[test]
474 fn test_prefix_tuning_config() {
475 let config = PrefixTuningConfig::default();
476 assert_eq!(config.num_prefix_tokens, 10);
477 assert_eq!(config.num_layers, 12);
478
479 let small = PrefixTuningConfig::small(6, 512, 8);
480 assert_eq!(small.num_prefix_tokens, 5);
481 assert_eq!(small.num_layers, 6);
482 }
483
484 #[test]
485 fn test_prefix_tuning_creation() {
486 let config = PrefixTuningConfig {
487 num_prefix_tokens: 5,
488 num_layers: 3,
489 d_model: 64,
490 num_heads: 4,
491 ..Default::default()
492 };
493 let prefix_tuning = PrefixTuning::new(config).unwrap();
494
495 let prefix = prefix_tuning.get_layer_prefix(0).unwrap();
496 assert_eq!(prefix.prefix_key.dims(), &[5, 64]);
497 assert_eq!(prefix.prefix_value.dims(), &[5, 64]);
498 }
499
500 #[test]
501 fn test_prefix_tuning_prepend() {
502 let config = PrefixTuningConfig {
503 num_prefix_tokens: 3,
504 num_layers: 2,
505 d_model: 32,
506 num_heads: 4,
507 ..Default::default()
508 };
509 let prefix_tuning = PrefixTuning::new(config).unwrap();
510
511 let key = Tensor::randn(&[2, 8, 32]);
512 let value = Tensor::randn(&[2, 8, 32]);
513
514 let (new_key, new_value) = prefix_tuning.prepend_to_kv(0, &key, &value).unwrap();
515
516 assert_eq!(new_key.dims(), &[2, 11, 32]); assert_eq!(new_value.dims(), &[2, 11, 32]);
518 }
519
520 #[test]
521 fn test_prefix_tuning_parameters() {
522 let config = PrefixTuningConfig {
523 num_prefix_tokens: 10,
524 num_layers: 12,
525 d_model: 768,
526 num_heads: 12,
527 ..Default::default()
528 };
529 let prefix_tuning = PrefixTuning::new(config).unwrap();
530
531 let num_params = prefix_tuning.num_parameters();
532 assert_eq!(num_params, 10 * 768 * 2 * 12);
534 }
535
536 #[test]
537 fn test_ptuning_v2() {
538 let config = PrefixTuningConfig {
539 num_prefix_tokens: 5,
540 num_layers: 3,
541 d_model: 64,
542 num_heads: 4,
543 ..Default::default()
544 };
545 let ptuning = PTuningV2::new(config).unwrap();
546
547 let prefix = ptuning.get_layer_prefix_transformed(0).unwrap();
548 assert_eq!(prefix.prefix_key.dims(), &[5, 64]);
549 }
550}