1use anyhow::Result;
4use candle_core::{DType, Device, Tensor};
5use rand::Rng;
6
7use crate::model::Transformer;
8
9pub struct TextGenerator<'a> {
11 model: &'a Transformer,
12 device: &'a Device,
13}
14
15impl<'a> TextGenerator<'a> {
16 pub fn new(model: &'a Transformer, device: &'a Device) -> Self {
17 Self { model, device }
18 }
19
20 pub fn generate(
28 &self,
29 prompt_tokens: &[u32],
30 max_new_tokens: usize,
31 temperature: f64,
32 top_k: Option<usize>,
33 ) -> Result<Vec<u32>> {
34 let mut tokens = prompt_tokens.to_vec();
35 let mut rng = rand::rng();
36
37 for _ in 0..max_new_tokens {
38 let context_len = tokens.len().min(self.model.config().max_seq_len);
39 let context: Vec<u32> = tokens[tokens.len() - context_len..].to_vec();
40
41 let input = Tensor::new(context.as_slice(), self.device)?
42 .unsqueeze(0)?
43 .to_dtype(DType::U32)?;
44
45 let logits = self.model.forward(&input, 0, false)?;
46 let logits = logits
48 .narrow(1, context_len - 1, 1)?
49 .squeeze(1)?
50 .squeeze(0)?;
51
52 let logits = if temperature != 1.0 {
53 logits.affine(1.0 / temperature, 0.0)?
54 } else {
55 logits
56 };
57
58 let logits = if let Some(k) = top_k {
59 top_k_filter(&logits, k, self.device)?
60 } else {
61 logits
62 };
63
64 let next_token = sample_from_logits(&logits, &mut rng)?;
65 tokens.push(next_token);
66 }
67
68 Ok(tokens)
69 }
70}
71
72fn top_k_filter(logits: &Tensor, k: usize, device: &Device) -> Result<Tensor> {
74 let logits_vec: Vec<f32> = logits.to_vec1()?;
75 let mut indexed: Vec<(usize, f32)> = logits_vec.iter().copied().enumerate().collect();
76 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
77
78 let mut masked = vec![f32::NEG_INFINITY; logits_vec.len()];
79 for i in 0..k.min(indexed.len()) {
80 masked[indexed[i].0] = indexed[i].1;
81 }
82
83 Ok(Tensor::new(masked, device)?)
84}
85
86fn sample_from_logits(logits: &Tensor, rng: &mut impl Rng) -> Result<u32> {
88 let probs = candle_nn::ops::softmax_last_dim(logits)?;
89 let probs_vec: Vec<f32> = probs.to_vec1()?;
90
91 let cumsum: Vec<f32> = probs_vec
92 .iter()
93 .scan(0.0, |acc, &x| {
94 *acc += x;
95 Some(*acc)
96 })
97 .collect();
98
99 let r: f32 = rng.random();
100 let next_token = cumsum.iter().position(|&p| p > r).unwrap_or(0) as u32;
101
102 Ok(next_token)
103}
104
105pub fn get_lr_with_warmup(
107 step: usize,
108 warmup_steps: usize,
109 max_lr: f64,
110 min_lr: f64,
111 total_steps: usize,
112) -> f64 {
113 if step < warmup_steps {
114 max_lr * (step as f64 / warmup_steps as f64)
115 } else {
116 let decay_ratio = (step - warmup_steps) as f64 / (total_steps - warmup_steps) as f64;
117 let coeff = 0.5 * (1.0 + (std::f64::consts::PI * decay_ratio).cos());
118 min_lr + coeff * (max_lr - min_lr)
119 }
120}