1use axonml_tensor::Tensor;
6use serde::{Serialize, Deserialize};
7use rand::Rng;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct GenerationConfig {
12 pub max_new_tokens: usize,
14 pub temperature: f32,
16 pub top_k: Option<usize>,
18 pub top_p: Option<f32>,
20 pub repetition_penalty: f32,
22 pub eos_token_ids: Vec<u32>,
24 pub pad_token_id: Option<u32>,
26 pub do_sample: bool,
28 pub num_beams: usize,
30 pub length_penalty: f32,
32 pub early_stopping: bool,
34}
35
36impl Default for GenerationConfig {
37 fn default() -> Self {
38 Self {
39 max_new_tokens: 50,
40 temperature: 1.0,
41 top_k: None,
42 top_p: None,
43 repetition_penalty: 1.0,
44 eos_token_ids: vec![],
45 pad_token_id: None,
46 do_sample: true,
47 num_beams: 1,
48 length_penalty: 1.0,
49 early_stopping: false,
50 }
51 }
52}
53
54impl GenerationConfig {
55 pub fn greedy() -> Self {
57 Self {
58 do_sample: false,
59 temperature: 1.0,
60 top_k: None,
61 top_p: None,
62 ..Default::default()
63 }
64 }
65
66 pub fn sampling(temperature: f32) -> Self {
68 Self {
69 do_sample: true,
70 temperature,
71 ..Default::default()
72 }
73 }
74
75 pub fn top_k_sampling(k: usize, temperature: f32) -> Self {
77 Self {
78 do_sample: true,
79 temperature,
80 top_k: Some(k),
81 ..Default::default()
82 }
83 }
84
85 pub fn nucleus_sampling(p: f32, temperature: f32) -> Self {
87 Self {
88 do_sample: true,
89 temperature,
90 top_p: Some(p),
91 ..Default::default()
92 }
93 }
94
95 pub fn beam_search(num_beams: usize) -> Self {
97 Self {
98 do_sample: false,
99 num_beams,
100 ..Default::default()
101 }
102 }
103
104 pub fn with_max_tokens(mut self, max_new_tokens: usize) -> Self {
106 self.max_new_tokens = max_new_tokens;
107 self
108 }
109
110 pub fn with_eos_token(mut self, eos_token_id: u32) -> Self {
112 self.eos_token_ids.push(eos_token_id);
113 self
114 }
115
116 pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
118 self.repetition_penalty = penalty;
119 self
120 }
121}
122
123pub struct TextGenerator {
125 pub config: GenerationConfig,
127}
128
129impl TextGenerator {
130 pub fn new(config: GenerationConfig) -> Self {
132 Self { config }
133 }
134
135 pub fn apply_temperature(&self, logits: &mut [f32]) {
137 if self.config.temperature != 1.0 {
138 for logit in logits.iter_mut() {
139 *logit /= self.config.temperature;
140 }
141 }
142 }
143
144 pub fn apply_repetition_penalty(&self, logits: &mut [f32], generated_tokens: &[u32]) {
146 if self.config.repetition_penalty != 1.0 {
147 for &token in generated_tokens {
148 let idx = token as usize;
149 if idx < logits.len() {
150 if logits[idx] > 0.0 {
151 logits[idx] /= self.config.repetition_penalty;
152 } else {
153 logits[idx] *= self.config.repetition_penalty;
154 }
155 }
156 }
157 }
158 }
159
160 pub fn apply_top_k(&self, logits: &mut [f32]) {
162 if let Some(k) = self.config.top_k {
163 if k < logits.len() {
164 let mut sorted_indices: Vec<usize> = (0..logits.len()).collect();
166 sorted_indices.sort_by(|&a, &b| {
167 logits[b].partial_cmp(&logits[a]).unwrap()
168 });
169
170 let top_k_indices: std::collections::HashSet<usize> =
172 sorted_indices[..k].iter().copied().collect();
173
174 for (i, logit) in logits.iter_mut().enumerate() {
176 if !top_k_indices.contains(&i) {
177 *logit = f32::NEG_INFINITY;
178 }
179 }
180 }
181 }
182 }
183
184 pub fn apply_top_p(&self, logits: &mut [f32]) {
186 if let Some(p) = self.config.top_p {
187 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
189 let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
190 let sum_exp: f32 = exp_logits.iter().sum();
191 let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
192
193 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
195 sorted_indices.sort_by(|&a, &b| {
196 probs[b].partial_cmp(&probs[a]).unwrap()
197 });
198
199 let mut cumsum = 0.0f32;
201 let mut cutoff_idx = sorted_indices.len();
202
203 for (i, &idx) in sorted_indices.iter().enumerate() {
204 cumsum += probs[idx];
205 if cumsum > p {
206 cutoff_idx = i + 1;
207 break;
208 }
209 }
210
211 for (i, logit) in logits.iter_mut().enumerate() {
213 if !sorted_indices[..cutoff_idx].contains(&i) {
214 *logit = f32::NEG_INFINITY;
215 }
216 }
217 }
218 }
219
220 pub fn sample(&self, logits: &[f32]) -> u32 {
222 if !self.config.do_sample {
223 return self.argmax(logits);
225 }
226
227 let mut rng = rand::thread_rng();
229
230 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
232 let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
233 let sum_exp: f32 = exp_logits.iter().sum();
234 let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
235
236 let mut cumsum = 0.0f32;
238 let sample: f32 = rng.gen();
239
240 for (i, &p) in probs.iter().enumerate() {
241 cumsum += p;
242 if sample < cumsum {
243 return i as u32;
244 }
245 }
246
247 (logits.len() - 1) as u32
249 }
250
251 pub fn argmax(&self, logits: &[f32]) -> u32 {
253 logits
254 .iter()
255 .enumerate()
256 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
257 .map(|(i, _)| i as u32)
258 .unwrap_or(0)
259 }
260
261 pub fn get_next_token(&self, logits: &[f32], generated_tokens: &[u32]) -> u32 {
263 let mut logits = logits.to_vec();
264
265 self.apply_repetition_penalty(&mut logits, generated_tokens);
267 self.apply_temperature(&mut logits);
268 self.apply_top_k(&mut logits);
269 self.apply_top_p(&mut logits);
270
271 self.sample(&logits)
273 }
274
275 pub fn should_stop(&self, token: u32) -> bool {
277 self.config.eos_token_ids.contains(&token)
278 }
279}
280
281#[derive(Debug, Clone)]
283pub struct Beam {
284 pub tokens: Vec<u32>,
286 pub score: f32,
288 pub finished: bool,
290}
291
292impl Beam {
293 pub fn new(initial_tokens: Vec<u32>) -> Self {
295 Self {
296 tokens: initial_tokens,
297 score: 0.0,
298 finished: false,
299 }
300 }
301
302 pub fn normalized_score(&self, length_penalty: f32) -> f32 {
304 let length = self.tokens.len() as f32;
305 self.score / length.powf(length_penalty)
306 }
307}
308
309pub struct BeamSearch {
311 pub num_beams: usize,
313 pub length_penalty: f32,
315 pub early_stopping: bool,
317 pub eos_token_ids: Vec<u32>,
319}
320
321impl BeamSearch {
322 pub fn new(num_beams: usize, length_penalty: f32, early_stopping: bool, eos_token_ids: Vec<u32>) -> Self {
324 Self {
325 num_beams,
326 length_penalty,
327 early_stopping,
328 eos_token_ids,
329 }
330 }
331
332 pub fn init_beams(&self, input_ids: &Tensor<u32>) -> Vec<Beam> {
334 let tokens: Vec<u32> = input_ids.to_vec().to_vec();
335 vec![Beam::new(tokens)]
336 }
337
338 pub fn expand_beams(&self, beams: &[Beam], next_token_logits: &[Vec<f32>]) -> Vec<Beam> {
340 let mut candidates = Vec::new();
341
342 for (beam_idx, beam) in beams.iter().enumerate() {
343 if beam.finished {
344 candidates.push(beam.clone());
345 continue;
346 }
347
348 let logits = &next_token_logits[beam_idx];
349
350 let mut indexed: Vec<(usize, f32)> = logits.iter().enumerate()
352 .map(|(i, &v)| (i, v))
353 .collect();
354 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
355
356 for (token, log_prob) in indexed.into_iter().take(self.num_beams * 2) {
357 let mut new_beam = beam.clone();
358 new_beam.tokens.push(token as u32);
359 new_beam.score += log_prob;
360
361 if self.eos_token_ids.contains(&(token as u32)) {
362 new_beam.finished = true;
363 }
364
365 candidates.push(new_beam);
366 }
367 }
368
369 candidates.sort_by(|a, b| {
371 b.normalized_score(self.length_penalty)
372 .partial_cmp(&a.normalized_score(self.length_penalty))
373 .unwrap()
374 });
375
376 candidates.into_iter().take(self.num_beams).collect()
377 }
378
379 pub fn should_stop(&self, beams: &[Beam]) -> bool {
381 if self.early_stopping {
382 beams.iter().all(|b| b.finished)
383 } else {
384 false
385 }
386 }
387
388 pub fn best_sequence(&self, beams: &[Beam]) -> Option<Vec<u32>> {
390 beams
391 .iter()
392 .filter(|b| b.finished)
393 .max_by(|a, b| {
394 a.normalized_score(self.length_penalty)
395 .partial_cmp(&b.normalized_score(self.length_penalty))
396 .unwrap()
397 })
398 .map(|b| b.tokens.clone())
399 .or_else(|| beams.first().map(|b| b.tokens.clone()))
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406
407 #[test]
408 fn test_generation_config_defaults() {
409 let config = GenerationConfig::default();
410 assert_eq!(config.max_new_tokens, 50);
411 assert_eq!(config.temperature, 1.0);
412 assert!(config.do_sample);
413 }
414
415 #[test]
416 fn test_greedy_config() {
417 let config = GenerationConfig::greedy();
418 assert!(!config.do_sample);
419 }
420
421 #[test]
422 fn test_top_k_filtering() {
423 let config = GenerationConfig::top_k_sampling(2, 1.0);
424 let generator = TextGenerator::new(config);
425
426 let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
427 generator.apply_top_k(&mut logits);
428
429 let finite_count = logits.iter().filter(|x| x.is_finite()).count();
431 assert_eq!(finite_count, 2);
432 }
433
434 #[test]
435 fn test_temperature_scaling() {
436 let config = GenerationConfig::sampling(2.0);
437 let generator = TextGenerator::new(config);
438
439 let mut logits = vec![2.0, 4.0, 6.0];
440 generator.apply_temperature(&mut logits);
441
442 assert_eq!(logits, vec![1.0, 2.0, 3.0]);
443 }
444
445 #[test]
446 fn test_argmax() {
447 let config = GenerationConfig::greedy();
448 let generator = TextGenerator::new(config);
449
450 let logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
451 let result = generator.argmax(&logits);
452
453 assert_eq!(result, 1);
454 }
455
456 #[test]
457 fn test_repetition_penalty() {
458 let config = GenerationConfig::default().with_repetition_penalty(2.0);
459 let generator = TextGenerator::new(config);
460
461 let mut logits = vec![1.0, 2.0, 3.0, 4.0];
462 let generated = vec![1, 3];
463 generator.apply_repetition_penalty(&mut logits, &generated);
464
465 assert!(logits[1] < 2.0);
467 assert!(logits[3] < 4.0);
468 }
469
470 #[test]
471 fn test_beam_search_init() {
472 let beam_search = BeamSearch::new(3, 1.0, false, vec![0]);
473 let input_ids = Tensor::from_vec(vec![1u32, 2, 3], &[1, 3]).unwrap();
474 let beams = beam_search.init_beams(&input_ids);
475
476 assert_eq!(beams.len(), 1);
477 assert_eq!(beams[0].tokens, vec![1, 2, 3]);
478 }
479}