1use std::fmt::{Display, Formatter};
2
3use burn::prelude::*;
4use burn::tensor::{Int, TensorData};
5use serde::{Deserialize, Serialize};
6
7use crate::utils::tensor_to_vec;
8
9pub trait AutoregressiveModel<B: Backend> {
10 fn forward_logits(
11 &self,
12 input_ids: Tensor<B, 2, Int>,
13 mask: Option<&Tensor<B, 3>>,
14 ) -> Tensor<B, 3>;
15
16 fn max_seq_len(&self) -> usize;
17}
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub struct GenerationConfig {
21 pub max_new_tokens: usize,
22 pub do_sample: bool,
23 pub temperature: f32,
24 pub top_k: Option<usize>,
25 pub top_p: Option<f32>,
26 pub eos_token_id: Option<usize>,
27 pub seed: u64,
28}
29
30impl GenerationConfig {
31 pub fn new(max_new_tokens: usize) -> Self {
32 Self {
33 max_new_tokens,
34 do_sample: false,
35 temperature: 1.0,
36 top_k: None,
37 top_p: None,
38 eos_token_id: None,
39 seed: 0,
40 }
41 }
42
43 pub fn with_do_sample(mut self, do_sample: bool) -> Self {
44 self.do_sample = do_sample;
45 self
46 }
47
48 pub fn with_temperature(mut self, temperature: f32) -> Self {
49 self.temperature = temperature;
50 self
51 }
52
53 pub fn with_top_k(mut self, top_k: Option<usize>) -> Self {
54 self.top_k = top_k;
55 self
56 }
57
58 pub fn with_top_p(mut self, top_p: Option<f32>) -> Self {
59 self.top_p = top_p;
60 self
61 }
62
63 pub fn with_eos_token_id(mut self, eos_token_id: Option<usize>) -> Self {
64 self.eos_token_id = eos_token_id;
65 self
66 }
67
68 pub fn with_seed(mut self, seed: u64) -> Self {
69 self.seed = seed;
70 self
71 }
72
73 pub fn validate(&self) -> Result<(), GenerationError> {
74 if !(self.temperature.is_finite() && self.temperature > 0.0) {
75 return Err(GenerationError::InvalidTemperature(self.temperature));
76 }
77 if matches!(self.top_k, Some(0)) {
78 return Err(GenerationError::InvalidTopK(0));
79 }
80 if let Some(top_p) = self.top_p
81 && !(top_p.is_finite() && top_p > 0.0 && top_p <= 1.0)
82 {
83 return Err(GenerationError::InvalidTopP(top_p));
84 }
85
86 Ok(())
87 }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
91pub enum FinishReason {
92 MaxNewTokens,
93 EosToken,
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub struct GenerationStep {
98 pub position: usize,
99 pub token_id: usize,
100 pub probability: f32,
101 pub logit: f32,
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct GenerationResult {
106 pub prompt_len: usize,
107 pub tokens: Vec<usize>,
108 pub steps: Vec<GenerationStep>,
109 pub finish_reason: FinishReason,
110}
111
112impl GenerationResult {
113 pub fn prompt_tokens(&self) -> &[usize] {
114 &self.tokens[..self.prompt_len]
115 }
116
117 pub fn generated_tokens(&self) -> &[usize] {
118 &self.tokens[self.prompt_len..]
119 }
120
121 pub fn generated_len(&self) -> usize {
122 self.tokens.len().saturating_sub(self.prompt_len)
123 }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub enum GenerationError {
128 EmptyPrompt,
129 InvalidContextWindow(usize),
130 InvalidTemperature(f32),
131 InvalidTopK(usize),
132 InvalidTopP(f32),
133 EosTokenOutOfRange {
134 eos_token_id: usize,
135 vocab_size: usize,
136 },
137}
138
139impl Display for GenerationError {
140 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
141 match self {
142 Self::EmptyPrompt => write!(f, "generation prompt must contain at least one token"),
143 Self::InvalidContextWindow(window) => {
144 write!(
145 f,
146 "generation context window must be positive, got {window}"
147 )
148 }
149 Self::InvalidTemperature(temperature) => write!(
150 f,
151 "generation temperature must be finite and positive, got {temperature}"
152 ),
153 Self::InvalidTopK(top_k) => {
154 write!(
155 f,
156 "generation top_k must be positive when provided, got {top_k}"
157 )
158 }
159 Self::InvalidTopP(top_p) => {
160 write!(f, "generation top_p must be in (0, 1], got {top_p}")
161 }
162 Self::EosTokenOutOfRange {
163 eos_token_id,
164 vocab_size,
165 } => write!(
166 f,
167 "generation eos token {eos_token_id} is outside the model vocabulary size {vocab_size}"
168 ),
169 }
170 }
171}
172
173impl std::error::Error for GenerationError {}
174
175pub fn generate_tokens<M, B>(
176 model: &M,
177 prompt_tokens: &[usize],
178 config: &GenerationConfig,
179 device: &B::Device,
180) -> Result<GenerationResult, GenerationError>
181where
182 M: AutoregressiveModel<B>,
183 B: Backend,
184{
185 if prompt_tokens.is_empty() {
186 return Err(GenerationError::EmptyPrompt);
187 }
188 config.validate()?;
189
190 let context_window = model.max_seq_len();
191 if context_window == 0 {
192 return Err(GenerationError::InvalidContextWindow(context_window));
193 }
194
195 let mut tokens = prompt_tokens.to_vec();
196 let mut steps = Vec::with_capacity(config.max_new_tokens);
197 let mut rng = SamplerRng::new(config.seed);
198 let mut finish_reason = FinishReason::MaxNewTokens;
199
200 for _ in 0..config.max_new_tokens {
201 let context_start = tokens.len().saturating_sub(context_window);
202 let context = &tokens[context_start..];
203 let input_ids = prompt_tensor::<B>(context, device);
204 let logits = model.forward_logits(input_ids, None);
205 let next_logits = last_token_logits(logits);
206
207 if let Some(eos_token_id) = config.eos_token_id
208 && eos_token_id >= next_logits.len()
209 {
210 return Err(GenerationError::EosTokenOutOfRange {
211 eos_token_id,
212 vocab_size: next_logits.len(),
213 });
214 }
215
216 let sampled = select_next_token(&next_logits, config, &mut rng);
217 tokens.push(sampled.token_id);
218 steps.push(GenerationStep {
219 position: tokens.len() - 1,
220 token_id: sampled.token_id,
221 probability: sampled.probability,
222 logit: sampled.logit,
223 });
224
225 if Some(sampled.token_id) == config.eos_token_id {
226 finish_reason = FinishReason::EosToken;
227 break;
228 }
229 }
230
231 Ok(GenerationResult {
232 prompt_len: prompt_tokens.len(),
233 tokens,
234 steps,
235 finish_reason,
236 })
237}
238
239fn prompt_tensor<B: Backend>(tokens: &[usize], device: &B::Device) -> Tensor<B, 2, Int> {
240 let data = TensorData::new(
241 tokens.iter().map(|token| *token as i64).collect::<Vec<_>>(),
242 [1, tokens.len()],
243 );
244 Tensor::<B, 2, Int>::from_data(data, device)
245}
246
247fn last_token_logits<B: Backend>(logits: Tensor<B, 3>) -> Vec<f32> {
248 let [_, seq_len, vocab_size] = logits.dims();
249 let last = logits
250 .slice([0..1, seq_len - 1..seq_len, 0..vocab_size])
251 .reshape([vocab_size]);
252 tensor_to_vec(last)
253}
254
255#[derive(Debug, Clone, Copy)]
256struct Candidate {
257 token_id: usize,
258 raw_logit: f32,
259 adjusted_logit: f32,
260}
261
262#[derive(Debug, Clone, Copy)]
263struct SampledToken {
264 token_id: usize,
265 probability: f32,
266 logit: f32,
267}
268
269fn select_next_token(
270 logits: &[f32],
271 config: &GenerationConfig,
272 rng: &mut SamplerRng,
273) -> SampledToken {
274 let mut candidates = logits
275 .iter()
276 .enumerate()
277 .map(|(token_id, raw_logit)| Candidate {
278 token_id,
279 raw_logit: *raw_logit,
280 adjusted_logit: *raw_logit / config.temperature,
281 })
282 .collect::<Vec<_>>();
283
284 if !config.do_sample {
285 let probabilities = normalized_probabilities(&candidates);
286 let best = candidates
287 .iter()
288 .max_by(|left, right| left.adjusted_logit.total_cmp(&right.adjusted_logit))
289 .copied()
290 .expect("generation candidates must not be empty");
291 let probability = probabilities
292 .iter()
293 .find(|(token_id, _)| *token_id == best.token_id)
294 .map(|(_, probability)| *probability)
295 .unwrap_or(1.0);
296
297 return SampledToken {
298 token_id: best.token_id,
299 probability,
300 logit: best.raw_logit,
301 };
302 }
303
304 if let Some(top_k) = config.top_k {
305 candidates.sort_by(|left, right| right.adjusted_logit.total_cmp(&left.adjusted_logit));
306 candidates.truncate(top_k.min(candidates.len()));
307 }
308
309 candidates.sort_by(|left, right| right.adjusted_logit.total_cmp(&left.adjusted_logit));
310 if let Some(top_p) = config.top_p {
311 let probabilities = normalized_probabilities(&candidates);
312 let mut cumulative = 0.0;
313 let mut keep = 0usize;
314 for (_, probability) in probabilities {
315 keep += 1;
316 cumulative += probability;
317 if cumulative >= top_p {
318 break;
319 }
320 }
321 candidates.truncate(keep.max(1));
322 }
323
324 let probabilities = normalized_probabilities(&candidates);
325 let draw = rng.next_f32();
326 let mut cumulative = 0.0;
327 for (idx, (token_id, probability)) in probabilities.iter().enumerate() {
328 cumulative += *probability;
329 if draw <= cumulative || idx + 1 == probabilities.len() {
330 let candidate = candidates
331 .iter()
332 .find(|candidate| candidate.token_id == *token_id)
333 .copied()
334 .expect("sampled token must be present in candidate set");
335 return SampledToken {
336 token_id: candidate.token_id,
337 probability: *probability,
338 logit: candidate.raw_logit,
339 };
340 }
341 }
342
343 unreachable!("generation sampling should always select a token");
344}
345
346fn normalized_probabilities(candidates: &[Candidate]) -> Vec<(usize, f32)> {
347 let max_logit = candidates
348 .iter()
349 .map(|candidate| candidate.adjusted_logit)
350 .fold(f32::NEG_INFINITY, f32::max);
351 let mut sum = 0.0;
352 let exps = candidates
353 .iter()
354 .map(|candidate| {
355 let value = (candidate.adjusted_logit - max_logit).exp();
356 sum += value;
357 value
358 })
359 .collect::<Vec<_>>();
360
361 candidates
362 .iter()
363 .zip(exps)
364 .map(|(candidate, value)| (candidate.token_id, value / sum))
365 .collect()
366}
367
368#[derive(Debug, Clone, Copy)]
369struct SamplerRng {
370 state: u64,
371}
372
373impl SamplerRng {
374 fn new(seed: u64) -> Self {
375 let state = if seed == 0 {
376 0x9E37_79B9_7F4A_7C15
377 } else {
378 seed
379 };
380 Self { state }
381 }
382
383 fn next_u64(&mut self) -> u64 {
384 let mut state = self.state;
385 state ^= state >> 12;
386 state ^= state << 25;
387 state ^= state >> 27;
388 self.state = state;
389 state.wrapping_mul(0x2545_F491_4F6C_DD1D)
390 }
391
392 fn next_f32(&mut self) -> f32 {
393 let value = self.next_u64() >> 40;
394 value as f32 / ((1u32 << 24) as f32)
395 }
396}