ghostflow_nn/
prompt_tuning.rs

1//! Prompt Tuning and Prefix Tuning
2//!
3//! Implements parameter-efficient fine-tuning methods:
4//! - Prompt Tuning: Learnable soft prompts prepended to input
5//! - Prefix Tuning: Learnable prefix vectors for each layer
6//! - P-Tuning v2: Prefix tuning with deep prompt optimization
7//! - Adapter-based prompt tuning
8
9use ghostflow_core::Tensor;
10use std::collections::HashMap;
11
12/// Prompt tuning configuration
13#[derive(Debug, Clone)]
14pub struct PromptTuningConfig {
15    /// Number of virtual tokens (prompt length)
16    pub num_virtual_tokens: usize,
17    /// Model dimension
18    pub d_model: usize,
19    /// Prompt initialization strategy
20    pub init_strategy: PromptInitStrategy,
21    /// Reparameterization for stability
22    pub reparameterize: bool,
23    /// Hidden dimension for reparameterization
24    pub hidden_dim: usize,
25}
26
27/// Prompt initialization strategies
28#[derive(Debug, Clone, Copy, PartialEq)]
29pub enum PromptInitStrategy {
30    /// Random initialization
31    Random,
32    /// Initialize from vocabulary embeddings
33    Vocab,
34    /// Initialize from text tokens
35    Text,
36}
37
38impl Default for PromptTuningConfig {
39    fn default() -> Self {
40        PromptTuningConfig {
41            num_virtual_tokens: 20,
42            d_model: 768,
43            init_strategy: PromptInitStrategy::Random,
44            reparameterize: false,
45            hidden_dim: 512,
46        }
47    }
48}
49
50impl PromptTuningConfig {
51    /// Short prompt configuration (5-10 tokens)
52    pub fn short(d_model: usize) -> Self {
53        PromptTuningConfig {
54            num_virtual_tokens: 10,
55            d_model,
56            ..Default::default()
57        }
58    }
59    
60    /// Medium prompt configuration (20-50 tokens)
61    pub fn medium(d_model: usize) -> Self {
62        PromptTuningConfig {
63            num_virtual_tokens: 30,
64            d_model,
65            ..Default::default()
66        }
67    }
68    
69    /// Long prompt configuration (50-100 tokens)
70    pub fn long(d_model: usize) -> Self {
71        PromptTuningConfig {
72            num_virtual_tokens: 80,
73            d_model,
74            ..Default::default()
75        }
76    }
77    
78    /// With reparameterization for stability
79    pub fn with_reparameterization(mut self, hidden_dim: usize) -> Self {
80        self.reparameterize = true;
81        self.hidden_dim = hidden_dim;
82        self
83    }
84}
85
86/// Prompt Tuning implementation
87pub struct PromptTuning {
88    config: PromptTuningConfig,
89    /// Learnable prompt embeddings
90    prompt_embeddings: Tensor,
91    /// Optional reparameterization layers
92    reparam_encoder: Option<Tensor>,
93    reparam_decoder: Option<Tensor>,
94}
95
96impl PromptTuning {
97    /// Create new prompt tuning
98    pub fn new(config: PromptTuningConfig) -> Result<Self, String> {
99        let prompt_embeddings = if config.reparameterize {
100            // Initialize in lower dimension
101            Tensor::randn(&[config.num_virtual_tokens, config.hidden_dim])
102        } else {
103            // Direct initialization
104            Tensor::randn(&[config.num_virtual_tokens, config.d_model])
105        };
106        
107        let (reparam_encoder, reparam_decoder) = if config.reparameterize {
108            let encoder = Tensor::randn(&[config.hidden_dim, config.d_model]);
109            let decoder = Tensor::randn(&[config.d_model, config.hidden_dim]);
110            (Some(encoder), Some(decoder))
111        } else {
112            (None, None)
113        };
114        
115        Ok(PromptTuning {
116            config,
117            prompt_embeddings,
118            reparam_encoder,
119            reparam_decoder,
120        })
121    }
122    
123    /// Get prompt embeddings
124    pub fn get_prompt_embeddings(&self) -> Result<Tensor, String> {
125        if self.config.reparameterize {
126            // Reparameterize: prompt @ encoder
127            let encoder = self.reparam_encoder.as_ref()
128                .ok_or("Encoder not initialized")?;
129            self.prompt_embeddings.matmul(encoder)
130                .map_err(|e| format!("Failed to reparameterize: {:?}", e))
131        } else {
132            Ok(self.prompt_embeddings.clone())
133        }
134    }
135    
136    /// Prepend prompts to input embeddings
137    pub fn prepend_prompts(&self, input_embeddings: &Tensor) -> Result<Tensor, String> {
138        let prompt_embeds = self.get_prompt_embeddings()?;
139        
140        let input_dims = input_embeddings.dims();
141        let prompt_dims = prompt_embeds.dims();
142        
143        if input_dims.len() != 3 || prompt_dims.len() != 2 {
144            return Err("Expected input [batch, seq_len, d_model] and prompt [num_tokens, d_model]".to_string());
145        }
146        
147        let batch_size = input_dims[0];
148        let seq_len = input_dims[1];
149        let d_model = input_dims[2];
150        let num_prompts = prompt_dims[0];
151        
152        // Expand prompts for batch
153        let mut result = Vec::with_capacity(batch_size * (num_prompts + seq_len) * d_model);
154        
155        let prompt_data = prompt_embeds.data_f32();
156        let input_data = input_embeddings.data_f32();
157        
158        for b in 0..batch_size {
159            // Add prompts
160            result.extend_from_slice(&prompt_data);
161            
162            // Add input embeddings
163            let start = b * seq_len * d_model;
164            let end = start + seq_len * d_model;
165            result.extend_from_slice(&input_data[start..end]);
166        }
167        
168        Tensor::from_slice(&result, &[batch_size, num_prompts + seq_len, d_model])
169            .map_err(|e| format!("Failed to prepend prompts: {:?}", e))
170    }
171    
172    /// Get number of trainable parameters
173    pub fn num_parameters(&self) -> usize {
174        let prompt_params = self.prompt_embeddings.data_f32().len();
175        let reparam_params = if self.config.reparameterize {
176            self.reparam_encoder.as_ref().map(|t| t.data_f32().len()).unwrap_or(0) +
177            self.reparam_decoder.as_ref().map(|t| t.data_f32().len()).unwrap_or(0)
178        } else {
179            0
180        };
181        prompt_params + reparam_params
182    }
183    
184    /// Get parameter efficiency ratio
185    pub fn parameter_efficiency(&self, total_model_params: usize) -> f32 {
186        let tunable_params = self.num_parameters();
187        tunable_params as f32 / total_model_params as f32
188    }
189}
190
191/// Prefix Tuning configuration
192#[derive(Debug, Clone)]
193pub struct PrefixTuningConfig {
194    /// Number of prefix tokens per layer
195    pub num_prefix_tokens: usize,
196    /// Number of layers
197    pub num_layers: usize,
198    /// Model dimension
199    pub d_model: usize,
200    /// Number of attention heads
201    pub num_heads: usize,
202    /// Prefix initialization strategy
203    pub init_strategy: PromptInitStrategy,
204    /// Use prefix for both key and value
205    pub prefix_kv: bool,
206}
207
208impl Default for PrefixTuningConfig {
209    fn default() -> Self {
210        PrefixTuningConfig {
211            num_prefix_tokens: 10,
212            num_layers: 12,
213            d_model: 768,
214            num_heads: 12,
215            init_strategy: PromptInitStrategy::Random,
216            prefix_kv: true,
217        }
218    }
219}
220
221impl PrefixTuningConfig {
222    /// Configuration for small models
223    pub fn small(num_layers: usize, d_model: usize, num_heads: usize) -> Self {
224        PrefixTuningConfig {
225            num_prefix_tokens: 5,
226            num_layers,
227            d_model,
228            num_heads,
229            ..Default::default()
230        }
231    }
232    
233    /// Configuration for large models
234    pub fn large(num_layers: usize, d_model: usize, num_heads: usize) -> Self {
235        PrefixTuningConfig {
236            num_prefix_tokens: 20,
237            num_layers,
238            d_model,
239            num_heads,
240            ..Default::default()
241        }
242    }
243}
244
245/// Prefix Tuning implementation
246pub struct PrefixTuning {
247    config: PrefixTuningConfig,
248    /// Prefix parameters for each layer
249    prefix_params: HashMap<usize, LayerPrefix>,
250}
251
252/// Prefix parameters for a single layer
253#[derive(Clone)]
254pub struct LayerPrefix {
255    /// Prefix for keys
256    pub prefix_key: Tensor,
257    /// Prefix for values
258    pub prefix_value: Tensor,
259}
260
261impl PrefixTuning {
262    /// Create new prefix tuning
263    pub fn new(config: PrefixTuningConfig) -> Result<Self, String> {
264        let mut prefix_params = HashMap::new();
265        
266        let head_dim = config.d_model / config.num_heads;
267        
268        for layer_idx in 0..config.num_layers {
269            let prefix_key = Tensor::randn(&[config.num_prefix_tokens, config.d_model]);
270            let prefix_value = if config.prefix_kv {
271                Tensor::randn(&[config.num_prefix_tokens, config.d_model])
272            } else {
273                prefix_key.clone()
274            };
275            
276            prefix_params.insert(layer_idx, LayerPrefix {
277                prefix_key,
278                prefix_value,
279            });
280        }
281        
282        Ok(PrefixTuning {
283            config,
284            prefix_params,
285        })
286    }
287    
288    /// Get prefix for a specific layer
289    pub fn get_layer_prefix(&self, layer_idx: usize) -> Option<&LayerPrefix> {
290        self.prefix_params.get(&layer_idx)
291    }
292    
293    /// Prepend prefix to key/value in attention
294    pub fn prepend_to_kv(
295        &self,
296        layer_idx: usize,
297        key: &Tensor,
298        value: &Tensor,
299    ) -> Result<(Tensor, Tensor), String> {
300        let prefix = self.get_layer_prefix(layer_idx)
301            .ok_or(format!("No prefix for layer {}", layer_idx))?;
302        
303        let new_key = self.concatenate_prefix(&prefix.prefix_key, key)?;
304        let new_value = self.concatenate_prefix(&prefix.prefix_value, value)?;
305        
306        Ok((new_key, new_value))
307    }
308    
309    /// Concatenate prefix to tensor
310    fn concatenate_prefix(&self, prefix: &Tensor, tensor: &Tensor) -> Result<Tensor, String> {
311        let prefix_dims = prefix.dims();
312        let tensor_dims = tensor.dims();
313        
314        if tensor_dims.len() != 3 {
315            return Err("Expected tensor [batch, seq_len, d_model]".to_string());
316        }
317        
318        let batch_size = tensor_dims[0];
319        let seq_len = tensor_dims[1];
320        let d_model = tensor_dims[2];
321        let num_prefix = prefix_dims[0];
322        
323        let mut result = Vec::with_capacity(batch_size * (num_prefix + seq_len) * d_model);
324        
325        let prefix_data = prefix.data_f32();
326        let tensor_data = tensor.data_f32();
327        
328        for b in 0..batch_size {
329            // Add prefix
330            result.extend_from_slice(&prefix_data);
331            
332            // Add tensor data
333            let start = b * seq_len * d_model;
334            let end = start + seq_len * d_model;
335            result.extend_from_slice(&tensor_data[start..end]);
336        }
337        
338        Tensor::from_slice(&result, &[batch_size, num_prefix + seq_len, d_model])
339            .map_err(|e| format!("Failed to concatenate prefix: {:?}", e))
340    }
341    
342    /// Get number of trainable parameters
343    pub fn num_parameters(&self) -> usize {
344        let mut total = 0;
345        for prefix in self.prefix_params.values() {
346            total += prefix.prefix_key.data_f32().len();
347            total += prefix.prefix_value.data_f32().len();
348        }
349        total
350    }
351    
352    /// Get parameter efficiency ratio
353    pub fn parameter_efficiency(&self, total_model_params: usize) -> f32 {
354        let tunable_params = self.num_parameters();
355        tunable_params as f32 / total_model_params as f32
356    }
357}
358
359/// P-Tuning v2 (Deep Prompt Tuning)
360pub struct PTuningV2 {
361    prefix_tuning: PrefixTuning,
362    /// Additional MLP for prefix generation
363    prefix_mlp: Option<Tensor>,
364}
365
366impl PTuningV2 {
367    /// Create new P-Tuning v2
368    pub fn new(config: PrefixTuningConfig) -> Result<Self, String> {
369        let prefix_tuning = PrefixTuning::new(config.clone())?;
370        
371        // Optional MLP for generating prefixes
372        let prefix_mlp = Some(Tensor::randn(&[config.d_model, config.d_model]));
373        
374        Ok(PTuningV2 {
375            prefix_tuning,
376            prefix_mlp,
377        })
378    }
379    
380    /// Get prefix with MLP transformation
381    pub fn get_layer_prefix_transformed(&self, layer_idx: usize) -> Option<LayerPrefix> {
382        let prefix = self.prefix_tuning.get_layer_prefix(layer_idx)?;
383        
384        if let Some(mlp) = &self.prefix_mlp {
385            // Transform prefix through MLP
386            // Simplified - in practice would be full MLP
387            Some(prefix.clone())
388        } else {
389            Some(prefix.clone())
390        }
391    }
392    
393    /// Get number of trainable parameters
394    pub fn num_parameters(&self) -> usize {
395        let prefix_params = self.prefix_tuning.num_parameters();
396        let mlp_params = self.prefix_mlp.as_ref()
397            .map(|t| t.data_f32().len())
398            .unwrap_or(0);
399        prefix_params + mlp_params
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    
407    #[test]
408    fn test_prompt_tuning_config() {
409        let config = PromptTuningConfig::default();
410        assert_eq!(config.num_virtual_tokens, 20);
411        assert_eq!(config.d_model, 768);
412        
413        let short = PromptTuningConfig::short(512);
414        assert_eq!(short.num_virtual_tokens, 10);
415        assert_eq!(short.d_model, 512);
416    }
417    
418    #[test]
419    fn test_prompt_tuning_creation() {
420        let config = PromptTuningConfig::default();
421        let prompt_tuning = PromptTuning::new(config).unwrap();
422        
423        let embeddings = prompt_tuning.get_prompt_embeddings().unwrap();
424        assert_eq!(embeddings.dims(), &[20, 768]);
425    }
426    
427    #[test]
428    fn test_prompt_tuning_prepend() {
429        let config = PromptTuningConfig {
430            num_virtual_tokens: 5,
431            d_model: 64,
432            ..Default::default()
433        };
434        let prompt_tuning = PromptTuning::new(config).unwrap();
435        
436        let input = Tensor::randn(&[2, 10, 64]);
437        let output = prompt_tuning.prepend_prompts(&input).unwrap();
438        
439        assert_eq!(output.dims(), &[2, 15, 64]); // 5 prompts + 10 input
440    }
441    
442    #[test]
443    fn test_prompt_tuning_reparameterization() {
444        let config = PromptTuningConfig {
445            num_virtual_tokens: 10,
446            d_model: 768,
447            reparameterize: true,
448            hidden_dim: 256,
449            ..Default::default()
450        };
451        let prompt_tuning = PromptTuning::new(config).unwrap();
452        
453        let embeddings = prompt_tuning.get_prompt_embeddings().unwrap();
454        assert_eq!(embeddings.dims(), &[10, 768]);
455    }
456    
457    #[test]
458    fn test_prompt_tuning_parameters() {
459        let config = PromptTuningConfig {
460            num_virtual_tokens: 20,
461            d_model: 768,
462            ..Default::default()
463        };
464        let prompt_tuning = PromptTuning::new(config).unwrap();
465        
466        let num_params = prompt_tuning.num_parameters();
467        assert_eq!(num_params, 20 * 768);
468        
469        let efficiency = prompt_tuning.parameter_efficiency(100_000_000);
470        assert!(efficiency < 0.01); // Less than 1% of model parameters
471    }
472    
473    #[test]
474    fn test_prefix_tuning_config() {
475        let config = PrefixTuningConfig::default();
476        assert_eq!(config.num_prefix_tokens, 10);
477        assert_eq!(config.num_layers, 12);
478        
479        let small = PrefixTuningConfig::small(6, 512, 8);
480        assert_eq!(small.num_prefix_tokens, 5);
481        assert_eq!(small.num_layers, 6);
482    }
483    
484    #[test]
485    fn test_prefix_tuning_creation() {
486        let config = PrefixTuningConfig {
487            num_prefix_tokens: 5,
488            num_layers: 3,
489            d_model: 64,
490            num_heads: 4,
491            ..Default::default()
492        };
493        let prefix_tuning = PrefixTuning::new(config).unwrap();
494        
495        let prefix = prefix_tuning.get_layer_prefix(0).unwrap();
496        assert_eq!(prefix.prefix_key.dims(), &[5, 64]);
497        assert_eq!(prefix.prefix_value.dims(), &[5, 64]);
498    }
499    
500    #[test]
501    fn test_prefix_tuning_prepend() {
502        let config = PrefixTuningConfig {
503            num_prefix_tokens: 3,
504            num_layers: 2,
505            d_model: 32,
506            num_heads: 4,
507            ..Default::default()
508        };
509        let prefix_tuning = PrefixTuning::new(config).unwrap();
510        
511        let key = Tensor::randn(&[2, 8, 32]);
512        let value = Tensor::randn(&[2, 8, 32]);
513        
514        let (new_key, new_value) = prefix_tuning.prepend_to_kv(0, &key, &value).unwrap();
515        
516        assert_eq!(new_key.dims(), &[2, 11, 32]); // 3 prefix + 8 original
517        assert_eq!(new_value.dims(), &[2, 11, 32]);
518    }
519    
520    #[test]
521    fn test_prefix_tuning_parameters() {
522        let config = PrefixTuningConfig {
523            num_prefix_tokens: 10,
524            num_layers: 12,
525            d_model: 768,
526            num_heads: 12,
527            ..Default::default()
528        };
529        let prefix_tuning = PrefixTuning::new(config).unwrap();
530        
531        let num_params = prefix_tuning.num_parameters();
532        // 10 tokens * 768 dim * 2 (key+value) * 12 layers
533        assert_eq!(num_params, 10 * 768 * 2 * 12);
534    }
535    
536    #[test]
537    fn test_ptuning_v2() {
538        let config = PrefixTuningConfig {
539            num_prefix_tokens: 5,
540            num_layers: 3,
541            d_model: 64,
542            num_heads: 4,
543            ..Default::default()
544        };
545        let ptuning = PTuningV2::new(config).unwrap();
546        
547        let prefix = ptuning.get_layer_prefix_transformed(0).unwrap();
548        assert_eq!(prefix.prefix_key.dims(), &[5, 64]);
549    }
550}