pub mod json_mode;
pub use ferrum_interfaces::sampler::{
GreedySampler, LogitsProcessor, LogitsProcessorChain, MultiSampler, MultinomialSampler,
ProcessorPriority, RepetitionPenaltyProcessor, Sampler, SamplingConfig, SamplingConfigBuilder,
SamplingContext, SamplingStats, TemperatureProcessor, TopKProcessor, TopPProcessor,
};
pub use ferrum_types::{Result, SamplingParams, TokenId};
use rand::RngCore;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct DefaultSamplerFactory;
impl DefaultSamplerFactory {
pub fn new() -> Self {
Self
}
pub fn build_config(&self, params: &SamplingParams) -> SamplingConfig {
SamplingConfig::from_params(params)
}
pub fn create_sampler(&self, params: &SamplingParams) -> Box<dyn Sampler + Send + Sync> {
if params.temperature == 0.0 {
Box::new(GreedySampler)
} else {
Box::new(MultinomialSampler)
}
}
pub fn build_pipeline(&self, params: &SamplingParams) -> SamplingPipeline {
let config = self.build_config(params);
SamplingPipeline { config }
}
}
pub struct SamplingPipeline {
config: SamplingConfig,
}
impl SamplingPipeline {
pub fn new(params: &SamplingParams) -> Self {
let config = SamplingConfig::from_params(params);
Self { config }
}
pub fn config(&self) -> &SamplingConfig {
&self.config
}
pub fn sample_next(
&self,
step: usize,
logits: &mut [f32],
previous_tokens: &[TokenId],
token_frequencies: &HashMap<TokenId, usize>,
sampling_params: &SamplingParams,
rng: &mut dyn RngCore,
) -> Result<TokenId> {
let vocab_size = logits.len();
let ctx = SamplingContext::new(
step,
sampling_params,
logits,
previous_tokens,
token_frequencies,
vocab_size,
);
self.config.sample(ctx, rng)
}
pub fn sample_simple(&self, logits: &mut [f32], rng: &mut dyn RngCore) -> Result<TokenId> {
let params = SamplingParams::default();
let empty_tokens = Vec::new();
let empty_freqs = HashMap::new();
self.sample_next(0, logits, &empty_tokens, &empty_freqs, ¶ms, rng)
}
}
pub fn build_sampling_config(params: &SamplingParams) -> SamplingConfig {
SamplingConfig::from_params(params)
}
pub fn sampler_from_params(params: &SamplingParams) -> Box<dyn Sampler + Send + Sync> {
DefaultSamplerFactory::new().create_sampler(params)
}
pub fn pipeline_from_params(params: &SamplingParams) -> SamplingPipeline {
DefaultSamplerFactory::new().build_pipeline(params)
}
pub fn greedy_sampler() -> Box<dyn Sampler + Send + Sync> {
Box::new(GreedySampler)
}
pub fn multinomial_sampler() -> Box<dyn Sampler + Send + Sync> {
Box::new(MultinomialSampler)
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[test]
fn test_factory_creates_greedy_for_zero_temp() {
let factory = DefaultSamplerFactory::new();
let params = SamplingParams {
temperature: 0.0,
..Default::default()
};
let sampler = factory.create_sampler(¶ms);
assert!(sampler.is_deterministic());
}
#[test]
fn test_factory_creates_multinomial_for_nonzero_temp() {
let factory = DefaultSamplerFactory::new();
let params = SamplingParams {
temperature: 1.0,
..Default::default()
};
let sampler = factory.create_sampler(¶ms);
assert!(!sampler.is_deterministic());
}
#[test]
fn test_build_sampling_config() {
let params = SamplingParams {
temperature: 0.8,
top_k: Some(50),
top_p: 0.95,
repetition_penalty: 1.1,
..Default::default()
};
let config = build_sampling_config(¶ms);
assert_eq!(config.processor_chain.processor_names().len(), 4);
}
#[test]
fn test_pipeline_sample_simple() {
let params = SamplingParams::greedy();
let pipeline = pipeline_from_params(¶ms);
let mut rng = StdRng::seed_from_u64(42);
let mut logits = vec![1.0, 5.0, 2.0, 0.5];
let token = pipeline.sample_simple(&mut logits, &mut rng).unwrap();
assert_eq!(token.get(), 1);
}
#[test]
fn test_greedy_sampler_deterministic() {
let sampler = greedy_sampler();
assert!(sampler.is_deterministic());
assert_eq!(sampler.name(), "greedy");
}
#[test]
fn test_multinomial_sampler_stochastic() {
let sampler = multinomial_sampler();
assert!(!sampler.is_deterministic());
assert_eq!(sampler.name(), "multinomial");
}
#[test]
fn test_pipeline_with_context() {
let params = SamplingParams {
temperature: 1.0,
repetition_penalty: 1.2,
..Default::default()
};
let pipeline = SamplingPipeline::new(¶ms);
let mut rng = StdRng::seed_from_u64(42);
let mut logits = vec![1.0, 2.0, 3.0, 2.0];
let previous_tokens = vec![TokenId::new(2)]; let mut freqs = HashMap::new();
freqs.insert(TokenId::new(2), 1);
let token = pipeline
.sample_next(0, &mut logits, &previous_tokens, &freqs, ¶ms, &mut rng)
.unwrap();
assert!(token.get() < 4);
}
}