use crate::generation_utils::{GenerationConfig, GenerationMode, GenerationUtils, KVCache};
use crate::gpt2::model::{Gpt2LMHeadModel, Gpt2LMOutput};
use scirs2_core::ndarray::s;
use scirs2_core::random::*;
use trustformers_core::{
errors::{tensor_op_error, Result, TrustformersError},
tensor::Tensor,
traits::{Model, TokenizedInput},
};
pub trait GenerativeModel {
fn generate_with_config(
&self,
input_ids: Vec<u32>,
config: GenerationConfig,
) -> Result<Vec<Vec<u32>>>;
fn generate_greedy(&self, input_ids: Vec<u32>, max_length: usize) -> Result<Vec<u32>>;
fn generate_beam_search(
&self,
input_ids: Vec<u32>,
max_length: usize,
num_beams: usize,
) -> Result<Vec<Vec<u32>>>;
fn generate_top_k(
&self,
input_ids: Vec<u32>,
max_length: usize,
k: usize,
temperature: f32,
) -> Result<Vec<u32>>;
fn generate_top_p(
&self,
input_ids: Vec<u32>,
max_length: usize,
p: f32,
temperature: f32,
) -> Result<Vec<u32>>;
}
impl GenerativeModel for Gpt2LMHeadModel {
fn generate_with_config(
&self,
input_ids: Vec<u32>,
config: GenerationConfig,
) -> Result<Vec<Vec<u32>>> {
config.validate()?;
let mut rng = thread_rng();
let max_length = if let Some(max_new_tokens) = config.max_new_tokens {
input_ids.len() + max_new_tokens
} else {
config.max_length
};
match &config.mode {
GenerationMode::Greedy => {
let result = self.generate_greedy_internal(input_ids, max_length, &config)?;
Ok(vec![result])
},
GenerationMode::BeamSearch { num_beams } => {
self.generate_beam_search_internal(input_ids, max_length, *num_beams, &config)
},
GenerationMode::TopK { k } => {
let result = self.generate_sampling_internal(
input_ids,
max_length,
&config,
&mut rng,
|logits, rng| GenerationUtils::sample_top_k(logits, *k, rng),
)?;
Ok(vec![result])
},
GenerationMode::TopP { p } => {
let result = self.generate_sampling_internal(
input_ids,
max_length,
&config,
&mut rng,
|logits, rng| GenerationUtils::sample_top_p(logits, *p, rng),
)?;
Ok(vec![result])
},
GenerationMode::MinP { p } => {
let result = self.generate_sampling_internal(
input_ids,
max_length,
&config,
&mut rng,
|logits, rng| GenerationUtils::sample_min_p(logits, *p, rng),
)?;
Ok(vec![result])
},
GenerationMode::Temperature { temperature } => {
let mut temp_config = config.clone();
temp_config.temperature = *temperature;
let result = self.generate_sampling_internal(
input_ids,
max_length,
&temp_config,
&mut rng,
|logits, rng| {
let probs = GenerationUtils::softmax(logits);
GenerationUtils::sample_from_probs(&probs, rng).map(|idx| idx as u32)
},
)?;
Ok(vec![result])
},
GenerationMode::Combined { k, p } => {
let result = self.generate_sampling_internal(
input_ids,
max_length,
&config,
&mut rng,
|logits, rng| {
let mut indexed_logits: Vec<(usize, f32)> =
logits.iter().copied().enumerate().collect();
indexed_logits.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
indexed_logits.truncate(*k);
let top_k_logits: Vec<f32> =
indexed_logits.iter().map(|(_, logit)| *logit).collect();
let probs = GenerationUtils::softmax(&top_k_logits);
let mut cumsum = 0.0;
let mut cutoff = probs.len();
for (i, &prob) in probs.iter().enumerate() {
cumsum += prob;
if cumsum >= *p {
cutoff = i + 1;
break;
}
}
let nucleus_probs = &probs[..cutoff];
let sample_idx = GenerationUtils::sample_from_probs(nucleus_probs, rng)?;
Ok(indexed_logits[sample_idx].0 as u32)
},
)?;
Ok(vec![result])
},
GenerationMode::ContrastiveSearch { top_k: _, alpha: _ } => {
Err(TrustformersError::model_error(
"Contrastive search not yet implemented for GPT-2".to_string(),
))
},
}
}
fn generate_greedy(&self, input_ids: Vec<u32>, max_length: usize) -> Result<Vec<u32>> {
let config = GenerationConfig::greedy();
self.generate_greedy_internal(input_ids, max_length, &config)
}
fn generate_beam_search(
&self,
input_ids: Vec<u32>,
max_length: usize,
num_beams: usize,
) -> Result<Vec<Vec<u32>>> {
let config = GenerationConfig::beam_search(num_beams);
self.generate_beam_search_internal(input_ids, max_length, num_beams, &config)
}
fn generate_top_k(
&self,
input_ids: Vec<u32>,
max_length: usize,
k: usize,
temperature: f32,
) -> Result<Vec<u32>> {
let mut config = GenerationConfig::top_k(k);
config.temperature = temperature;
config.max_length = max_length;
let mut rng = thread_rng();
self.generate_sampling_internal(input_ids, max_length, &config, &mut rng, |logits, rng| {
GenerationUtils::sample_top_k(logits, k, rng)
})
}
fn generate_top_p(
&self,
input_ids: Vec<u32>,
max_length: usize,
p: f32,
temperature: f32,
) -> Result<Vec<u32>> {
let mut config = GenerationConfig::top_p(p);
config.temperature = temperature;
config.max_length = max_length;
let mut rng = thread_rng();
self.generate_sampling_internal(input_ids, max_length, &config, &mut rng, |logits, rng| {
GenerationUtils::sample_top_p(logits, p, rng)
})
}
}
impl Gpt2LMHeadModel {
fn generate_greedy_internal(
&self,
input_ids: Vec<u32>,
max_length: usize,
config: &GenerationConfig,
) -> Result<Vec<u32>> {
let mut generated = input_ids.clone();
let mut kv_cache = if !config.no_kv_cache { Some(KVCache::new()) } else { None };
while generated.len() < max_length {
if GenerationUtils::should_stop(&generated, config, generated.len()) {
break;
}
let mut logits = self.get_next_token_logits(&generated, kv_cache.as_mut())?;
GenerationUtils::apply_repetition_penalty(
&mut logits,
&generated,
config.repetition_penalty,
config.repetition_penalty_decay,
);
GenerationUtils::apply_frequency_penalty(
&mut logits,
&generated,
config.frequency_penalty,
);
GenerationUtils::apply_presence_penalty(
&mut logits,
&generated,
config.presence_penalty,
);
GenerationUtils::apply_bad_words_filter(&mut logits, &generated, &config.bad_words_ids);
let next_token = GenerationUtils::sample_greedy(&logits);
generated.push(next_token);
}
Ok(generated)
}
fn generate_sampling_internal<F, R>(
&self,
input_ids: Vec<u32>,
max_length: usize,
config: &GenerationConfig,
rng: &mut R,
sample_fn: F,
) -> Result<Vec<u32>>
where
F: Fn(&[f32], &mut R) -> Result<u32>,
R: Rng,
{
let mut generated = input_ids.clone();
let mut kv_cache = if !config.no_kv_cache { Some(KVCache::new()) } else { None };
while generated.len() < max_length {
if GenerationUtils::should_stop(&generated, config, generated.len()) {
break;
}
let mut logits = self.get_next_token_logits(&generated, kv_cache.as_mut())?;
GenerationUtils::apply_temperature(&mut logits, config.temperature);
GenerationUtils::apply_repetition_penalty(
&mut logits,
&generated,
config.repetition_penalty,
config.repetition_penalty_decay,
);
GenerationUtils::apply_frequency_penalty(
&mut logits,
&generated,
config.frequency_penalty,
);
GenerationUtils::apply_presence_penalty(
&mut logits,
&generated,
config.presence_penalty,
);
GenerationUtils::apply_bad_words_filter(&mut logits, &generated, &config.bad_words_ids);
let next_token = sample_fn(&logits, rng)?;
generated.push(next_token);
}
Ok(generated)
}
fn generate_beam_search_internal(
&self,
input_ids: Vec<u32>,
max_length: usize,
num_beams: usize,
config: &GenerationConfig,
) -> Result<Vec<Vec<u32>>> {
use crate::generation_utils::BeamHypothesis;
if num_beams == 1 {
let result = self.generate_greedy_internal(input_ids, max_length, config)?;
return Ok(vec![result]);
}
let mut beams: Vec<BeamHypothesis> = vec![BeamHypothesis::new(input_ids.clone(), 0.0)];
while beams[0].tokens.len() < max_length {
let mut candidates = Vec::new();
for beam in &beams {
if beam.finished {
candidates.push(beam.clone());
continue;
}
let mut logits = self.get_next_token_logits(&beam.tokens, None)?;
GenerationUtils::apply_repetition_penalty(
&mut logits,
&beam.tokens,
config.repetition_penalty,
config.repetition_penalty_decay,
);
GenerationUtils::apply_bad_words_filter(
&mut logits,
&beam.tokens,
&config.bad_words_ids,
);
let log_probs =
GenerationUtils::softmax(&logits).iter().map(|&p| p.ln()).collect::<Vec<_>>();
let mut token_scores: Vec<(f32, usize)> =
log_probs.iter().enumerate().map(|(idx, &log_prob)| (log_prob, idx)).collect();
token_scores
.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
for (log_prob, token_idx) in token_scores.iter().take(num_beams) {
let new_score = beam.score + log_prob;
let mut new_tokens = beam.tokens.clone();
new_tokens.push(*token_idx as u32);
let mut new_beam = BeamHypothesis::new(new_tokens.clone(), new_score);
if GenerationUtils::should_stop(&new_tokens, config, new_tokens.len()) {
new_beam.finished = true;
}
candidates.push(new_beam);
}
}
candidates.sort_by(|a, b| {
let a_score = a.normalized_score(config.length_penalty);
let b_score = b.normalized_score(config.length_penalty);
b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
});
beams = candidates.into_iter().take(num_beams).collect();
if config.early_stopping && beams.iter().all(|b| b.finished) {
break;
}
}
beams.sort_by(|a, b| {
let a_score = a.normalized_score(config.length_penalty);
let b_score = b.normalized_score(config.length_penalty);
b_score.partial_cmp(&a_score).unwrap_or(std::cmp::Ordering::Equal)
});
let num_return = config.num_return_sequences.min(beams.len());
Ok(beams.iter().take(num_return).map(|b| b.tokens.clone()).collect())
}
fn get_next_token_logits(
&self,
input_ids: &[u32],
_kv_cache: Option<&mut KVCache>,
) -> Result<Vec<f32>> {
let input = TokenizedInput {
input_ids: input_ids.to_vec(),
attention_mask: vec![1u8; input_ids.len()],
token_type_ids: None,
special_tokens_mask: None,
offset_mapping: None,
overflowing_tokens: None,
};
let output: Gpt2LMOutput = self.forward(input)?;
match &output.logits {
Tensor::F32(arr) => {
let shape = arr.shape();
if shape.len() != 3 {
return Err(tensor_op_error(
"tensor_operation",
format!("Expected 3D logits tensor, got {}D", shape.len()),
));
}
let seq_len = shape[1];
let vocab_size = shape[2];
let last_logits = arr.slice(s![0, seq_len - 1, ..]);
let logits_vec: Vec<f32> = last_logits.iter().copied().collect();
if logits_vec.len() != vocab_size {
return Err(tensor_op_error(
"tensor_operation",
format!(
"Logits size mismatch: expected {}, got {}",
vocab_size,
logits_vec.len()
),
));
}
Ok(logits_vec)
},
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported tensor type for logits".to_string(),
)),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gpt2::Gpt2Config;
#[test]
fn test_generation_config_integration() {
let config = GenerationConfig {
max_length: 50,
temperature: 0.8,
repetition_penalty: 1.2,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_greedy_generation_interface() -> Result<()> {
let mut gpt2_config = Gpt2Config::small();
gpt2_config.vocab_size = 50; gpt2_config.n_positions = 32; gpt2_config.n_embd = 32; gpt2_config.n_layer = 1; gpt2_config.n_head = 2;
let model = Gpt2LMHeadModel::new(gpt2_config)?;
let input_ids = vec![1, 2];
let max_length = 5;
let result = model.generate_greedy(input_ids, max_length);
assert!(
result.is_ok(),
"Greedy generation failed: {:?}",
result.err()
);
let generated = result?;
assert!(generated.len() <= max_length);
drop(generated);
drop(model);
std::hint::black_box(());
Ok(())
}
#[test]
fn test_generation_modes() {
let greedy = GenerationMode::Greedy;
let beam = GenerationMode::BeamSearch { num_beams: 5 };
let top_k = GenerationMode::TopK { k: 50 };
let top_p = GenerationMode::TopP { p: 0.9 };
assert!(matches!(greedy, GenerationMode::Greedy));
assert!(matches!(beam, GenerationMode::BeamSearch { .. }));
assert!(matches!(top_k, GenerationMode::TopK { .. }));
assert!(matches!(top_p, GenerationMode::TopP { .. }));
}
#[test]
fn test_generation_config_default() {
let config = GenerationConfig::default();
assert!(config.max_length > 0);
assert!(config.temperature > 0.0);
assert!(config.validate().is_ok());
}
#[test]
fn test_generation_config_validation_valid() {
let config = GenerationConfig {
max_length: 100,
temperature: 1.0,
repetition_penalty: 1.0,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_generation_config_custom_temperature() {
let config = GenerationConfig {
temperature: 0.5,
..Default::default()
};
assert!(config.validate().is_ok());
assert!((config.temperature - 0.5).abs() < f32::EPSILON);
}
#[test]
fn test_generation_config_with_repetition_penalty() {
let config = GenerationConfig {
repetition_penalty: 1.5,
..Default::default()
};
assert!(config.validate().is_ok());
assert!((config.repetition_penalty - 1.5).abs() < f32::EPSILON);
}
#[test]
fn test_generation_mode_greedy_debug() {
let mode = GenerationMode::Greedy;
let dbg = format!("{:?}", mode);
assert!(dbg.contains("Greedy"));
}
#[test]
fn test_generation_mode_beam_search_params() {
let mode = GenerationMode::BeamSearch { num_beams: 5 };
match mode {
GenerationMode::BeamSearch { num_beams } => assert_eq!(num_beams, 5),
_ => panic!("Expected BeamSearch"),
}
}
#[test]
fn test_generation_mode_top_k_params() {
let mode = GenerationMode::TopK { k: 50 };
match mode {
GenerationMode::TopK { k } => assert_eq!(k, 50),
_ => panic!("Expected TopK"),
}
}
#[test]
fn test_generation_mode_top_p_params() {
let mode = GenerationMode::TopP { p: 0.95 };
match mode {
GenerationMode::TopP { p } => assert!((p - 0.95).abs() < f32::EPSILON),
_ => panic!("Expected TopP"),
}
}
#[test]
fn test_gpt2_small_config() {
let config = Gpt2Config::small();
assert!(config.vocab_size > 0);
assert!(config.n_embd > 0);
assert!(config.n_layer > 0);
assert!(config.n_head > 0);
}
#[test]
fn test_gpt2_model_creation_tiny() -> Result<()> {
let mut config = Gpt2Config::small();
config.vocab_size = 20;
config.n_positions = 16;
config.n_embd = 16;
config.n_layer = 1;
config.n_head = 2;
let model = Gpt2LMHeadModel::new(config)?;
assert!(model.num_parameters() > 0);
Ok(())
}
#[test]
fn test_generation_config_max_new_tokens() {
let config = GenerationConfig {
max_new_tokens: Some(10),
..Default::default()
};
assert_eq!(config.max_new_tokens, Some(10));
}
#[test]
fn test_greedy_generation_tiny_model() -> Result<()> {
let mut config = Gpt2Config::small();
config.vocab_size = 30;
config.n_positions = 16;
config.n_embd = 16;
config.n_layer = 1;
config.n_head = 2;
let model = Gpt2LMHeadModel::new(config)?;
let input_ids = vec![1, 2, 3];
let result = model.generate_greedy(input_ids, 5);
assert!(result.is_ok());
let generated = result?;
assert!(generated.len() <= 5);
drop(generated);
drop(model);
std::hint::black_box(());
Ok(())
}
#[test]
fn test_generation_config_clone() {
let config = GenerationConfig {
max_length: 200,
temperature: 0.7,
repetition_penalty: 1.3,
..Default::default()
};
let cloned = config.clone();
assert_eq!(cloned.max_length, 200);
assert!((cloned.temperature - 0.7).abs() < f32::EPSILON);
assert!((cloned.repetition_penalty - 1.3).abs() < f32::EPSILON);
}
#[test]
fn test_generation_mode_all_variants_constructable() {
let modes: Vec<GenerationMode> = vec![
GenerationMode::Greedy,
GenerationMode::BeamSearch { num_beams: 4 },
GenerationMode::TopK { k: 40 },
GenerationMode::TopP { p: 0.9 },
];
assert_eq!(modes.len(), 4);
}
#[test]
fn test_generation_config_with_beam_search() {
let config = GenerationConfig {
mode: GenerationMode::BeamSearch { num_beams: 4 },
max_length: 50,
..Default::default()
};
assert!(matches!(
config.mode,
GenerationMode::BeamSearch { num_beams: 4 }
));
}
}