Skip to main content

deep_delta_learning/
generation.rs

1use std::fmt::{Display, Formatter};
2
3use burn::prelude::*;
4use burn::tensor::{Int, TensorData};
5use serde::{Deserialize, Serialize};
6
7use crate::utils::tensor_to_vec;
8
9pub trait AutoregressiveModel<B: Backend> {
10    fn forward_logits(
11        &self,
12        input_ids: Tensor<B, 2, Int>,
13        mask: Option<&Tensor<B, 3>>,
14    ) -> Tensor<B, 3>;
15
16    fn max_seq_len(&self) -> usize;
17}
18
19#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
20pub struct GenerationConfig {
21    pub max_new_tokens: usize,
22    pub do_sample: bool,
23    pub temperature: f32,
24    pub top_k: Option<usize>,
25    pub top_p: Option<f32>,
26    pub eos_token_id: Option<usize>,
27    pub seed: u64,
28}
29
30impl GenerationConfig {
31    pub fn new(max_new_tokens: usize) -> Self {
32        Self {
33            max_new_tokens,
34            do_sample: false,
35            temperature: 1.0,
36            top_k: None,
37            top_p: None,
38            eos_token_id: None,
39            seed: 0,
40        }
41    }
42
43    pub fn with_do_sample(mut self, do_sample: bool) -> Self {
44        self.do_sample = do_sample;
45        self
46    }
47
48    pub fn with_temperature(mut self, temperature: f32) -> Self {
49        self.temperature = temperature;
50        self
51    }
52
53    pub fn with_top_k(mut self, top_k: Option<usize>) -> Self {
54        self.top_k = top_k;
55        self
56    }
57
58    pub fn with_top_p(mut self, top_p: Option<f32>) -> Self {
59        self.top_p = top_p;
60        self
61    }
62
63    pub fn with_eos_token_id(mut self, eos_token_id: Option<usize>) -> Self {
64        self.eos_token_id = eos_token_id;
65        self
66    }
67
68    pub fn with_seed(mut self, seed: u64) -> Self {
69        self.seed = seed;
70        self
71    }
72
73    pub fn validate(&self) -> Result<(), GenerationError> {
74        if !(self.temperature.is_finite() && self.temperature > 0.0) {
75            return Err(GenerationError::InvalidTemperature(self.temperature));
76        }
77        if matches!(self.top_k, Some(0)) {
78            return Err(GenerationError::InvalidTopK(0));
79        }
80        if let Some(top_p) = self.top_p
81            && !(top_p.is_finite() && top_p > 0.0 && top_p <= 1.0)
82        {
83            return Err(GenerationError::InvalidTopP(top_p));
84        }
85
86        Ok(())
87    }
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
91pub enum FinishReason {
92    MaxNewTokens,
93    EosToken,
94}
95
96#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
97pub struct GenerationStep {
98    pub position: usize,
99    pub token_id: usize,
100    pub probability: f32,
101    pub logit: f32,
102}
103
104#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
105pub struct GenerationResult {
106    pub prompt_len: usize,
107    pub tokens: Vec<usize>,
108    pub steps: Vec<GenerationStep>,
109    pub finish_reason: FinishReason,
110}
111
112impl GenerationResult {
113    pub fn prompt_tokens(&self) -> &[usize] {
114        &self.tokens[..self.prompt_len]
115    }
116
117    pub fn generated_tokens(&self) -> &[usize] {
118        &self.tokens[self.prompt_len..]
119    }
120
121    pub fn generated_len(&self) -> usize {
122        self.tokens.len().saturating_sub(self.prompt_len)
123    }
124}
125
126#[derive(Debug, Clone, PartialEq)]
127pub enum GenerationError {
128    EmptyPrompt,
129    InvalidContextWindow(usize),
130    InvalidTemperature(f32),
131    InvalidTopK(usize),
132    InvalidTopP(f32),
133    EosTokenOutOfRange {
134        eos_token_id: usize,
135        vocab_size: usize,
136    },
137}
138
139impl Display for GenerationError {
140    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
141        match self {
142            Self::EmptyPrompt => write!(f, "generation prompt must contain at least one token"),
143            Self::InvalidContextWindow(window) => {
144                write!(
145                    f,
146                    "generation context window must be positive, got {window}"
147                )
148            }
149            Self::InvalidTemperature(temperature) => write!(
150                f,
151                "generation temperature must be finite and positive, got {temperature}"
152            ),
153            Self::InvalidTopK(top_k) => {
154                write!(
155                    f,
156                    "generation top_k must be positive when provided, got {top_k}"
157                )
158            }
159            Self::InvalidTopP(top_p) => {
160                write!(f, "generation top_p must be in (0, 1], got {top_p}")
161            }
162            Self::EosTokenOutOfRange {
163                eos_token_id,
164                vocab_size,
165            } => write!(
166                f,
167                "generation eos token {eos_token_id} is outside the model vocabulary size {vocab_size}"
168            ),
169        }
170    }
171}
172
173impl std::error::Error for GenerationError {}
174
175pub fn generate_tokens<M, B>(
176    model: &M,
177    prompt_tokens: &[usize],
178    config: &GenerationConfig,
179    device: &B::Device,
180) -> Result<GenerationResult, GenerationError>
181where
182    M: AutoregressiveModel<B>,
183    B: Backend,
184{
185    if prompt_tokens.is_empty() {
186        return Err(GenerationError::EmptyPrompt);
187    }
188    config.validate()?;
189
190    let context_window = model.max_seq_len();
191    if context_window == 0 {
192        return Err(GenerationError::InvalidContextWindow(context_window));
193    }
194
195    let mut tokens = prompt_tokens.to_vec();
196    let mut steps = Vec::with_capacity(config.max_new_tokens);
197    let mut rng = SamplerRng::new(config.seed);
198    let mut finish_reason = FinishReason::MaxNewTokens;
199
200    for _ in 0..config.max_new_tokens {
201        let context_start = tokens.len().saturating_sub(context_window);
202        let context = &tokens[context_start..];
203        let input_ids = prompt_tensor::<B>(context, device);
204        let logits = model.forward_logits(input_ids, None);
205        let next_logits = last_token_logits(logits);
206
207        if let Some(eos_token_id) = config.eos_token_id
208            && eos_token_id >= next_logits.len()
209        {
210            return Err(GenerationError::EosTokenOutOfRange {
211                eos_token_id,
212                vocab_size: next_logits.len(),
213            });
214        }
215
216        let sampled = select_next_token(&next_logits, config, &mut rng);
217        tokens.push(sampled.token_id);
218        steps.push(GenerationStep {
219            position: tokens.len() - 1,
220            token_id: sampled.token_id,
221            probability: sampled.probability,
222            logit: sampled.logit,
223        });
224
225        if Some(sampled.token_id) == config.eos_token_id {
226            finish_reason = FinishReason::EosToken;
227            break;
228        }
229    }
230
231    Ok(GenerationResult {
232        prompt_len: prompt_tokens.len(),
233        tokens,
234        steps,
235        finish_reason,
236    })
237}
238
239fn prompt_tensor<B: Backend>(tokens: &[usize], device: &B::Device) -> Tensor<B, 2, Int> {
240    let data = TensorData::new(
241        tokens.iter().map(|token| *token as i64).collect::<Vec<_>>(),
242        [1, tokens.len()],
243    );
244    Tensor::<B, 2, Int>::from_data(data, device)
245}
246
247fn last_token_logits<B: Backend>(logits: Tensor<B, 3>) -> Vec<f32> {
248    let [_, seq_len, vocab_size] = logits.dims();
249    let last = logits
250        .slice([0..1, seq_len - 1..seq_len, 0..vocab_size])
251        .reshape([vocab_size]);
252    tensor_to_vec(last)
253}
254
255#[derive(Debug, Clone, Copy)]
256struct Candidate {
257    token_id: usize,
258    raw_logit: f32,
259    adjusted_logit: f32,
260}
261
262#[derive(Debug, Clone, Copy)]
263struct SampledToken {
264    token_id: usize,
265    probability: f32,
266    logit: f32,
267}
268
269fn select_next_token(
270    logits: &[f32],
271    config: &GenerationConfig,
272    rng: &mut SamplerRng,
273) -> SampledToken {
274    let mut candidates = logits
275        .iter()
276        .enumerate()
277        .map(|(token_id, raw_logit)| Candidate {
278            token_id,
279            raw_logit: *raw_logit,
280            adjusted_logit: *raw_logit / config.temperature,
281        })
282        .collect::<Vec<_>>();
283
284    if !config.do_sample {
285        let probabilities = normalized_probabilities(&candidates);
286        let best = candidates
287            .iter()
288            .max_by(|left, right| left.adjusted_logit.total_cmp(&right.adjusted_logit))
289            .copied()
290            .expect("generation candidates must not be empty");
291        let probability = probabilities
292            .iter()
293            .find(|(token_id, _)| *token_id == best.token_id)
294            .map(|(_, probability)| *probability)
295            .unwrap_or(1.0);
296
297        return SampledToken {
298            token_id: best.token_id,
299            probability,
300            logit: best.raw_logit,
301        };
302    }
303
304    if let Some(top_k) = config.top_k {
305        candidates.sort_by(|left, right| right.adjusted_logit.total_cmp(&left.adjusted_logit));
306        candidates.truncate(top_k.min(candidates.len()));
307    }
308
309    candidates.sort_by(|left, right| right.adjusted_logit.total_cmp(&left.adjusted_logit));
310    if let Some(top_p) = config.top_p {
311        let probabilities = normalized_probabilities(&candidates);
312        let mut cumulative = 0.0;
313        let mut keep = 0usize;
314        for (_, probability) in probabilities {
315            keep += 1;
316            cumulative += probability;
317            if cumulative >= top_p {
318                break;
319            }
320        }
321        candidates.truncate(keep.max(1));
322    }
323
324    let probabilities = normalized_probabilities(&candidates);
325    let draw = rng.next_f32();
326    let mut cumulative = 0.0;
327    for (idx, (token_id, probability)) in probabilities.iter().enumerate() {
328        cumulative += *probability;
329        if draw <= cumulative || idx + 1 == probabilities.len() {
330            let candidate = candidates
331                .iter()
332                .find(|candidate| candidate.token_id == *token_id)
333                .copied()
334                .expect("sampled token must be present in candidate set");
335            return SampledToken {
336                token_id: candidate.token_id,
337                probability: *probability,
338                logit: candidate.raw_logit,
339            };
340        }
341    }
342
343    unreachable!("generation sampling should always select a token");
344}
345
346fn normalized_probabilities(candidates: &[Candidate]) -> Vec<(usize, f32)> {
347    let max_logit = candidates
348        .iter()
349        .map(|candidate| candidate.adjusted_logit)
350        .fold(f32::NEG_INFINITY, f32::max);
351    let mut sum = 0.0;
352    let exps = candidates
353        .iter()
354        .map(|candidate| {
355            let value = (candidate.adjusted_logit - max_logit).exp();
356            sum += value;
357            value
358        })
359        .collect::<Vec<_>>();
360
361    candidates
362        .iter()
363        .zip(exps)
364        .map(|(candidate, value)| (candidate.token_id, value / sum))
365        .collect()
366}
367
368#[derive(Debug, Clone, Copy)]
369struct SamplerRng {
370    state: u64,
371}
372
373impl SamplerRng {
374    fn new(seed: u64) -> Self {
375        let state = if seed == 0 {
376            0x9E37_79B9_7F4A_7C15
377        } else {
378            seed
379        };
380        Self { state }
381    }
382
383    fn next_u64(&mut self) -> u64 {
384        let mut state = self.state;
385        state ^= state >> 12;
386        state ^= state << 25;
387        state ^= state >> 27;
388        self.state = state;
389        state.wrapping_mul(0x2545_F491_4F6C_DD1D)
390    }
391
392    fn next_f32(&mut self) -> f32 {
393        let value = self.next_u64() >> 40;
394        value as f32 / ((1u32 << 24) as f32)
395    }
396}