impl LogitProcessorChain {
#[must_use]
pub fn new() -> Self {
Self {
processors: Vec::new(),
}
}
#[must_use]
pub fn with_processor<P: LogitProcessor + 'static>(mut self, processor: P) -> Self {
self.processors.push(Box::new(processor));
self
}
#[must_use]
pub fn with_boxed_processor(mut self, processor: Box<dyn LogitProcessor>) -> Self {
self.processors.push(processor);
self
}
pub fn process(&self, logits: &mut [f32], ctx: &LogitProcessorContext) {
for processor in &self.processors {
processor.process(logits, ctx);
}
}
#[must_use]
pub fn len(&self) -> usize {
self.processors.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.processors.is_empty()
}
#[must_use]
pub fn processor_names(&self) -> Vec<&str> {
self.processors.iter().map(|p| p.name()).collect()
}
}
impl LogitProcessor for LogitProcessorChain {
fn process(&self, logits: &mut [f32], ctx: &LogitProcessorContext) {
LogitProcessorChain::process(self, logits, ctx);
}
fn name(&self) -> &'static str {
"processor_chain"
}
}
pub trait GenerativeModel {
fn forward(&mut self, tokens: &[u32]) -> Result<Vec<f32>>;
fn vocab_size(&self) -> usize;
fn reset(&mut self) {}
}
pub struct GenerationPipeline<M: GenerativeModel> {
model: M,
processors: LogitProcessorChain,
config: GenerationConfig,
}
impl<M: GenerativeModel> GenerationPipeline<M> {
#[must_use]
pub fn new(model: M) -> Self {
Self {
model,
processors: LogitProcessorChain::new(),
config: GenerationConfig::default(),
}
}
#[must_use]
pub fn add_processor<P: LogitProcessor + 'static>(mut self, processor: P) -> Self {
self.processors = self.processors.with_processor(processor);
self
}
#[must_use]
pub fn with_config(mut self, config: GenerationConfig) -> Self {
self.config = config;
self
}
pub fn generate(&mut self, initial_tokens: &[u32]) -> Result<Vec<u32>> {
let mut tokens = initial_tokens.to_vec();
let n_vocab = self.model.vocab_size();
let eos_token = self.config.eos_token_id;
let mut rng_state = self.config.seed.unwrap_or(42);
for step in 0..self.config.max_tokens {
let mut logits = self.model.forward(&tokens)?;
let ctx = LogitProcessorContext::new(&tokens, step, n_vocab);
self.processors.process(&mut logits, &ctx);
let logits_tensor = Tensor::from_vec(vec![logits.len()], logits)?;
rng_state = rng_state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1);
let rng_value = (rng_state >> 33) as f32 / (1u64 << 31) as f32;
let next_token = sample_token(&logits_tensor, &self.config, rng_value)? as u32;
tokens.push(next_token);
if let Some(eos) = eos_token {
if next_token == eos as u32 {
break;
}
}
}
Ok(tokens)
}
#[must_use]
pub fn model(&self) -> &M {
&self.model
}
pub fn model_mut(&mut self) -> &mut M {
&mut self.model
}
#[must_use]
pub fn processors(&self) -> &LogitProcessorChain {
&self.processors
}
#[must_use]
pub fn config(&self) -> &GenerationConfig {
&self.config
}
}
#[cfg(test)]
#[path = "sampler_tests.rs"]
mod sampler_tests;