trueno/inference/
generate.rs1use crate::error::TruenoError;
4use crate::inference::model::{ForwardArena, KvCache, LlamaModel};
5
6#[derive(Debug, Clone)]
8pub struct SampleParams {
9 pub temperature: f32,
10 pub top_k: usize,
11 pub top_p: f32,
12 pub seed: u64,
13}
14
15impl Default for SampleParams {
16 fn default() -> Self {
17 Self { temperature: 0.7, top_k: 40, top_p: 0.9, seed: 42 }
18 }
19}
20
21pub struct Rng(u64);
23
24impl Rng {
25 fn new(seed: u64) -> Self {
26 Self(seed.max(1))
27 }
28
29 fn next_f32(&mut self) -> f32 {
30 self.0 ^= self.0 << 13;
31 self.0 ^= self.0 >> 7;
32 self.0 ^= self.0 << 17;
33 (self.0 as f32) / (u64::MAX as f32)
34 }
35}
36
37pub fn sample_token(logits: &[f32], params: &SampleParams, rng: &mut Rng) -> u32 {
39 let vocab_size = logits.len();
40
41 if params.temperature <= 0.0 {
42 return logits
44 .iter()
45 .enumerate()
46 .max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
47 .map(|(i, _)| i as u32)
48 .unwrap_or(0);
49 }
50
51 let inv_temp = 1.0 / params.temperature;
53 let mut scaled: Vec<(usize, f32)> =
54 logits.iter().enumerate().map(|(i, &v)| (i, v * inv_temp)).collect();
55
56 let k = params.top_k.min(vocab_size);
58 if k < vocab_size {
59 scaled.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
60 scaled.truncate(k);
61 }
62
63 let max_logit = scaled.iter().map(|x| x.1).fold(f32::NEG_INFINITY, f32::max);
65 let mut probs: Vec<(usize, f32)> =
66 scaled.iter().map(|&(i, v)| (i, (v - max_logit).exp())).collect();
67 let sum: f32 = probs.iter().map(|x| x.1).sum();
68 for p in &mut probs {
69 p.1 /= sum;
70 }
71
72 probs.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
74 let mut cumulative = 0.0f32;
75 let mut cutoff = probs.len();
76 for (i, &(_, prob)) in probs.iter().enumerate() {
77 cumulative += prob;
78 if cumulative >= params.top_p {
79 cutoff = i + 1;
80 break;
81 }
82 }
83 probs.truncate(cutoff);
84
85 let sum2: f32 = probs.iter().map(|x| x.1).sum();
87 for p in &mut probs {
88 p.1 /= sum2;
89 }
90
91 let r = rng.next_f32();
93 let mut cum = 0.0;
94 for &(idx, prob) in &probs {
95 cum += prob;
96 if r < cum {
97 return idx as u32;
98 }
99 }
100 probs.last().map(|&(i, _)| i as u32).unwrap_or(0)
101}
102
103pub fn generate(
108 model: &LlamaModel,
109 prompt_tokens: &[u32],
110 max_tokens: usize,
111 params: &SampleParams,
112 eos_token: u32,
113) -> Result<Vec<u32>, TruenoError> {
114 let mut kv_cache = KvCache::new(&model.config);
115 let mut arena = ForwardArena::new(&model.config);
116 let mut rng = Rng::new(params.seed);
117 let mut generated = Vec::with_capacity(max_tokens);
118
119 let mut last_logits = Vec::new();
121 for (pos, &token_id) in prompt_tokens.iter().enumerate() {
122 last_logits = model.forward(token_id, pos, &mut kv_cache, &mut arena)?;
123 }
124
125 if last_logits.is_empty() {
126 return Err(TruenoError::InvalidInput("Empty prompt".into()));
127 }
128
129 let mut pos = prompt_tokens.len();
131 for _ in 0..max_tokens {
132 let token = sample_token(&last_logits, params, &mut rng);
133
134 if token == eos_token {
135 break;
136 }
137 if pos >= model.config.max_seq_len - 1 {
138 break;
139 }
140
141 generated.push(token);
142 last_logits = model.forward(token, pos, &mut kv_cache, &mut arena)?;
143 pos += 1;
144 }
145
146 Ok(generated)
147}
148
149#[cfg(test)]
150mod tests {
151 use super::*;
152
153 #[test]
154 fn test_greedy_sampling() {
155 let logits = vec![0.1, 0.5, 0.3, 0.9, 0.2];
156 let params = SampleParams { temperature: 0.0, ..Default::default() };
157 let mut rng = Rng::new(42);
158 assert_eq!(sample_token(&logits, ¶ms, &mut rng), 3); }
160
161 #[test]
162 fn test_temperature_sampling() {
163 let logits = vec![1.0, 2.0, 3.0];
164 let params = SampleParams { temperature: 1.0, top_k: 3, top_p: 1.0, seed: 42 };
165 let mut rng = Rng::new(42);
166 let token = sample_token(&logits, ¶ms, &mut rng);
167 assert!(token < 3);
168 }
169
170 #[test]
171 fn test_top_k_reduces_candidates() {
172 let mut logits = vec![0.0f32; 100];
173 logits[50] = 10.0;
174 logits[51] = 9.0;
175 let params = SampleParams { temperature: 1.0, top_k: 2, top_p: 1.0, seed: 42 };
176 let mut rng = Rng::new(42);
177 let token = sample_token(&logits, ¶ms, &mut rng);
178 assert!(token == 50 || token == 51);
179 }
180}