Skip to main content

axonml_llm/
generation.rs

1//! Text Generation Utilities
2//!
3//! Sampling strategies and generation configuration for language models.
4
5use axonml_tensor::Tensor;
6use serde::{Serialize, Deserialize};
7use rand::Rng;
8
9/// Configuration for text generation.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GenerationConfig {
12    /// Maximum number of new tokens to generate
13    pub max_new_tokens: usize,
14    /// Temperature for sampling (1.0 = no change, <1.0 = more deterministic, >1.0 = more random)
15    pub temperature: f32,
16    /// Top-k sampling: only sample from top k tokens
17    pub top_k: Option<usize>,
18    /// Top-p (nucleus) sampling: sample from tokens with cumulative probability >= p
19    pub top_p: Option<f32>,
20    /// Repetition penalty (1.0 = no penalty, >1.0 = penalize repetition)
21    pub repetition_penalty: f32,
22    /// Stop token IDs
23    pub eos_token_ids: Vec<u32>,
24    /// Pad token ID
25    pub pad_token_id: Option<u32>,
26    /// Whether to do greedy decoding
27    pub do_sample: bool,
28    /// Number of beams for beam search (1 = no beam search)
29    pub num_beams: usize,
30    /// Length penalty for beam search
31    pub length_penalty: f32,
32    /// Early stopping for beam search
33    pub early_stopping: bool,
34}
35
36impl Default for GenerationConfig {
37    fn default() -> Self {
38        Self {
39            max_new_tokens: 50,
40            temperature: 1.0,
41            top_k: None,
42            top_p: None,
43            repetition_penalty: 1.0,
44            eos_token_ids: vec![],
45            pad_token_id: None,
46            do_sample: true,
47            num_beams: 1,
48            length_penalty: 1.0,
49            early_stopping: false,
50        }
51    }
52}
53
54impl GenerationConfig {
55    /// Creates a config for greedy decoding.
56    pub fn greedy() -> Self {
57        Self {
58            do_sample: false,
59            temperature: 1.0,
60            top_k: None,
61            top_p: None,
62            ..Default::default()
63        }
64    }
65
66    /// Creates a config for sampling with temperature.
67    pub fn sampling(temperature: f32) -> Self {
68        Self {
69            do_sample: true,
70            temperature,
71            ..Default::default()
72        }
73    }
74
75    /// Creates a config for top-k sampling.
76    pub fn top_k_sampling(k: usize, temperature: f32) -> Self {
77        Self {
78            do_sample: true,
79            temperature,
80            top_k: Some(k),
81            ..Default::default()
82        }
83    }
84
85    /// Creates a config for nucleus (top-p) sampling.
86    pub fn nucleus_sampling(p: f32, temperature: f32) -> Self {
87        Self {
88            do_sample: true,
89            temperature,
90            top_p: Some(p),
91            ..Default::default()
92        }
93    }
94
95    /// Creates a config for beam search.
96    pub fn beam_search(num_beams: usize) -> Self {
97        Self {
98            do_sample: false,
99            num_beams,
100            ..Default::default()
101        }
102    }
103
104    /// Sets the maximum number of new tokens.
105    pub fn with_max_tokens(mut self, max_new_tokens: usize) -> Self {
106        self.max_new_tokens = max_new_tokens;
107        self
108    }
109
110    /// Sets the EOS token ID.
111    pub fn with_eos_token(mut self, eos_token_id: u32) -> Self {
112        self.eos_token_ids.push(eos_token_id);
113        self
114    }
115
116    /// Sets the repetition penalty.
117    pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
118        self.repetition_penalty = penalty;
119        self
120    }
121}
122
123/// Text generator for language models.
124pub struct TextGenerator {
125    /// Generation configuration
126    pub config: GenerationConfig,
127}
128
129impl TextGenerator {
130    /// Creates a new text generator.
131    pub fn new(config: GenerationConfig) -> Self {
132        Self { config }
133    }
134
135    /// Applies temperature scaling to logits.
136    pub fn apply_temperature(&self, logits: &mut [f32]) {
137        if self.config.temperature != 1.0 {
138            for logit in logits.iter_mut() {
139                *logit /= self.config.temperature;
140            }
141        }
142    }
143
144    /// Applies repetition penalty to logits.
145    pub fn apply_repetition_penalty(&self, logits: &mut [f32], generated_tokens: &[u32]) {
146        if self.config.repetition_penalty != 1.0 {
147            for &token in generated_tokens {
148                let idx = token as usize;
149                if idx < logits.len() {
150                    if logits[idx] > 0.0 {
151                        logits[idx] /= self.config.repetition_penalty;
152                    } else {
153                        logits[idx] *= self.config.repetition_penalty;
154                    }
155                }
156            }
157        }
158    }
159
160    /// Applies top-k filtering to logits.
161    pub fn apply_top_k(&self, logits: &mut [f32]) {
162        if let Some(k) = self.config.top_k {
163            if k < logits.len() {
164                // Find indices of top k values
165                let mut sorted_indices: Vec<usize> = (0..logits.len()).collect();
166                sorted_indices.sort_by(|&a, &b| {
167                    logits[b].partial_cmp(&logits[a]).unwrap()
168                });
169
170                // Create a set of top-k indices
171                let top_k_indices: std::collections::HashSet<usize> =
172                    sorted_indices[..k].iter().copied().collect();
173
174                // Set all values not in top-k to -inf
175                for (i, logit) in logits.iter_mut().enumerate() {
176                    if !top_k_indices.contains(&i) {
177                        *logit = f32::NEG_INFINITY;
178                    }
179                }
180            }
181        }
182    }
183
184    /// Applies top-p (nucleus) filtering to logits.
185    pub fn apply_top_p(&self, logits: &mut [f32]) {
186        if let Some(p) = self.config.top_p {
187            // Convert to probabilities
188            let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
189            let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
190            let sum_exp: f32 = exp_logits.iter().sum();
191            let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
192
193            // Sort by probability
194            let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
195            sorted_indices.sort_by(|&a, &b| {
196                probs[b].partial_cmp(&probs[a]).unwrap()
197            });
198
199            // Find cutoff
200            let mut cumsum = 0.0f32;
201            let mut cutoff_idx = sorted_indices.len();
202
203            for (i, &idx) in sorted_indices.iter().enumerate() {
204                cumsum += probs[idx];
205                if cumsum > p {
206                    cutoff_idx = i + 1;
207                    break;
208                }
209            }
210
211            // Set values outside nucleus to -inf
212            for (i, logit) in logits.iter_mut().enumerate() {
213                if !sorted_indices[..cutoff_idx].contains(&i) {
214                    *logit = f32::NEG_INFINITY;
215                }
216            }
217        }
218    }
219
220    /// Samples from logits distribution.
221    pub fn sample(&self, logits: &[f32]) -> u32 {
222        if !self.config.do_sample {
223            // Greedy: return argmax
224            return self.argmax(logits);
225        }
226
227        // Sample from distribution
228        let mut rng = rand::thread_rng();
229
230        // Softmax
231        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
232        let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
233        let sum_exp: f32 = exp_logits.iter().sum();
234        let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
235
236        // Sample
237        let mut cumsum = 0.0f32;
238        let sample: f32 = rng.gen();
239
240        for (i, &p) in probs.iter().enumerate() {
241            cumsum += p;
242            if sample < cumsum {
243                return i as u32;
244            }
245        }
246
247        // Fallback to last token
248        (logits.len() - 1) as u32
249    }
250
251    /// Returns the index of the maximum value.
252    pub fn argmax(&self, logits: &[f32]) -> u32 {
253        logits
254            .iter()
255            .enumerate()
256            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
257            .map(|(i, _)| i as u32)
258            .unwrap_or(0)
259    }
260
261    /// Processes logits and returns next token.
262    pub fn get_next_token(&self, logits: &[f32], generated_tokens: &[u32]) -> u32 {
263        let mut logits = logits.to_vec();
264
265        // Apply modifiers
266        self.apply_repetition_penalty(&mut logits, generated_tokens);
267        self.apply_temperature(&mut logits);
268        self.apply_top_k(&mut logits);
269        self.apply_top_p(&mut logits);
270
271        // Sample
272        self.sample(&logits)
273    }
274
275    /// Checks if generation should stop.
276    pub fn should_stop(&self, token: u32) -> bool {
277        self.config.eos_token_ids.contains(&token)
278    }
279}
280
281/// Beam for beam search.
282#[derive(Debug, Clone)]
283pub struct Beam {
284    /// Token sequence
285    pub tokens: Vec<u32>,
286    /// Log probability score
287    pub score: f32,
288    /// Whether this beam has finished
289    pub finished: bool,
290}
291
292impl Beam {
293    /// Creates a new beam.
294    pub fn new(initial_tokens: Vec<u32>) -> Self {
295        Self {
296            tokens: initial_tokens,
297            score: 0.0,
298            finished: false,
299        }
300    }
301
302    /// Returns the normalized score (for length penalty).
303    pub fn normalized_score(&self, length_penalty: f32) -> f32 {
304        let length = self.tokens.len() as f32;
305        self.score / length.powf(length_penalty)
306    }
307}
308
309/// Beam search implementation.
310pub struct BeamSearch {
311    /// Number of beams
312    pub num_beams: usize,
313    /// Length penalty
314    pub length_penalty: f32,
315    /// Early stopping
316    pub early_stopping: bool,
317    /// EOS token IDs
318    pub eos_token_ids: Vec<u32>,
319}
320
321impl BeamSearch {
322    /// Creates a new beam search.
323    pub fn new(num_beams: usize, length_penalty: f32, early_stopping: bool, eos_token_ids: Vec<u32>) -> Self {
324        Self {
325            num_beams,
326            length_penalty,
327            early_stopping,
328            eos_token_ids,
329        }
330    }
331
332    /// Initializes beams from input tokens.
333    pub fn init_beams(&self, input_ids: &Tensor<u32>) -> Vec<Beam> {
334        let tokens: Vec<u32> = input_ids.to_vec().to_vec();
335        vec![Beam::new(tokens)]
336    }
337
338    /// Expands beams with new tokens and scores.
339    pub fn expand_beams(&self, beams: &[Beam], next_token_logits: &[Vec<f32>]) -> Vec<Beam> {
340        let mut candidates = Vec::new();
341
342        for (beam_idx, beam) in beams.iter().enumerate() {
343            if beam.finished {
344                candidates.push(beam.clone());
345                continue;
346            }
347
348            let logits = &next_token_logits[beam_idx];
349
350            // Get top-k tokens for this beam
351            let mut indexed: Vec<(usize, f32)> = logits.iter().enumerate()
352                .map(|(i, &v)| (i, v))
353                .collect();
354            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
355
356            for (token, log_prob) in indexed.into_iter().take(self.num_beams * 2) {
357                let mut new_beam = beam.clone();
358                new_beam.tokens.push(token as u32);
359                new_beam.score += log_prob;
360
361                if self.eos_token_ids.contains(&(token as u32)) {
362                    new_beam.finished = true;
363                }
364
365                candidates.push(new_beam);
366            }
367        }
368
369        // Sort by score and keep top beams
370        candidates.sort_by(|a, b| {
371            b.normalized_score(self.length_penalty)
372                .partial_cmp(&a.normalized_score(self.length_penalty))
373                .unwrap()
374        });
375
376        candidates.into_iter().take(self.num_beams).collect()
377    }
378
379    /// Checks if search should stop.
380    pub fn should_stop(&self, beams: &[Beam]) -> bool {
381        if self.early_stopping {
382            beams.iter().all(|b| b.finished)
383        } else {
384            false
385        }
386    }
387
388    /// Returns the best completed sequence.
389    pub fn best_sequence(&self, beams: &[Beam]) -> Option<Vec<u32>> {
390        beams
391            .iter()
392            .filter(|b| b.finished)
393            .max_by(|a, b| {
394                a.normalized_score(self.length_penalty)
395                    .partial_cmp(&b.normalized_score(self.length_penalty))
396                    .unwrap()
397            })
398            .map(|b| b.tokens.clone())
399            .or_else(|| beams.first().map(|b| b.tokens.clone()))
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406
407    #[test]
408    fn test_generation_config_defaults() {
409        let config = GenerationConfig::default();
410        assert_eq!(config.max_new_tokens, 50);
411        assert_eq!(config.temperature, 1.0);
412        assert!(config.do_sample);
413    }
414
415    #[test]
416    fn test_greedy_config() {
417        let config = GenerationConfig::greedy();
418        assert!(!config.do_sample);
419    }
420
421    #[test]
422    fn test_top_k_filtering() {
423        let config = GenerationConfig::top_k_sampling(2, 1.0);
424        let generator = TextGenerator::new(config);
425
426        let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
427        generator.apply_top_k(&mut logits);
428
429        // Only top 2 should remain finite
430        let finite_count = logits.iter().filter(|x| x.is_finite()).count();
431        assert_eq!(finite_count, 2);
432    }
433
434    #[test]
435    fn test_temperature_scaling() {
436        let config = GenerationConfig::sampling(2.0);
437        let generator = TextGenerator::new(config);
438
439        let mut logits = vec![2.0, 4.0, 6.0];
440        generator.apply_temperature(&mut logits);
441
442        assert_eq!(logits, vec![1.0, 2.0, 3.0]);
443    }
444
445    #[test]
446    fn test_argmax() {
447        let config = GenerationConfig::greedy();
448        let generator = TextGenerator::new(config);
449
450        let logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
451        let result = generator.argmax(&logits);
452
453        assert_eq!(result, 1);
454    }
455
456    #[test]
457    fn test_repetition_penalty() {
458        let config = GenerationConfig::default().with_repetition_penalty(2.0);
459        let generator = TextGenerator::new(config);
460
461        let mut logits = vec![1.0, 2.0, 3.0, 4.0];
462        let generated = vec![1, 3];
463        generator.apply_repetition_penalty(&mut logits, &generated);
464
465        // Tokens 1 and 3 should be penalized
466        assert!(logits[1] < 2.0);
467        assert!(logits[3] < 4.0);
468    }
469
470    #[test]
471    fn test_beam_search_init() {
472        let beam_search = BeamSearch::new(3, 1.0, false, vec![0]);
473        let input_ids = Tensor::from_vec(vec![1u32, 2, 3], &[1, 3]).unwrap();
474        let beams = beam_search.init_beams(&input_ids);
475
476        assert_eq!(beams.len(), 1);
477        assert_eq!(beams[0].tokens, vec![1, 2, 3]);
478    }
479}