Skip to main content

god_graph/transformer/
generation.rs

1//! Text generation utilities
2
3use crate::tensor::DenseTensor;
4use super::model::LlamaModel;
5
6/// Generation configuration
7#[derive(Debug, Clone)]
8pub struct GenerationConfig {
9    /// Maximum length to generate
10    pub max_length: usize,
11    /// Minimum length to generate
12    pub min_length: usize,
13    /// Temperature for sampling (higher = more random)
14    pub temperature: f64,
15    /// Top-k sampling (0 = disabled)
16    pub top_k: usize,
17    /// Top-p (nucleus) sampling (0.0 = disabled)
18    pub top_p: f64,
19    /// Repetition penalty
20    pub repetition_penalty: f64,
21    /// EOS token ID
22    pub eos_token_id: Option<usize>,
23    /// Pad token ID
24    pub pad_token_id: Option<usize>,
25    /// Do sample (if false, use greedy decoding)
26    pub do_sample: bool,
27    /// Number of beams for beam search (1 = disabled)
28    pub num_beams: usize,
29    /// Length penalty for beam search
30    pub length_penalty: f64,
31}
32
33impl Default for GenerationConfig {
34    fn default() -> Self {
35        Self {
36            max_length: 256,
37            min_length: 0,
38            temperature: 1.0,
39            top_k: 0,
40            top_p: 0.0,
41            repetition_penalty: 1.0,
42            eos_token_id: None,
43            pad_token_id: None,
44            do_sample: false,
45            num_beams: 1,
46            length_penalty: 1.0,
47        }
48    }
49}
50
51impl GenerationConfig {
52    /// Create config for greedy decoding
53    pub fn greedy() -> Self {
54        Self {
55            do_sample: false,
56            ..Self::default()
57        }
58    }
59
60    /// Create config for sampling
61    pub fn sampling(temperature: f64) -> Self {
62        Self {
63            do_sample: true,
64            temperature,
65            ..Self::default()
66        }
67    }
68
69    /// Create config for beam search
70    pub fn beam_search(num_beams: usize) -> Self {
71        Self {
72            do_sample: false,
73            num_beams,
74            ..Self::default()
75        }
76    }
77
78    /// Set maximum length
79    pub fn with_max_length(mut self, max_length: usize) -> Self {
80        self.max_length = max_length;
81        self
82    }
83
84    /// Set temperature
85    pub fn with_temperature(mut self, temperature: f64) -> Self {
86        self.temperature = temperature;
87        self
88    }
89
90    /// Set top-k
91    pub fn with_top_k(mut self, top_k: usize) -> Self {
92        self.top_k = top_k;
93        self
94    }
95
96    /// Set top-p
97    pub fn with_top_p(mut self, top_p: f64) -> Self {
98        self.top_p = top_p;
99        self
100    }
101
102    /// Set EOS token ID
103    pub fn with_eos_token_id(mut self, eos_token_id: usize) -> Self {
104        self.eos_token_id = Some(eos_token_id);
105        self
106    }
107}
108
109/// Text generator for LLaMA models
110pub struct TextGenerator<'a> {
111    /// Reference to the model
112    model: &'a LlamaModel,
113    /// Generation configuration
114    config: GenerationConfig,
115}
116
117impl<'a> TextGenerator<'a> {
118    /// Create a new text generator
119    pub fn new(model: &'a LlamaModel, config: GenerationConfig) -> Self {
120        Self { model, config }
121    }
122
123    /// Generate text from input prompt
124    ///
125    /// # Arguments
126    /// * `input_ids` - Input token IDs [seq_len]
127    ///
128    /// # Returns
129    /// Generated token IDs
130    pub fn generate(&self, input_ids: &[usize]) -> Vec<usize> {
131        if self.config.num_beams > 1 {
132            self.generate_beam_search(input_ids)
133        } else if self.config.do_sample {
134            self.generate_sampling(input_ids)
135        } else {
136            self.generate_greedy(input_ids)
137        }
138    }
139
140    /// Greedy decoding (always pick the highest probability token)
141    fn generate_greedy(&self, input_ids: &[usize]) -> Vec<usize> {
142        let mut current_ids = input_ids.to_vec();
143        
144        for _ in 0..self.config.max_length {
145            // Forward pass
146            let logits = self.model.forward_single(&current_ids, None);
147            
148            // Get logits for the last position
149            let seq_len = current_ids.len();
150            let last_logits = logits.get_row(seq_len - 1);
151            
152            // Apply temperature
153            let mut probs = last_logits.clone();
154            if self.config.temperature != 1.0 {
155                probs = probs.scale(1.0 / self.config.temperature);
156            }
157            
158            // Apply softmax
159            probs = probs.softmax(-1);
160            
161            // Greedy: pick the token with highest probability
162            let next_token = self.argmax(probs.data());
163            
164            // Check for EOS
165            if Some(next_token) == self.config.eos_token_id {
166                break;
167            }
168            
169            current_ids.push(next_token);
170        }
171        
172        current_ids
173    }
174
175    /// Sampling-based generation
176    fn generate_sampling(&self, input_ids: &[usize]) -> Vec<usize> {
177        let mut current_ids = input_ids.to_vec();
178        let mut rng = rand::thread_rng();
179        
180        for _ in 0..self.config.max_length {
181            // Forward pass
182            let logits = self.model.forward_single(&current_ids, None);
183            
184            // Get logits for the last position
185            let seq_len = current_ids.len();
186            let last_logits = logits.get_row(seq_len - 1);
187            
188            // Apply temperature
189            let mut probs = last_logits.clone();
190            if self.config.temperature != 1.0 {
191                probs = probs.scale(1.0 / self.config.temperature);
192            }
193            
194            // Apply softmax
195            probs = probs.softmax(-1);
196            
197            // Apply top-k filtering
198            if self.config.top_k > 0 {
199                probs = self.top_k_filtering(&probs, self.config.top_k);
200            }
201            
202            // Apply top-p (nucleus) filtering
203            if self.config.top_p > 0.0 {
204                probs = self.top_p_filtering(&probs, self.config.top_p);
205            }
206            
207            // Sample from the distribution
208            let next_token = self.sample_from_probs(probs.data(), &mut rng);
209            
210            // Check for EOS
211            if Some(next_token) == self.config.eos_token_id {
212                break;
213            }
214            
215            current_ids.push(next_token);
216        }
217        
218        current_ids
219    }
220
221    /// Beam search generation
222    fn generate_beam_search(&self, input_ids: &[usize]) -> Vec<usize> {
223        // Simplified beam search implementation
224        // A full implementation would track multiple hypotheses
225        
226        let mut beams: Vec<(Vec<usize>, f64)> = vec![(input_ids.to_vec(), 0.0)];
227        
228        for _ in 0..self.config.max_length {
229            let mut candidates: Vec<(Vec<usize>, f64)> = Vec::new();
230            
231            for (beam_ids, beam_score) in &beams {
232                // Forward pass
233                let logits = self.model.forward_single(beam_ids, None);
234                
235                // Get logits for the last position
236                let seq_len = beam_ids.len();
237                let last_logits = logits.get_row(seq_len - 1);
238                
239                // Get top-k candidates
240                let top_indices = self.topk_indices(last_logits.data(), self.config.num_beams);
241                
242                for &next_token in &top_indices {
243                    let mut new_beam = beam_ids.clone();
244                    new_beam.push(next_token);
245                    
246                    // Update score (log probability)
247                    let token_prob = last_logits.data()[next_token];
248                    let new_score = beam_score + token_prob.ln();
249                    
250                    candidates.push((new_beam, new_score));
251                }
252            }
253
254            // Keep top-k beams
255            candidates.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
256            beams = candidates.into_iter().take(self.config.num_beams).collect();
257
258            // Check if all beams reached EOS
259            if beams.iter().all(|(ids, _)| {
260                ids.last() == self.config.eos_token_id.as_ref()
261            }) {
262                break;
263            }
264        }
265
266        // Return the best beam
267        beams.into_iter()
268            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
269            .map(|(ids, _)| ids)
270            .unwrap_or_else(|| input_ids.to_vec())
271    }
272
273    /// Argmax: find index of maximum value
274    fn argmax(&self, data: &[f64]) -> usize {
275        data.iter()
276            .enumerate()
277            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
278            .map(|(i, _)| i)
279            .unwrap_or(0)
280    }
281
282    /// Top-k indices
283    fn topk_indices(&self, data: &[f64], k: usize) -> Vec<usize> {
284        let mut indexed: Vec<(usize, &f64)> = data.iter().enumerate().collect();
285        indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
286        indexed.into_iter().take(k).map(|(i, _)| i).collect()
287    }
288
289    /// Top-k filtering: zero out probabilities outside top-k
290    fn top_k_filtering(&self, probs: &DenseTensor, k: usize) -> DenseTensor {
291        let data = probs.data();
292        let top_indices = self.topk_indices(data, k);
293        let threshold = top_indices.iter()
294            .map(|&i| data[i])
295            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
296            .unwrap_or(0.0);
297
298        let mut filtered = probs.clone();
299        for (i, &prob) in data.iter().enumerate() {
300            if prob < threshold {
301                filtered.data_mut()[i] = 0.0;
302            }
303        }
304
305        // Re-normalize
306        let sum: f64 = filtered.data().iter().sum();
307        if sum > 0.0 {
308            filtered = filtered.scale(1.0 / sum);
309        }
310
311        filtered
312    }
313
314    /// Top-p (nucleus) filtering: keep smallest set of tokens with cumulative prob >= p
315    fn top_p_filtering(&self, probs: &DenseTensor, p: f64) -> DenseTensor {
316        let data = probs.data();
317        let mut indexed: Vec<(usize, &f64)> = data.iter().enumerate().collect();
318        indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
319        
320        let mut cumulative_prob = 0.0;
321        let mut cutoff_index = indexed.len();
322        
323        for (i, (_, &prob)) in indexed.iter().enumerate() {
324            cumulative_prob += prob;
325            if cumulative_prob >= p {
326                cutoff_index = i + 1;
327                break;
328            }
329        }
330        
331        let threshold = indexed.into_iter()
332            .take(cutoff_index)
333            .map(|(_, &prob)| prob)
334            .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
335            .unwrap_or(0.0);
336        
337        let mut filtered = probs.clone();
338        for (i, &prob) in data.iter().enumerate() {
339            if prob < threshold {
340                filtered.data_mut()[i] = 0.0;
341            }
342        }
343        
344        // Re-normalize
345        let sum: f64 = filtered.data().iter().sum();
346        if sum > 0.0 {
347            filtered = filtered.scale(1.0 / sum);
348        }
349        
350        filtered
351    }
352
353    /// Sample from probability distribution
354    fn sample_from_probs(&self, probs: &[f64], rng: &mut impl rand::Rng) -> usize {
355        let r: f64 = rng.gen();
356        let mut cumulative = 0.0;
357        
358        for (i, &prob) in probs.iter().enumerate() {
359            cumulative += prob;
360            if r < cumulative {
361                return i;
362            }
363        }
364        
365        probs.len() - 1
366    }
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_generation_config() {
375        let config = GenerationConfig::default();
376        assert_eq!(config.max_length, 256);
377        assert_eq!(config.temperature, 1.0);
378        assert!(!config.do_sample);
379        
380        let greedy = GenerationConfig::greedy();
381        assert!(!greedy.do_sample);
382        
383        let sampling = GenerationConfig::sampling(0.8);
384        assert!(sampling.do_sample);
385        assert_eq!(sampling.temperature, 0.8);
386    }
387
388    #[test]
389    fn test_argmax() {
390        let model = create_test_model();
391        let generator = TextGenerator::new(
392            &model,
393            GenerationConfig::default(),
394        );
395
396        let data = vec![0.1, 0.3, 0.5, 0.2, 0.4];
397        assert_eq!(generator.argmax(&data), 2);
398    }
399
400    #[test]
401    fn test_topk_indices() {
402        let model = create_test_model();
403        let generator = TextGenerator::new(
404            &model,
405            GenerationConfig::default(),
406        );
407
408        let data = vec![0.1, 0.5, 0.3, 0.9, 0.2];
409        let top2 = generator.topk_indices(&data, 2);
410        assert_eq!(top2, vec![3, 1]);
411    }
412}
413
414#[cfg(test)]
415fn create_test_model() -> LlamaModel {
416    use super::model::LlamaModel;
417    use super::layers::{MultiHeadAttention, FeedForward, RMSNorm};
418    use super::loader::LlamaConfig;
419    use crate::tensor::DenseTensor;
420
421    let config = LlamaConfig::llama_7b();
422    let embed_tokens = DenseTensor::ones(vec![config.vocab_size, config.hidden_size]);
423
424    let hidden_dim = config.hidden_size;
425    let num_heads = config.num_attention_heads;
426
427    let w_q = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
428    let w_k = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
429    let w_v = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
430    let w_o = DenseTensor::ones(vec![hidden_dim, hidden_dim]);
431    let self_attn = MultiHeadAttention::standard(w_q, w_k, w_v, w_o, num_heads);
432
433    let gate_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
434    let up_proj = DenseTensor::ones(vec![hidden_dim, config.intermediate_size]);
435    let down_proj = DenseTensor::ones(vec![config.intermediate_size, hidden_dim]);
436    let mlp = FeedForward::swiglu(gate_proj, up_proj, down_proj);
437
438    let input_layernorm = RMSNorm::default(hidden_dim);
439    let post_attention_layernorm = RMSNorm::default(hidden_dim);
440
441    let layer = super::model::LlamaDecoderLayer::new(
442        self_attn, mlp, input_layernorm, post_attention_layernorm
443    );
444
445    let layers = vec![layer; 2]; // Use 2 layers for testing
446    let norm = RMSNorm::default(hidden_dim);
447
448    LlamaModel::new(config, embed_tokens, layers, norm, None)
449}