use ferrum_types::{Result, SamplingParams, TokenId};
use rand::RngCore;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug)]
pub struct SamplingContext<'a> {
pub step: usize,
pub sampling_params: &'a SamplingParams,
pub logits: &'a mut [f32],
pub previous_tokens: &'a [TokenId],
pub token_frequencies: &'a HashMap<TokenId, usize>,
pub vocab_size: usize,
pub metadata: HashMap<String, f32>,
}
impl<'a> SamplingContext<'a> {
pub fn new(
step: usize,
sampling_params: &'a SamplingParams,
logits: &'a mut [f32],
previous_tokens: &'a [TokenId],
token_frequencies: &'a HashMap<TokenId, usize>,
vocab_size: usize,
) -> Self {
Self {
step,
sampling_params,
logits,
previous_tokens,
token_frequencies,
vocab_size,
metadata: HashMap::new(),
}
}
pub fn get_logit(&self, token_id: TokenId) -> Option<f32> {
if usize::from(token_id) < self.logits.len() {
Some(self.logits[usize::from(token_id)])
} else {
None
}
}
pub fn set_logit(&mut self, token_id: TokenId, value: f32) -> bool {
if usize::from(token_id) < self.logits.len() {
self.logits[usize::from(token_id)] = value;
true
} else {
false
}
}
pub fn mask_tokens(&mut self, token_ids: &[TokenId]) {
for &token_id in token_ids {
if usize::from(token_id) < self.logits.len() {
self.logits[usize::from(token_id)] = f32::NEG_INFINITY;
}
}
}
}
pub trait LogitsProcessor: Send + Sync {
fn process(&self, ctx: &mut SamplingContext) -> Result<()>;
fn name(&self) -> &str;
fn priority(&self) -> ProcessorPriority {
ProcessorPriority::Normal
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum ProcessorPriority {
High = 3,
Normal = 2,
Low = 1,
}
pub trait Sampler: Send + Sync {
fn sample(&self, logits: &[f32], rng: &mut dyn RngCore) -> Result<TokenId>;
fn sample_with_context(&self, ctx: &SamplingContext, rng: &mut dyn RngCore) -> Result<TokenId> {
self.sample(ctx.logits, rng)
}
fn name(&self) -> &str;
fn is_deterministic(&self) -> bool;
}
pub trait MultiSampler: Sampler {
fn sample_multiple(
&self,
logits: &[f32],
num_samples: usize,
rng: &mut dyn RngCore,
) -> Result<Vec<TokenId>>;
fn sample_with_probabilities(
&self,
logits: &[f32],
rng: &mut dyn RngCore,
) -> Result<(TokenId, Vec<f32>)>;
}
pub struct LogitsProcessorChain {
processors: Vec<Box<dyn LogitsProcessor>>,
}
impl LogitsProcessorChain {
pub fn new() -> Self {
Self {
processors: Vec::new(),
}
}
pub fn add_processor(mut self, processor: Box<dyn LogitsProcessor>) -> Self {
self.processors.push(processor);
self.processors
.sort_by(|a, b| b.priority().cmp(&a.priority()));
self
}
pub fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
for processor in &self.processors {
processor.process(ctx)?;
}
Ok(())
}
pub fn processor_names(&self) -> Vec<&str> {
self.processors.iter().map(|p| p.name()).collect()
}
}
impl Default for LogitsProcessorChain {
fn default() -> Self {
Self::new()
}
}
pub struct TemperatureProcessor {
pub temperature: f32,
}
impl TemperatureProcessor {
pub fn new(temperature: f32) -> Self {
Self { temperature }
}
}
impl LogitsProcessor for TemperatureProcessor {
fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
if self.temperature > 0.0 && self.temperature != 1.0 {
for logit in ctx.logits.iter_mut() {
*logit /= self.temperature;
}
}
Ok(())
}
fn name(&self) -> &str {
"temperature"
}
fn priority(&self) -> ProcessorPriority {
ProcessorPriority::Low }
}
pub struct TopKProcessor {
pub k: usize,
}
impl TopKProcessor {
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl LogitsProcessor for TopKProcessor {
fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
if self.k > 0 && self.k < ctx.logits.len() {
let mut indices: Vec<usize> = (0..ctx.logits.len()).collect();
indices.sort_by(|&a, &b| {
ctx.logits[b]
.partial_cmp(&ctx.logits[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let threshold = ctx.logits[indices[self.k - 1]];
for logit in ctx.logits.iter_mut() {
if *logit < threshold {
*logit = f32::NEG_INFINITY;
}
}
}
Ok(())
}
fn name(&self) -> &str {
"top_k"
}
}
pub struct TopPProcessor {
pub p: f32,
}
impl TopPProcessor {
pub fn new(p: f32) -> Self {
Self { p }
}
}
impl LogitsProcessor for TopPProcessor {
fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
if self.p < 1.0 && self.p > 0.0 {
let max_logit = ctx.logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut probs: Vec<f32> = ctx
.logits
.iter()
.map(|&logit| (logit - max_logit).exp())
.collect();
let sum: f32 = probs.iter().sum();
for prob in probs.iter_mut() {
*prob /= sum;
}
let mut indices: Vec<usize> = (0..probs.len()).collect();
indices.sort_by(|&a, &b| {
probs[b]
.partial_cmp(&probs[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut cum_prob = 0.0;
let mut cutoff_idx = probs.len();
for (i, &idx) in indices.iter().enumerate() {
cum_prob += probs[idx];
if cum_prob > self.p {
cutoff_idx = i + 1;
break;
}
}
for (i, &idx) in indices.iter().enumerate() {
if i >= cutoff_idx {
ctx.logits[idx] = f32::NEG_INFINITY;
}
}
}
Ok(())
}
fn name(&self) -> &str {
"top_p"
}
}
pub struct RepetitionPenaltyProcessor {
pub penalty: f32,
}
impl RepetitionPenaltyProcessor {
pub fn new(penalty: f32) -> Self {
Self { penalty }
}
}
impl LogitsProcessor for RepetitionPenaltyProcessor {
fn process(&self, ctx: &mut SamplingContext) -> Result<()> {
if self.penalty != 1.0 {
for &token_id in ctx.previous_tokens {
if let Some(freq) = ctx.token_frequencies.get(&token_id) {
if usize::from(token_id) < ctx.logits.len() {
let idx = usize::from(token_id);
let current_logit = ctx.logits[idx];
let penalty_factor = self.penalty.powi(*freq as i32);
if current_logit > 0.0 {
ctx.logits[idx] = current_logit / penalty_factor;
} else {
ctx.logits[idx] = current_logit * penalty_factor;
}
}
}
}
}
Ok(())
}
fn name(&self) -> &str {
"repetition_penalty"
}
fn priority(&self) -> ProcessorPriority {
ProcessorPriority::High }
}
pub struct GreedySampler;
impl Sampler for GreedySampler {
fn sample(&self, logits: &[f32], _rng: &mut dyn RngCore) -> Result<TokenId> {
let max_idx = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.ok_or_else(|| ferrum_types::FerrumError::backend("Empty logits for sampling"))?;
Ok(TokenId::new(max_idx as u32))
}
fn name(&self) -> &str {
"greedy"
}
fn is_deterministic(&self) -> bool {
true
}
}
pub struct MultinomialSampler;
impl Sampler for MultinomialSampler {
fn sample(&self, logits: &[f32], rng: &mut dyn RngCore) -> Result<TokenId> {
let max_logit = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut probs: Vec<f32> = logits
.iter()
.map(|&logit| {
if logit.is_finite() && logit > f32::NEG_INFINITY {
(logit - max_logit).exp()
} else {
0.0
}
})
.collect();
let sum: f32 = probs.iter().sum();
if sum <= 0.0 {
return Err(ferrum_types::FerrumError::backend(
"No valid tokens for sampling",
));
}
for prob in probs.iter_mut() {
*prob /= sum;
}
let threshold = rng.next_u32() as f32 / u32::MAX as f32;
let mut cumulative = 0.0;
for (idx, prob) in probs.iter().enumerate() {
cumulative += prob;
if cumulative >= threshold {
return Ok(TokenId::new(idx as u32));
}
}
Ok(TokenId::new((probs.len() - 1) as u32))
}
fn name(&self) -> &str {
"multinomial"
}
fn is_deterministic(&self) -> bool {
false
}
}
pub struct SamplingConfigBuilder {
processors: Vec<Box<dyn LogitsProcessor>>,
sampler: Option<Box<dyn Sampler>>,
}
impl SamplingConfigBuilder {
pub fn new() -> Self {
Self {
processors: Vec::new(),
sampler: None,
}
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
if temperature > 0.0 && temperature != 1.0 {
self.processors
.push(Box::new(TemperatureProcessor::new(temperature)));
}
self
}
pub fn with_top_k(mut self, k: usize) -> Self {
if k > 0 {
self.processors.push(Box::new(TopKProcessor::new(k)));
}
self
}
pub fn with_top_p(mut self, p: f32) -> Self {
if p > 0.0 && p < 1.0 {
self.processors.push(Box::new(TopPProcessor::new(p)));
}
self
}
pub fn with_repetition_penalty(mut self, penalty: f32) -> Self {
if penalty != 1.0 {
self.processors
.push(Box::new(RepetitionPenaltyProcessor::new(penalty)));
}
self
}
pub fn with_sampler(mut self, sampler: Box<dyn Sampler>) -> Self {
self.sampler = Some(sampler);
self
}
pub fn build(self) -> SamplingConfig {
let mut chain = LogitsProcessorChain::new();
for processor in self.processors {
chain = chain.add_processor(processor);
}
let sampler = self.sampler.unwrap_or_else(|| Box::new(MultinomialSampler));
SamplingConfig {
processor_chain: chain,
sampler,
}
}
}
impl Default for SamplingConfigBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct SamplingConfig {
pub processor_chain: LogitsProcessorChain,
pub sampler: Box<dyn Sampler>,
}
impl SamplingConfig {
pub fn from_params(params: &SamplingParams) -> Self {
let mut builder = SamplingConfigBuilder::new()
.with_temperature(params.temperature)
.with_repetition_penalty(params.repetition_penalty);
if let Some(top_k) = params.top_k {
builder = builder.with_top_k(top_k);
}
if params.top_p < 1.0 {
builder = builder.with_top_p(params.top_p);
}
let sampler: Box<dyn Sampler> = if params.temperature == 0.0 {
Box::new(GreedySampler)
} else {
Box::new(MultinomialSampler)
};
builder.with_sampler(sampler).build()
}
pub fn sample(&self, mut ctx: SamplingContext, rng: &mut dyn RngCore) -> Result<TokenId> {
self.processor_chain.process(&mut ctx)?;
self.sampler.sample_with_context(&ctx, rng)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SamplingStats {
pub total_samples: u64,
pub avg_sample_time_us: f64,
pub token_distribution: HashMap<TokenId, u64>,
pub effective_temperature: f32,
pub processor_times: HashMap<String, f64>,
}