use crate::autograd::Tensor;
#[derive(Debug, Clone)]
pub struct GenerationConfig {
pub max_length: usize,
pub min_length: usize,
pub temperature: f32,
pub top_k: Option<usize>,
pub top_p: Option<f32>,
pub num_beams: usize,
pub length_penalty: f32,
pub repetition_penalty: f32,
pub early_stopping: bool,
pub eos_token_id: Option<usize>,
pub pad_token_id: Option<usize>,
}
impl Default for GenerationConfig {
fn default() -> Self {
Self {
max_length: 50,
min_length: 0,
temperature: 1.0,
top_k: None,
top_p: None,
num_beams: 1,
length_penalty: 1.0,
repetition_penalty: 1.0,
early_stopping: false,
eos_token_id: None,
pad_token_id: None,
}
}
}
impl GenerationConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_max_length(mut self, max_length: usize) -> Self {
self.max_length = max_length;
self
}
#[must_use]
pub fn with_temperature(mut self, temperature: f32) -> Self {
contract_pre_temperature_bounds!();
contract_pre_seed_determinism!();
self.temperature = temperature;
contract_post_temperature_bounds!(&self);
contract_post_seed_determinism!(&self);
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: usize) -> Self {
self.top_k = Some(top_k);
self
}
#[must_use]
pub fn with_top_p(mut self, top_p: f32) -> Self {
contract_pre_top_k_top_p_interaction!();
self.top_p = Some(top_p);
contract_post_top_k_top_p_interaction!(&self);
self
}
#[must_use]
pub fn with_num_beams(mut self, num_beams: usize) -> Self {
self.num_beams = num_beams;
self
}
#[must_use]
pub fn with_eos_token_id(mut self, eos_token_id: usize) -> Self {
self.eos_token_id = Some(eos_token_id);
self
}
}
#[derive(Debug, Clone)]
pub struct BeamHypothesis {
pub tokens: Vec<usize>,
pub score: f32,
pub is_done: bool,
}
impl BeamHypothesis {
#[must_use]
pub fn new(tokens: Vec<usize>, score: f32) -> Self {
Self {
tokens,
score,
is_done: false,
}
}
#[must_use]
pub fn normalized_score(&self, length_penalty: f32) -> f32 {
let len = self.tokens.len() as f32;
self.score / len.powf(length_penalty)
}
}
pub struct BeamSearch {
beam_size: usize,
length_penalty: f32,
early_stopping: bool,
eos_token_id: Option<usize>,
}
impl BeamSearch {
#[must_use]
pub fn new(beam_size: usize) -> Self {
Self {
beam_size,
length_penalty: 1.0,
early_stopping: false,
eos_token_id: None,
}
}
#[must_use]
pub fn with_length_penalty(mut self, penalty: f32) -> Self {
self.length_penalty = penalty;
self
}
#[must_use]
pub fn with_early_stopping(mut self) -> Self {
self.early_stopping = true;
self
}
#[must_use]
pub fn with_eos_token_id(mut self, eos_token_id: usize) -> Self {
self.eos_token_id = Some(eos_token_id);
self
}
#[must_use]
pub fn step(
&self,
log_probs: &Tensor,
current_beams: &[BeamHypothesis],
) -> Vec<BeamHypothesis> {
let vocab_size = log_probs.shape()[0];
let mut candidates: Vec<BeamHypothesis> = Vec::new();
for beam in current_beams {
if beam.is_done {
candidates.push(beam.clone());
continue;
}
for token_id in 0..vocab_size {
let token_score = log_probs.data()[token_id];
let new_score = beam.score + token_score;
let mut new_tokens = beam.tokens.clone();
new_tokens.push(token_id);
let mut new_beam = BeamHypothesis::new(new_tokens, new_score);
if Some(token_id) == self.eos_token_id {
new_beam.is_done = true;
}
candidates.push(new_beam);
}
}
candidates.sort_by(|a, b| {
b.normalized_score(self.length_penalty)
.partial_cmp(&a.normalized_score(self.length_penalty))
.unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(self.beam_size);
candidates
}
#[must_use]
pub fn init(&self, start_token: usize) -> Vec<BeamHypothesis> {
vec![BeamHypothesis::new(vec![start_token], 0.0)]
}
#[must_use]
pub fn all_done(&self, beams: &[BeamHypothesis]) -> bool {
beams.iter().all(|b| b.is_done)
}
#[must_use]
pub fn best(&self, beams: &[BeamHypothesis]) -> Option<BeamHypothesis> {
beams
.iter()
.max_by(|a, b| {
a.normalized_score(self.length_penalty)
.partial_cmp(&b.normalized_score(self.length_penalty))
.unwrap_or(std::cmp::Ordering::Equal)
})
.cloned()
}
#[must_use]
pub fn beam_size(&self) -> usize {
self.beam_size
}
#[must_use]
pub fn length_penalty(&self) -> f32 {
self.length_penalty
}
}
impl std::fmt::Debug for BeamSearch {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("BeamSearch")
.field("beam_size", &self.beam_size)
.field("length_penalty", &self.length_penalty)
.field("early_stopping", &self.early_stopping)
.field("eos_token_id", &self.eos_token_id)
.finish()
}
}
pub struct NucleusSampler {
pub(crate) top_p: f32,
pub(crate) temperature: f32,
pub(crate) min_tokens_to_keep: usize,
}
#[path = "nucleus_sampler.rs"]
mod nucleus_sampler;
pub use nucleus_sampler::*;
#[path = "teacher_forcing.rs"]
mod teacher_forcing;