Skip to main content

hermes_llm/
generate.rs

1//! Text generation utilities for LLM inference.
2
3use anyhow::Result;
4use candle_core::{DType, Device, Tensor};
5use rand::Rng;
6
7use crate::model::Transformer;
8
9/// Text generator for autoregressive text generation.
10pub struct TextGenerator<'a> {
11    model: &'a Transformer,
12    device: &'a Device,
13}
14
15impl<'a> TextGenerator<'a> {
16    pub fn new(model: &'a Transformer, device: &'a Device) -> Self {
17        Self { model, device }
18    }
19
20    /// Generate tokens autoregressively from a prompt.
21    ///
22    /// # Arguments
23    /// * `prompt_tokens` - Initial token sequence
24    /// * `max_new_tokens` - Maximum number of new tokens to generate
25    /// * `temperature` - Sampling temperature (1.0 = no scaling)
26    /// * `top_k` - Optional top-k filtering
27    pub fn generate(
28        &self,
29        prompt_tokens: &[u32],
30        max_new_tokens: usize,
31        temperature: f64,
32        top_k: Option<usize>,
33    ) -> Result<Vec<u32>> {
34        let mut tokens = prompt_tokens.to_vec();
35        let mut rng = rand::rng();
36
37        for _ in 0..max_new_tokens {
38            let context_len = tokens.len().min(self.model.config().max_seq_len);
39            let context: Vec<u32> = tokens[tokens.len() - context_len..].to_vec();
40
41            let input = Tensor::new(context.as_slice(), self.device)?
42                .unsqueeze(0)?
43                .to_dtype(DType::U32)?;
44
45            let logits = self.model.forward(&input, 0, false)?;
46            // Shape: [1, seq_len, vocab] -> [1, 1, vocab] -> [vocab]
47            let logits = logits
48                .narrow(1, context_len - 1, 1)?
49                .squeeze(1)?
50                .squeeze(0)?;
51
52            let logits = if temperature != 1.0 {
53                logits.affine(1.0 / temperature, 0.0)?
54            } else {
55                logits
56            };
57
58            let logits = if let Some(k) = top_k {
59                top_k_filter(&logits, k, self.device)?
60            } else {
61                logits
62            };
63
64            let next_token = sample_from_logits(&logits, &mut rng)?;
65            tokens.push(next_token);
66        }
67
68        Ok(tokens)
69    }
70}
71
72/// Apply top-k filtering to logits.
73fn top_k_filter(logits: &Tensor, k: usize, device: &Device) -> Result<Tensor> {
74    let logits_vec: Vec<f32> = logits.to_vec1()?;
75    let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
76    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
77
78    let mut masked = vec![f32::NEG_INFINITY; logits_vec.len()];
79    for i in 0..k.min(indexed.len()) {
80        masked[indexed[i].0] = indexed[i].1;
81    }
82
83    Ok(Tensor::new(masked, device)?)
84}
85
86/// Sample a token from logits using the probability distribution.
87fn sample_from_logits(logits: &Tensor, rng: &mut impl Rng) -> Result<u32> {
88    let probs = candle_nn::ops::softmax_last_dim(logits)?;
89    let probs_vec: Vec<f32> = probs.to_vec1()?;
90
91    let cumsum: Vec<f32> = probs_vec
92        .iter()
93        .scan(0.0, |acc, &x| {
94            *acc += x;
95            Some(*acc)
96        })
97        .collect();
98
99    let r: f32 = rng.random();
100    let next_token = cumsum.iter().position(|&p| p > r).unwrap_or(0) as u32;
101
102    Ok(next_token)
103}
104
105/// Cosine learning rate schedule with warmup.
106pub fn get_lr_with_warmup(
107    step: usize,
108    warmup_steps: usize,
109    max_lr: f64,
110    min_lr: f64,
111    total_steps: usize,
112) -> f64 {
113    if step < warmup_steps {
114        max_lr * (step as f64 / warmup_steps as f64)
115    } else {
116        let decay_ratio = (step - warmup_steps) as f64 / (total_steps - warmup_steps) as f64;
117        let coeff = 0.5 * (1.0 + (std::f64::consts::PI * decay_ratio).cos());
118        min_lr + coeff * (max_lr - min_lr)
119    }
120}