Skip to main content

axonml_llm/
generation.rs

1//! Text Generation — Decoding Strategies for Autoregressive LMs
2//!
3//! Decoding primitives used by the LLM family. `GenerationConfig` holds the
4//! decoding knobs (`max_new_tokens`, `temperature`, `top_k`, `top_p`,
5//! `repetition_penalty`, `eos_token_ids`, `pad_token_id`, `do_sample`,
6//! `num_beams`, `length_penalty`, `early_stopping`) and exposes preset
7//! constructors: `greedy`, `sampling(T)`, `top_k_sampling(k, T)`,
8//! `nucleus_sampling(p, T)`, and `beam_search(n)` plus builder helpers
9//! `with_max_tokens`, `with_eos_token`, `with_repetition_penalty`.
10//! `TextGenerator` applies these knobs: `apply_temperature` divides logits by
11//! T, `apply_repetition_penalty` multiplicatively attenuates logits for
12//! already-seen tokens (dividing if positive, multiplying if negative),
13//! `apply_top_k` keeps the k highest logits and masks the rest to `-inf`,
14//! `apply_top_p` computes softmax probabilities and masks tokens outside the
15//! cumulative-probability nucleus. `sample` performs categorical sampling via
16//! softmax and a cumulative threshold from `rand::thread_rng()`; `argmax` is
17//! the greedy fallback. `get_next_token` chains the four modifiers and hands
18//! off to `sample`. Beam search is implemented through `Beam` (token
19//! sequence, cumulative log-prob score, finished flag, length-normalized
20//! score via `score / length^length_penalty`) and `BeamSearch`
21//! (`init_beams`, `expand_beams` that computes top-`2*num_beams` candidates
22//! per beam, marks EOS tokens finished, and keeps the best `num_beams` by
23//! normalized score; `should_stop` honors `early_stopping`; `best_sequence`
24//! picks the highest-scoring finished beam). `generate_beam_search` converts
25//! logits to log-probs via log-sum-exp, iterates up to `max_new_tokens`, and
26//! returns the best sequence. Tests cover default config, greedy, top-k
27//! filtering keeping two finite logits, temperature halving, argmax, per-
28//! token repetition penalty, beam initialization, single-step beam expansion,
29//! and end-to-end beam search with EOS termination.
30//!
31//! # File
32//! `crates/axonml-llm/src/generation.rs`
33//!
34//! # Author
35//! Andrew Jewell Sr. — AutomataNexus LLC
36//! ORCID: 0009-0005-2158-7060
37//!
38//! # Updated
39//! April 16, 2026 11:15 PM EST
40//!
41//! # Disclaimer
42//! Use at own risk. This software is provided "as is", without warranty of any
43//! kind, express or implied. The author and AutomataNexus shall not be held
44//! liable for any damages arising from the use of this software.
45
46// =============================================================================
47// Imports
48// =============================================================================
49
50use axonml_tensor::Tensor;
51use rand::Rng;
52use serde::{Deserialize, Serialize};
53
54// =============================================================================
55// GenerationConfig
56// =============================================================================
57
58/// Configuration for text generation.
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct GenerationConfig {
61    /// Maximum number of new tokens to generate
62    pub max_new_tokens: usize,
63    /// Temperature for sampling (1.0 = no change, <1.0 = more deterministic, >1.0 = more random)
64    pub temperature: f32,
65    /// Top-k sampling: only sample from top k tokens
66    pub top_k: Option<usize>,
67    /// Top-p (nucleus) sampling: sample from tokens with cumulative probability >= p
68    pub top_p: Option<f32>,
69    /// Repetition penalty (1.0 = no penalty, >1.0 = penalize repetition)
70    pub repetition_penalty: f32,
71    /// Stop token IDs
72    pub eos_token_ids: Vec<u32>,
73    /// Pad token ID
74    pub pad_token_id: Option<u32>,
75    /// Whether to do greedy decoding
76    pub do_sample: bool,
77    /// Number of beams for beam search (1 = no beam search)
78    pub num_beams: usize,
79    /// Length penalty for beam search
80    pub length_penalty: f32,
81    /// Early stopping for beam search
82    pub early_stopping: bool,
83}
84
85impl Default for GenerationConfig {
86    fn default() -> Self {
87        Self {
88            max_new_tokens: 50,
89            temperature: 1.0,
90            top_k: None,
91            top_p: None,
92            repetition_penalty: 1.0,
93            eos_token_ids: vec![],
94            pad_token_id: None,
95            do_sample: true,
96            num_beams: 1,
97            length_penalty: 1.0,
98            early_stopping: false,
99        }
100    }
101}
102
103impl GenerationConfig {
104    /// Creates a config for greedy decoding.
105    pub fn greedy() -> Self {
106        Self {
107            do_sample: false,
108            temperature: 1.0,
109            top_k: None,
110            top_p: None,
111            ..Default::default()
112        }
113    }
114
115    /// Creates a config for sampling with temperature.
116    pub fn sampling(temperature: f32) -> Self {
117        Self {
118            do_sample: true,
119            temperature,
120            ..Default::default()
121        }
122    }
123
124    /// Creates a config for top-k sampling.
125    pub fn top_k_sampling(k: usize, temperature: f32) -> Self {
126        Self {
127            do_sample: true,
128            temperature,
129            top_k: Some(k),
130            ..Default::default()
131        }
132    }
133
134    /// Creates a config for nucleus (top-p) sampling.
135    pub fn nucleus_sampling(p: f32, temperature: f32) -> Self {
136        Self {
137            do_sample: true,
138            temperature,
139            top_p: Some(p),
140            ..Default::default()
141        }
142    }
143
144    /// Creates a config for beam search.
145    pub fn beam_search(num_beams: usize) -> Self {
146        Self {
147            do_sample: false,
148            num_beams,
149            ..Default::default()
150        }
151    }
152
153    /// Sets the maximum number of new tokens.
154    pub fn with_max_tokens(mut self, max_new_tokens: usize) -> Self {
155        self.max_new_tokens = max_new_tokens;
156        self
157    }
158
159    /// Sets the EOS token ID.
160    pub fn with_eos_token(mut self, eos_token_id: u32) -> Self {
161        self.eos_token_ids.push(eos_token_id);
162        self
163    }
164
165    /// Sets the repetition penalty.
166    pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
167        self.repetition_penalty = penalty;
168        self
169    }
170}
171
172// =============================================================================
173// TextGenerator
174// =============================================================================
175
176/// Text generator for language models.
177pub struct TextGenerator {
178    /// Generation configuration
179    pub config: GenerationConfig,
180}
181
182impl TextGenerator {
183    /// Creates a new text generator.
184    pub fn new(config: GenerationConfig) -> Self {
185        Self { config }
186    }
187
188    // -------------------------------------------------------------------------
189    // Logit Modifiers
190    // -------------------------------------------------------------------------
191
192    /// Applies temperature scaling to logits.
193    pub fn apply_temperature(&self, logits: &mut [f32]) {
194        if self.config.temperature != 1.0 {
195            for logit in logits.iter_mut() {
196                *logit /= self.config.temperature;
197            }
198        }
199    }
200
201    /// Applies repetition penalty to logits.
202    pub fn apply_repetition_penalty(&self, logits: &mut [f32], generated_tokens: &[u32]) {
203        if self.config.repetition_penalty != 1.0 {
204            for &token in generated_tokens {
205                let idx = token as usize;
206                if idx < logits.len() {
207                    if logits[idx] > 0.0 {
208                        logits[idx] /= self.config.repetition_penalty;
209                    } else {
210                        logits[idx] *= self.config.repetition_penalty;
211                    }
212                }
213            }
214        }
215    }
216
217    /// Applies top-k filtering to logits.
218    pub fn apply_top_k(&self, logits: &mut [f32]) {
219        if let Some(k) = self.config.top_k {
220            if k < logits.len() {
221                // Find indices of top k values
222                let mut sorted_indices: Vec<usize> = (0..logits.len()).collect();
223                sorted_indices.sort_by(|&a, &b| logits[b].partial_cmp(&logits[a]).unwrap());
224
225                // Create a set of top-k indices
226                let top_k_indices: std::collections::HashSet<usize> =
227                    sorted_indices[..k].iter().copied().collect();
228
229                // Set all values not in top-k to -inf
230                for (i, logit) in logits.iter_mut().enumerate() {
231                    if !top_k_indices.contains(&i) {
232                        *logit = f32::NEG_INFINITY;
233                    }
234                }
235            }
236        }
237    }
238
239    /// Applies top-p (nucleus) filtering to logits.
240    pub fn apply_top_p(&self, logits: &mut [f32]) {
241        if let Some(p) = self.config.top_p {
242            // Convert to probabilities
243            let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
244            let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
245            let sum_exp: f32 = exp_logits.iter().sum();
246            let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
247
248            // Sort by probability
249            let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
250            sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
251
252            // Find cutoff
253            let mut cumsum = 0.0f32;
254            let mut cutoff_idx = sorted_indices.len();
255
256            for (i, &idx) in sorted_indices.iter().enumerate() {
257                cumsum += probs[idx];
258                if cumsum > p {
259                    cutoff_idx = i + 1;
260                    break;
261                }
262            }
263
264            // Set values outside nucleus to -inf
265            for (i, logit) in logits.iter_mut().enumerate() {
266                if !sorted_indices[..cutoff_idx].contains(&i) {
267                    *logit = f32::NEG_INFINITY;
268                }
269            }
270        }
271    }
272
273    // -------------------------------------------------------------------------
274    // Sampling
275    // -------------------------------------------------------------------------
276
277    /// Samples from logits distribution.
278    pub fn sample(&self, logits: &[f32]) -> u32 {
279        if !self.config.do_sample {
280            // Greedy: return argmax
281            return self.argmax(logits);
282        }
283
284        // Sample from distribution
285        let mut rng = rand::thread_rng();
286
287        // Softmax
288        let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
289        let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
290        let sum_exp: f32 = exp_logits.iter().sum();
291        let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
292
293        // Sample
294        let mut cumsum = 0.0f32;
295        let sample: f32 = rng.r#gen();
296
297        for (i, &p) in probs.iter().enumerate() {
298            cumsum += p;
299            if sample < cumsum {
300                return i as u32;
301            }
302        }
303
304        // Fallback to last token
305        (logits.len() - 1) as u32
306    }
307
308    /// Returns the index of the maximum value.
309    pub fn argmax(&self, logits: &[f32]) -> u32 {
310        logits
311            .iter()
312            .enumerate()
313            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
314            .map(|(i, _)| i as u32)
315            .unwrap_or(0)
316    }
317
318    // -------------------------------------------------------------------------
319    // Decoding Pipeline
320    // -------------------------------------------------------------------------
321
322    /// Processes logits and returns next token.
323    ///
324    /// If `num_beams > 1` in the config, use `generate_beam_search` instead
325    /// for proper beam search decoding.
326    pub fn get_next_token(&self, logits: &[f32], generated_tokens: &[u32]) -> u32 {
327        let mut logits = logits.to_vec();
328
329        // Apply modifiers
330        self.apply_repetition_penalty(&mut logits, generated_tokens);
331        self.apply_temperature(&mut logits);
332        self.apply_top_k(&mut logits);
333        self.apply_top_p(&mut logits);
334
335        // Sample
336        self.sample(&logits)
337    }
338
339    /// Generates a sequence using beam search.
340    ///
341    /// `get_logits_fn` takes a token sequence and returns logits [vocab_size].
342    /// Returns the best sequence found by beam search.
343    pub fn generate_beam_search<F>(&self, initial_tokens: &[u32], get_logits_fn: &mut F) -> Vec<u32>
344    where
345        F: FnMut(&[u32]) -> Vec<f32>,
346    {
347        let beam_search = BeamSearch::new(
348            self.config.num_beams,
349            self.config.length_penalty,
350            self.config.early_stopping,
351            self.config.eos_token_ids.clone(),
352        );
353
354        let mut beams = vec![Beam::new(initial_tokens.to_vec())];
355
356        for _ in 0..self.config.max_new_tokens {
357            if beam_search.should_stop(&beams) {
358                break;
359            }
360
361            // Get logits for each active beam
362            let mut all_logits = Vec::with_capacity(beams.len());
363            for beam in &beams {
364                if beam.finished {
365                    // Finished beams don't need logits, provide zeros
366                    all_logits.push(vec![0.0f32; 1]);
367                } else {
368                    let logits = get_logits_fn(&beam.tokens);
369                    all_logits.push(logits);
370                }
371            }
372
373            // Convert logits to log-probs for beam scoring
374            let log_prob_beams: Vec<Vec<f32>> = all_logits
375                .iter()
376                .map(|logits| {
377                    let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
378                    let exp_sum: f32 = logits.iter().map(|x| (x - max_l).exp()).sum();
379                    let log_sum = max_l + exp_sum.ln();
380                    logits.iter().map(|x| x - log_sum).collect()
381                })
382                .collect();
383
384            beams = beam_search.expand_beams(&beams, &log_prob_beams);
385        }
386
387        beam_search
388            .best_sequence(&beams)
389            .unwrap_or_else(|| initial_tokens.to_vec())
390    }
391
392    /// Checks if generation should stop.
393    pub fn should_stop(&self, token: u32) -> bool {
394        self.config.eos_token_ids.contains(&token)
395    }
396}
397
398// =============================================================================
399// Beam Search
400// =============================================================================
401
402/// Beam for beam search.
403#[derive(Debug, Clone)]
404pub struct Beam {
405    /// Token sequence
406    pub tokens: Vec<u32>,
407    /// Log probability score
408    pub score: f32,
409    /// Whether this beam has finished
410    pub finished: bool,
411}
412
413impl Beam {
414    /// Creates a new beam.
415    pub fn new(initial_tokens: Vec<u32>) -> Self {
416        Self {
417            tokens: initial_tokens,
418            score: 0.0,
419            finished: false,
420        }
421    }
422
423    /// Returns the normalized score (for length penalty).
424    pub fn normalized_score(&self, length_penalty: f32) -> f32 {
425        let length = self.tokens.len() as f32;
426        self.score / length.powf(length_penalty)
427    }
428}
429
430/// Beam search implementation.
431pub struct BeamSearch {
432    /// Number of beams
433    pub num_beams: usize,
434    /// Length penalty
435    pub length_penalty: f32,
436    /// Early stopping
437    pub early_stopping: bool,
438    /// EOS token IDs
439    pub eos_token_ids: Vec<u32>,
440}
441
442impl BeamSearch {
443    /// Creates a new beam search.
444    pub fn new(
445        num_beams: usize,
446        length_penalty: f32,
447        early_stopping: bool,
448        eos_token_ids: Vec<u32>,
449    ) -> Self {
450        Self {
451            num_beams,
452            length_penalty,
453            early_stopping,
454            eos_token_ids,
455        }
456    }
457
458    /// Initializes beams from input tokens.
459    pub fn init_beams(&self, input_ids: &Tensor<u32>) -> Vec<Beam> {
460        let tokens: Vec<u32> = input_ids.to_vec().to_vec();
461        vec![Beam::new(tokens)]
462    }
463
464    /// Expands beams with new tokens and scores.
465    pub fn expand_beams(&self, beams: &[Beam], next_token_logits: &[Vec<f32>]) -> Vec<Beam> {
466        let mut candidates = Vec::new();
467
468        for (beam_idx, beam) in beams.iter().enumerate() {
469            if beam.finished {
470                candidates.push(beam.clone());
471                continue;
472            }
473
474            let logits = &next_token_logits[beam_idx];
475
476            // Get top-k tokens for this beam
477            let mut indexed: Vec<(usize, f32)> =
478                logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
479            indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
480
481            for (token, log_prob) in indexed.into_iter().take(self.num_beams * 2) {
482                let mut new_beam = beam.clone();
483                new_beam.tokens.push(token as u32);
484                new_beam.score += log_prob;
485
486                if self.eos_token_ids.contains(&(token as u32)) {
487                    new_beam.finished = true;
488                }
489
490                candidates.push(new_beam);
491            }
492        }
493
494        // Sort by score and keep top beams
495        candidates.sort_by(|a, b| {
496            b.normalized_score(self.length_penalty)
497                .partial_cmp(&a.normalized_score(self.length_penalty))
498                .unwrap()
499        });
500
501        candidates.into_iter().take(self.num_beams).collect()
502    }
503
504    /// Checks if search should stop.
505    pub fn should_stop(&self, beams: &[Beam]) -> bool {
506        if self.early_stopping {
507            beams.iter().all(|b| b.finished)
508        } else {
509            false
510        }
511    }
512
513    /// Returns the best completed sequence.
514    pub fn best_sequence(&self, beams: &[Beam]) -> Option<Vec<u32>> {
515        beams
516            .iter()
517            .filter(|b| b.finished)
518            .max_by(|a, b| {
519                a.normalized_score(self.length_penalty)
520                    .partial_cmp(&b.normalized_score(self.length_penalty))
521                    .unwrap()
522            })
523            .map(|b| b.tokens.clone())
524            .or_else(|| beams.first().map(|b| b.tokens.clone()))
525    }
526}
527
528// =============================================================================
529// Tests
530// =============================================================================
531
532#[cfg(test)]
533mod tests {
534    use super::*;
535
536    #[test]
537    fn test_generation_config_defaults() {
538        let config = GenerationConfig::default();
539        assert_eq!(config.max_new_tokens, 50);
540        assert_eq!(config.temperature, 1.0);
541        assert!(config.do_sample);
542    }
543
544    #[test]
545    fn test_greedy_config() {
546        let config = GenerationConfig::greedy();
547        assert!(!config.do_sample);
548    }
549
550    #[test]
551    fn test_top_k_filtering() {
552        let config = GenerationConfig::top_k_sampling(2, 1.0);
553        let generator = TextGenerator::new(config);
554
555        let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
556        generator.apply_top_k(&mut logits);
557
558        // Only top 2 should remain finite
559        let finite_count = logits.iter().filter(|x| x.is_finite()).count();
560        assert_eq!(finite_count, 2);
561    }
562
563    #[test]
564    fn test_temperature_scaling() {
565        let config = GenerationConfig::sampling(2.0);
566        let generator = TextGenerator::new(config);
567
568        let mut logits = vec![2.0, 4.0, 6.0];
569        generator.apply_temperature(&mut logits);
570
571        assert_eq!(logits, vec![1.0, 2.0, 3.0]);
572    }
573
574    #[test]
575    fn test_argmax() {
576        let config = GenerationConfig::greedy();
577        let generator = TextGenerator::new(config);
578
579        let logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
580        let result = generator.argmax(&logits);
581
582        assert_eq!(result, 1);
583    }
584
585    #[test]
586    fn test_repetition_penalty() {
587        let config = GenerationConfig::default().with_repetition_penalty(2.0);
588        let generator = TextGenerator::new(config);
589
590        let mut logits = vec![1.0, 2.0, 3.0, 4.0];
591        let generated = vec![1, 3];
592        generator.apply_repetition_penalty(&mut logits, &generated);
593
594        // Tokens 1 and 3 should be penalized
595        assert!(logits[1] < 2.0);
596        assert!(logits[3] < 4.0);
597    }
598
599    #[test]
600    fn test_beam_search_init() {
601        let beam_search = BeamSearch::new(3, 1.0, false, vec![0]);
602        let input_ids = Tensor::from_vec(vec![1u32, 2, 3], &[1, 3]).unwrap();
603        let beams = beam_search.init_beams(&input_ids);
604
605        assert_eq!(beams.len(), 1);
606        assert_eq!(beams[0].tokens, vec![1, 2, 3]);
607    }
608
609    #[test]
610    fn test_beam_search_expand() {
611        let beam_search = BeamSearch::new(2, 1.0, false, vec![99]);
612
613        let initial = vec![Beam::new(vec![1, 2])];
614        // 5-token vocab, strongly favor token 3
615        let logits = vec![vec![-10.0, -10.0, -10.0, 5.0, -10.0]];
616        let expanded = beam_search.expand_beams(&initial, &logits);
617
618        assert_eq!(expanded.len(), 2);
619        // Best beam should end with token 3
620        assert_eq!(*expanded[0].tokens.last().unwrap(), 3);
621    }
622
623    #[test]
624    fn test_generate_beam_search() {
625        let config = GenerationConfig::beam_search(3)
626            .with_max_tokens(5)
627            .with_eos_token(4);
628        let generator = TextGenerator::new(config);
629
630        let mut step = 0;
631        let result = generator.generate_beam_search(&[1, 2], &mut |_tokens| {
632            step += 1;
633            // Always strongly prefer token 3, except at step 3 produce EOS
634            if step >= 3 {
635                vec![-10.0, -10.0, -10.0, -10.0, 10.0] // token 4 = EOS
636            } else {
637                vec![-10.0, -10.0, -10.0, 10.0, -10.0] // token 3
638            }
639        });
640
641        // Should have initial tokens + generated ones
642        assert!(result.len() > 2);
643        assert_eq!(result[0], 1);
644        assert_eq!(result[1], 2);
645    }
646}