1use axonml_tensor::Tensor;
51use rand::Rng;
52use serde::{Deserialize, Serialize};
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct GenerationConfig {
61 pub max_new_tokens: usize,
63 pub temperature: f32,
65 pub top_k: Option<usize>,
67 pub top_p: Option<f32>,
69 pub repetition_penalty: f32,
71 pub eos_token_ids: Vec<u32>,
73 pub pad_token_id: Option<u32>,
75 pub do_sample: bool,
77 pub num_beams: usize,
79 pub length_penalty: f32,
81 pub early_stopping: bool,
83}
84
85impl Default for GenerationConfig {
86 fn default() -> Self {
87 Self {
88 max_new_tokens: 50,
89 temperature: 1.0,
90 top_k: None,
91 top_p: None,
92 repetition_penalty: 1.0,
93 eos_token_ids: vec![],
94 pad_token_id: None,
95 do_sample: true,
96 num_beams: 1,
97 length_penalty: 1.0,
98 early_stopping: false,
99 }
100 }
101}
102
103impl GenerationConfig {
104 pub fn greedy() -> Self {
106 Self {
107 do_sample: false,
108 temperature: 1.0,
109 top_k: None,
110 top_p: None,
111 ..Default::default()
112 }
113 }
114
115 pub fn sampling(temperature: f32) -> Self {
117 Self {
118 do_sample: true,
119 temperature,
120 ..Default::default()
121 }
122 }
123
124 pub fn top_k_sampling(k: usize, temperature: f32) -> Self {
126 Self {
127 do_sample: true,
128 temperature,
129 top_k: Some(k),
130 ..Default::default()
131 }
132 }
133
134 pub fn nucleus_sampling(p: f32, temperature: f32) -> Self {
136 Self {
137 do_sample: true,
138 temperature,
139 top_p: Some(p),
140 ..Default::default()
141 }
142 }
143
144 pub fn beam_search(num_beams: usize) -> Self {
146 Self {
147 do_sample: false,
148 num_beams,
149 ..Default::default()
150 }
151 }
152
153 pub fn with_max_tokens(mut self, max_new_tokens: usize) -> Self {
155 self.max_new_tokens = max_new_tokens;
156 self
157 }
158
159 pub fn with_eos_token(mut self, eos_token_id: u32) -> Self {
161 self.eos_token_ids.push(eos_token_id);
162 self
163 }
164
165 pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
167 self.repetition_penalty = penalty;
168 self
169 }
170}
171
172pub struct TextGenerator {
178 pub config: GenerationConfig,
180}
181
182impl TextGenerator {
183 pub fn new(config: GenerationConfig) -> Self {
185 Self { config }
186 }
187
188 pub fn apply_temperature(&self, logits: &mut [f32]) {
194 if self.config.temperature != 1.0 {
195 for logit in logits.iter_mut() {
196 *logit /= self.config.temperature;
197 }
198 }
199 }
200
201 pub fn apply_repetition_penalty(&self, logits: &mut [f32], generated_tokens: &[u32]) {
203 if self.config.repetition_penalty != 1.0 {
204 for &token in generated_tokens {
205 let idx = token as usize;
206 if idx < logits.len() {
207 if logits[idx] > 0.0 {
208 logits[idx] /= self.config.repetition_penalty;
209 } else {
210 logits[idx] *= self.config.repetition_penalty;
211 }
212 }
213 }
214 }
215 }
216
217 pub fn apply_top_k(&self, logits: &mut [f32]) {
219 if let Some(k) = self.config.top_k {
220 if k < logits.len() {
221 let mut sorted_indices: Vec<usize> = (0..logits.len()).collect();
223 sorted_indices.sort_by(|&a, &b| logits[b].partial_cmp(&logits[a]).unwrap());
224
225 let top_k_indices: std::collections::HashSet<usize> =
227 sorted_indices[..k].iter().copied().collect();
228
229 for (i, logit) in logits.iter_mut().enumerate() {
231 if !top_k_indices.contains(&i) {
232 *logit = f32::NEG_INFINITY;
233 }
234 }
235 }
236 }
237 }
238
239 pub fn apply_top_p(&self, logits: &mut [f32]) {
241 if let Some(p) = self.config.top_p {
242 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
244 let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
245 let sum_exp: f32 = exp_logits.iter().sum();
246 let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
247
248 let mut sorted_indices: Vec<usize> = (0..probs.len()).collect();
250 sorted_indices.sort_by(|&a, &b| probs[b].partial_cmp(&probs[a]).unwrap());
251
252 let mut cumsum = 0.0f32;
254 let mut cutoff_idx = sorted_indices.len();
255
256 for (i, &idx) in sorted_indices.iter().enumerate() {
257 cumsum += probs[idx];
258 if cumsum > p {
259 cutoff_idx = i + 1;
260 break;
261 }
262 }
263
264 for (i, logit) in logits.iter_mut().enumerate() {
266 if !sorted_indices[..cutoff_idx].contains(&i) {
267 *logit = f32::NEG_INFINITY;
268 }
269 }
270 }
271 }
272
273 pub fn sample(&self, logits: &[f32]) -> u32 {
279 if !self.config.do_sample {
280 return self.argmax(logits);
282 }
283
284 let mut rng = rand::thread_rng();
286
287 let max_logit = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
289 let exp_logits: Vec<f32> = logits.iter().map(|x| (x - max_logit).exp()).collect();
290 let sum_exp: f32 = exp_logits.iter().sum();
291 let probs: Vec<f32> = exp_logits.iter().map(|x| x / sum_exp).collect();
292
293 let mut cumsum = 0.0f32;
295 let sample: f32 = rng.r#gen();
296
297 for (i, &p) in probs.iter().enumerate() {
298 cumsum += p;
299 if sample < cumsum {
300 return i as u32;
301 }
302 }
303
304 (logits.len() - 1) as u32
306 }
307
308 pub fn argmax(&self, logits: &[f32]) -> u32 {
310 logits
311 .iter()
312 .enumerate()
313 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
314 .map(|(i, _)| i as u32)
315 .unwrap_or(0)
316 }
317
318 pub fn get_next_token(&self, logits: &[f32], generated_tokens: &[u32]) -> u32 {
327 let mut logits = logits.to_vec();
328
329 self.apply_repetition_penalty(&mut logits, generated_tokens);
331 self.apply_temperature(&mut logits);
332 self.apply_top_k(&mut logits);
333 self.apply_top_p(&mut logits);
334
335 self.sample(&logits)
337 }
338
339 pub fn generate_beam_search<F>(&self, initial_tokens: &[u32], get_logits_fn: &mut F) -> Vec<u32>
344 where
345 F: FnMut(&[u32]) -> Vec<f32>,
346 {
347 let beam_search = BeamSearch::new(
348 self.config.num_beams,
349 self.config.length_penalty,
350 self.config.early_stopping,
351 self.config.eos_token_ids.clone(),
352 );
353
354 let mut beams = vec![Beam::new(initial_tokens.to_vec())];
355
356 for _ in 0..self.config.max_new_tokens {
357 if beam_search.should_stop(&beams) {
358 break;
359 }
360
361 let mut all_logits = Vec::with_capacity(beams.len());
363 for beam in &beams {
364 if beam.finished {
365 all_logits.push(vec![0.0f32; 1]);
367 } else {
368 let logits = get_logits_fn(&beam.tokens);
369 all_logits.push(logits);
370 }
371 }
372
373 let log_prob_beams: Vec<Vec<f32>> = all_logits
375 .iter()
376 .map(|logits| {
377 let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
378 let exp_sum: f32 = logits.iter().map(|x| (x - max_l).exp()).sum();
379 let log_sum = max_l + exp_sum.ln();
380 logits.iter().map(|x| x - log_sum).collect()
381 })
382 .collect();
383
384 beams = beam_search.expand_beams(&beams, &log_prob_beams);
385 }
386
387 beam_search
388 .best_sequence(&beams)
389 .unwrap_or_else(|| initial_tokens.to_vec())
390 }
391
392 pub fn should_stop(&self, token: u32) -> bool {
394 self.config.eos_token_ids.contains(&token)
395 }
396}
397
398#[derive(Debug, Clone)]
404pub struct Beam {
405 pub tokens: Vec<u32>,
407 pub score: f32,
409 pub finished: bool,
411}
412
413impl Beam {
414 pub fn new(initial_tokens: Vec<u32>) -> Self {
416 Self {
417 tokens: initial_tokens,
418 score: 0.0,
419 finished: false,
420 }
421 }
422
423 pub fn normalized_score(&self, length_penalty: f32) -> f32 {
425 let length = self.tokens.len() as f32;
426 self.score / length.powf(length_penalty)
427 }
428}
429
430pub struct BeamSearch {
432 pub num_beams: usize,
434 pub length_penalty: f32,
436 pub early_stopping: bool,
438 pub eos_token_ids: Vec<u32>,
440}
441
442impl BeamSearch {
443 pub fn new(
445 num_beams: usize,
446 length_penalty: f32,
447 early_stopping: bool,
448 eos_token_ids: Vec<u32>,
449 ) -> Self {
450 Self {
451 num_beams,
452 length_penalty,
453 early_stopping,
454 eos_token_ids,
455 }
456 }
457
458 pub fn init_beams(&self, input_ids: &Tensor<u32>) -> Vec<Beam> {
460 let tokens: Vec<u32> = input_ids.to_vec().to_vec();
461 vec![Beam::new(tokens)]
462 }
463
464 pub fn expand_beams(&self, beams: &[Beam], next_token_logits: &[Vec<f32>]) -> Vec<Beam> {
466 let mut candidates = Vec::new();
467
468 for (beam_idx, beam) in beams.iter().enumerate() {
469 if beam.finished {
470 candidates.push(beam.clone());
471 continue;
472 }
473
474 let logits = &next_token_logits[beam_idx];
475
476 let mut indexed: Vec<(usize, f32)> =
478 logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
479 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
480
481 for (token, log_prob) in indexed.into_iter().take(self.num_beams * 2) {
482 let mut new_beam = beam.clone();
483 new_beam.tokens.push(token as u32);
484 new_beam.score += log_prob;
485
486 if self.eos_token_ids.contains(&(token as u32)) {
487 new_beam.finished = true;
488 }
489
490 candidates.push(new_beam);
491 }
492 }
493
494 candidates.sort_by(|a, b| {
496 b.normalized_score(self.length_penalty)
497 .partial_cmp(&a.normalized_score(self.length_penalty))
498 .unwrap()
499 });
500
501 candidates.into_iter().take(self.num_beams).collect()
502 }
503
504 pub fn should_stop(&self, beams: &[Beam]) -> bool {
506 if self.early_stopping {
507 beams.iter().all(|b| b.finished)
508 } else {
509 false
510 }
511 }
512
513 pub fn best_sequence(&self, beams: &[Beam]) -> Option<Vec<u32>> {
515 beams
516 .iter()
517 .filter(|b| b.finished)
518 .max_by(|a, b| {
519 a.normalized_score(self.length_penalty)
520 .partial_cmp(&b.normalized_score(self.length_penalty))
521 .unwrap()
522 })
523 .map(|b| b.tokens.clone())
524 .or_else(|| beams.first().map(|b| b.tokens.clone()))
525 }
526}
527
528#[cfg(test)]
533mod tests {
534 use super::*;
535
536 #[test]
537 fn test_generation_config_defaults() {
538 let config = GenerationConfig::default();
539 assert_eq!(config.max_new_tokens, 50);
540 assert_eq!(config.temperature, 1.0);
541 assert!(config.do_sample);
542 }
543
544 #[test]
545 fn test_greedy_config() {
546 let config = GenerationConfig::greedy();
547 assert!(!config.do_sample);
548 }
549
550 #[test]
551 fn test_top_k_filtering() {
552 let config = GenerationConfig::top_k_sampling(2, 1.0);
553 let generator = TextGenerator::new(config);
554
555 let mut logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
556 generator.apply_top_k(&mut logits);
557
558 let finite_count = logits.iter().filter(|x| x.is_finite()).count();
560 assert_eq!(finite_count, 2);
561 }
562
563 #[test]
564 fn test_temperature_scaling() {
565 let config = GenerationConfig::sampling(2.0);
566 let generator = TextGenerator::new(config);
567
568 let mut logits = vec![2.0, 4.0, 6.0];
569 generator.apply_temperature(&mut logits);
570
571 assert_eq!(logits, vec![1.0, 2.0, 3.0]);
572 }
573
574 #[test]
575 fn test_argmax() {
576 let config = GenerationConfig::greedy();
577 let generator = TextGenerator::new(config);
578
579 let logits = vec![1.0, 5.0, 3.0, 4.0, 2.0];
580 let result = generator.argmax(&logits);
581
582 assert_eq!(result, 1);
583 }
584
585 #[test]
586 fn test_repetition_penalty() {
587 let config = GenerationConfig::default().with_repetition_penalty(2.0);
588 let generator = TextGenerator::new(config);
589
590 let mut logits = vec![1.0, 2.0, 3.0, 4.0];
591 let generated = vec![1, 3];
592 generator.apply_repetition_penalty(&mut logits, &generated);
593
594 assert!(logits[1] < 2.0);
596 assert!(logits[3] < 4.0);
597 }
598
599 #[test]
600 fn test_beam_search_init() {
601 let beam_search = BeamSearch::new(3, 1.0, false, vec![0]);
602 let input_ids = Tensor::from_vec(vec![1u32, 2, 3], &[1, 3]).unwrap();
603 let beams = beam_search.init_beams(&input_ids);
604
605 assert_eq!(beams.len(), 1);
606 assert_eq!(beams[0].tokens, vec![1, 2, 3]);
607 }
608
609 #[test]
610 fn test_beam_search_expand() {
611 let beam_search = BeamSearch::new(2, 1.0, false, vec![99]);
612
613 let initial = vec![Beam::new(vec![1, 2])];
614 let logits = vec![vec![-10.0, -10.0, -10.0, 5.0, -10.0]];
616 let expanded = beam_search.expand_beams(&initial, &logits);
617
618 assert_eq!(expanded.len(), 2);
619 assert_eq!(*expanded[0].tokens.last().unwrap(), 3);
621 }
622
623 #[test]
624 fn test_generate_beam_search() {
625 let config = GenerationConfig::beam_search(3)
626 .with_max_tokens(5)
627 .with_eos_token(4);
628 let generator = TextGenerator::new(config);
629
630 let mut step = 0;
631 let result = generator.generate_beam_search(&[1, 2], &mut |_tokens| {
632 step += 1;
633 if step >= 3 {
635 vec![-10.0, -10.0, -10.0, -10.0, 10.0] } else {
637 vec![-10.0, -10.0, -10.0, 10.0, -10.0] }
639 });
640
641 assert!(result.len() > 2);
643 assert_eq!(result[0], 1);
644 assert_eq!(result[1], 2);
645 }
646}