impl Sampler for DynTempSampler {
fn name(&self) -> &'static str {
"dyn_temp"
}
fn apply(&self, logits: &mut Tensor<f32>, _context: &SamplerContext) {
*logits = apply_dynamic_temperature(logits, &self.config);
}
fn clone_box(&self) -> Box<dyn Sampler> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct TopKSampler {
pub k: usize,
}
impl TopKSampler {
pub fn new(k: usize) -> Self {
Self { k }
}
}
impl Sampler for TopKSampler {
fn name(&self) -> &'static str {
"top_k"
}
fn apply(&self, logits: &mut Tensor<f32>, _context: &SamplerContext) {
let data = logits.data_mut();
if self.k == 0 || self.k >= data.len() {
return; }
let mut values: Vec<f32> = data.to_vec();
values.select_nth_unstable_by(self.k - 1, |a, b| {
b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal)
});
let threshold = values[self.k - 1];
for logit in data.iter_mut() {
if *logit < threshold {
*logit = f32::NEG_INFINITY;
}
}
}
fn clone_box(&self) -> Box<dyn Sampler> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct TopPSampler {
pub p: f32,
}
impl TopPSampler {
pub fn new(p: f32) -> Self {
Self { p }
}
}
impl Sampler for TopPSampler {
fn name(&self) -> &'static str {
"top_p"
}
fn apply(&self, logits: &mut Tensor<f32>, _context: &SamplerContext) {
let data = logits.data();
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = data.iter().map(|x| (x - max_logit).exp()).sum();
let mut indexed: Vec<(usize, f32, f32)> = data
.iter()
.enumerate()
.map(|(i, &logit)| (i, logit, (logit - max_logit).exp() / exp_sum))
.collect();
indexed.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let mut cumsum = 0.0;
let mut cutoff_idx = indexed.len();
for (i, (_, _, prob)) in indexed.iter().enumerate() {
cumsum += prob;
if cumsum >= self.p {
cutoff_idx = i + 1;
break;
}
}
let mut new_data = vec![f32::NEG_INFINITY; data.len()];
for (idx, logit, _) in indexed.iter().take(cutoff_idx) {
new_data[*idx] = *logit;
}
if let Ok(result) = Tensor::from_vec(logits.shape().to_vec(), new_data) {
*logits = result;
}
}
fn clone_box(&self) -> Box<dyn Sampler> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct RepetitionPenaltySampler {
pub config: RepetitionPenaltyConfig,
}
impl RepetitionPenaltySampler {
pub fn new(config: RepetitionPenaltyConfig) -> Self {
Self { config }
}
}
impl Sampler for RepetitionPenaltySampler {
fn name(&self) -> &'static str {
"repetition_penalty"
}
fn apply(&self, logits: &mut Tensor<f32>, context: &SamplerContext) {
*logits = apply_repetition_penalty(logits, &context.tokens, &self.config);
}
fn clone_box(&self) -> Box<dyn Sampler> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct InfillSampler {
pub config: InfillConfig,
}
impl InfillSampler {
pub fn new(config: InfillConfig) -> Self {
Self { config }
}
}
impl Sampler for InfillSampler {
fn name(&self) -> &'static str {
"infill"
}
fn apply(&self, logits: &mut Tensor<f32>, _context: &SamplerContext) {
let result = apply_infill_sampling(logits, &self.config);
*logits = result.logits;
}
fn clone_box(&self) -> Box<dyn Sampler> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct LogitProcessorContext<'a> {
pub tokens: &'a [u32],
pub step: usize,
pub n_vocab: usize,
}
impl<'a> LogitProcessorContext<'a> {
#[must_use]
pub fn new(tokens: &'a [u32], step: usize, n_vocab: usize) -> Self {
Self {
tokens,
step,
n_vocab,
}
}
}
pub trait LogitProcessor: Send + Sync {
fn process(&self, logits: &mut [f32], ctx: &LogitProcessorContext);
fn name(&self) -> &'static str {
"unnamed"
}
}
#[derive(Debug, Clone)]
pub struct TokenSuppressor {
suppress_ids: Vec<u32>,
}
impl TokenSuppressor {
#[must_use]
pub fn new(suppress_ids: Vec<u32>) -> Self {
Self { suppress_ids }
}
#[must_use]
pub fn from_slice(suppress_ids: &[u32]) -> Self {
Self {
suppress_ids: suppress_ids.to_vec(),
}
}
}
impl LogitProcessor for TokenSuppressor {
fn process(&self, logits: &mut [f32], _ctx: &LogitProcessorContext) {
for &token_id in &self.suppress_ids {
if (token_id as usize) < logits.len() {
logits[token_id as usize] = f32::NEG_INFINITY;
}
}
}
fn name(&self) -> &'static str {
"token_suppressor"
}
}
#[derive(Debug, Clone)]
pub struct RepetitionPenalty {
penalty: f32,
window: usize,
}
impl RepetitionPenalty {
#[must_use]
pub fn new(penalty: f32, window: usize) -> Self {
Self { penalty, window }
}
#[must_use]
pub fn with_penalty(penalty: f32) -> Self {
Self { penalty, window: 0 }
}
}
impl LogitProcessor for RepetitionPenalty {
fn process(&self, logits: &mut [f32], ctx: &LogitProcessorContext) {
let tokens = if self.window > 0 && ctx.tokens.len() > self.window {
&ctx.tokens[ctx.tokens.len() - self.window..]
} else {
ctx.tokens
};
for &token_id in tokens {
if (token_id as usize) < logits.len() {
let logit = logits[token_id as usize];
logits[token_id as usize] = if logit > 0.0 {
logit / self.penalty
} else {
logit * self.penalty
};
}
}
}
fn name(&self) -> &'static str {
"repetition_penalty"
}
}
#[derive(Debug, Clone)]
pub struct TemperatureScaler {
temperature: f32,
}
impl TemperatureScaler {
#[must_use]
pub fn new(temperature: f32) -> Self {
assert!(temperature > 0.0, "Temperature must be positive");
Self { temperature }
}
}
impl LogitProcessor for TemperatureScaler {
fn process(&self, logits: &mut [f32], _ctx: &LogitProcessorContext) {
if (self.temperature - 1.0).abs() > 1e-6 {
for logit in logits.iter_mut() {
*logit /= self.temperature;
}
}
}
fn name(&self) -> &'static str {
"temperature_scaler"
}
}
#[derive(Default)]
pub struct LogitProcessorChain {
processors: Vec<Box<dyn LogitProcessor>>,
}