Skip to main content

trueno/inference/
generate.rs

1//! Token sampling and text generation loop.
2
3use crate::error::TruenoError;
4use crate::inference::model::{ForwardArena, KvCache, LlamaModel};
5
6/// Sampling parameters for text generation.
7#[derive(Debug, Clone)]
8pub struct SampleParams {
9    pub temperature: f32,
10    pub top_k: usize,
11    pub top_p: f32,
12    pub seed: u64,
13}
14
15impl Default for SampleParams {
16    fn default() -> Self {
17        Self { temperature: 0.7, top_k: 40, top_p: 0.9, seed: 42 }
18    }
19}
20
21/// Simple xorshift64 PRNG (no external dependency).
22pub struct Rng(u64);
23
24impl Rng {
25    fn new(seed: u64) -> Self {
26        Self(seed.max(1))
27    }
28
29    fn next_f32(&mut self) -> f32 {
30        self.0 ^= self.0 << 13;
31        self.0 ^= self.0 >> 7;
32        self.0 ^= self.0 << 17;
33        (self.0 as f32) / (u64::MAX as f32)
34    }
35}
36
37/// Sample a token from logits using temperature + top-k + top-p.
38pub fn sample_token(logits: &[f32], params: &SampleParams, rng: &mut Rng) -> u32 {
39    let vocab_size = logits.len();
40
41    if params.temperature <= 0.0 {
42        // Greedy: argmax
43        return logits
44            .iter()
45            .enumerate()
46            .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
47            .map(|(i, _)| i as u32)
48            .unwrap_or(0);
49    }
50
51    // Apply temperature
52    let inv_temp = 1.0 / params.temperature;
53    let mut scaled: Vec<(usize, f32)> =
54        logits.iter().enumerate().map(|(i, &v)| (i, v * inv_temp)).collect();
55
56    // Top-k: keep only top-k highest logits
57    let k = params.top_k.min(vocab_size);
58    if k < vocab_size {
59        scaled.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
60        scaled.truncate(k);
61    }
62
63    // Softmax over remaining candidates
64    let max_logit = scaled.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
65    let mut probs: Vec<(usize, f32)> =
66        scaled.iter().map(|&(i, v)| (i, (v - max_logit).exp())).collect();
67    let sum: f32 = probs.iter().map(|x| x.1).sum();
68    for p in &mut probs {
69        p.1 /= sum;
70    }
71
72    // Top-p (nucleus): accumulate until cumulative prob >= top_p
73    probs.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
74    let mut cumulative = 0.0f32;
75    let mut cutoff = probs.len();
76    for (i, &(_, prob)) in probs.iter().enumerate() {
77        cumulative += prob;
78        if cumulative >= params.top_p {
79            cutoff = i + 1;
80            break;
81        }
82    }
83    probs.truncate(cutoff);
84
85    // Renormalize
86    let sum2: f32 = probs.iter().map(|x| x.1).sum();
87    for p in &mut probs {
88        p.1 /= sum2;
89    }
90
91    // Sample from distribution
92    let r = rng.next_f32();
93    let mut cum = 0.0;
94    for &(idx, prob) in &probs {
95        cum += prob;
96        if r < cum {
97            return idx as u32;
98        }
99    }
100    probs.last().map(|&(i, _)| i as u32).unwrap_or(0)
101}
102
103/// Generate text tokens autoregressively.
104///
105/// Takes initial prompt token IDs and generates up to `max_tokens` more.
106/// Returns the generated token IDs (not including prompt).
107pub fn generate(
108    model: &LlamaModel,
109    prompt_tokens: &[u32],
110    max_tokens: usize,
111    params: &SampleParams,
112    eos_token: u32,
113) -> Result<Vec<u32>, TruenoError> {
114    let mut kv_cache = KvCache::new(&model.config);
115    let mut arena = ForwardArena::new(&model.config);
116    let mut rng = Rng::new(params.seed);
117    let mut generated = Vec::with_capacity(max_tokens);
118
119    // Prefill: process prompt tokens
120    let mut last_logits = Vec::new();
121    for (pos, &token_id) in prompt_tokens.iter().enumerate() {
122        last_logits = model.forward(token_id, pos, &mut kv_cache, &mut arena)?;
123    }
124
125    if last_logits.is_empty() {
126        return Err(TruenoError::InvalidInput("Empty prompt".into()));
127    }
128
129    // Decode: generate tokens one at a time
130    let mut pos = prompt_tokens.len();
131    for _ in 0..max_tokens {
132        let token = sample_token(&last_logits, params, &mut rng);
133
134        if token == eos_token {
135            break;
136        }
137        if pos >= model.config.max_seq_len - 1 {
138            break;
139        }
140
141        generated.push(token);
142        last_logits = model.forward(token, pos, &mut kv_cache, &mut arena)?;
143        pos += 1;
144    }
145
146    Ok(generated)
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_greedy_sampling() {
155        let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
156        let params = SampleParams { temperature: 0.0, ..Default::default() };
157        let mut rng = Rng::new(42);
158        assert_eq!(sample_token(&logits, &params, &mut rng), 3); // argmax
159    }
160
161    #[test]
162    fn test_temperature_sampling() {
163        let logits = vec![1.0, 2.0, 3.0];
164        let params = SampleParams { temperature: 1.0, top_k: 3, top_p: 1.0, seed: 42 };
165        let mut rng = Rng::new(42);
166        let token = sample_token(&logits, &params, &mut rng);
167        assert!(token < 3);
168    }
169
170    #[test]
171    fn test_top_k_reduces_candidates() {
172        let mut logits = vec![0.0f32; 100];
173        logits[50] = 10.0;
174        logits[51] = 9.0;
175        let params = SampleParams { temperature: 1.0, top_k: 2, top_p: 1.0, seed: 42 };
176        let mut rng = Rng::new(42);
177        let token = sample_token(&logits, &params, &mut rng);
178        assert!(token == 50 || token == 51);
179    }
180}