#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DynTempConfig {
pub temp: f32,
pub delta: f32,
pub exponent: f32,
}
impl Default for DynTempConfig {
fn default() -> Self {
Self {
temp: 1.0,
delta: 0.0,
exponent: 1.0,
}
}
}
impl DynTempConfig {
pub fn new(temp: f32, delta: f32, exponent: f32) -> Self {
Self {
temp,
delta,
exponent,
}
}
pub fn static_temp(temp: f32) -> Self {
Self {
temp,
delta: 0.0,
exponent: 1.0,
}
}
}
pub fn apply_dynamic_temperature(logits: &Tensor<f32>, config: &DynTempConfig) -> Tensor<f32> {
if config.delta <= 0.0 {
return apply_temperature(logits, config.temp).unwrap_or_else(|e| {
eprintln!("[WARN] Temperature application failed ({e}), using raw logits");
logits.clone()
});
}
let data = logits.data();
if data.len() <= 1 {
return logits.clone();
}
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = data.iter().map(|x| (x - max_logit).exp()).sum();
let probs: Vec<f32> = data
.iter()
.map(|x| (x - max_logit).exp() / exp_sum)
.collect();
let max_entropy = (data.len() as f32).ln();
let entropy: f32 = probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum();
let normalized_entropy = if max_entropy > 0.0 {
(entropy / max_entropy).clamp(0.0, 1.0)
} else {
0.0
};
let min_temp = (config.temp - config.delta).max(0.0);
let max_temp = config.temp + config.delta;
let dyn_temp = min_temp + (max_temp - min_temp) * normalized_entropy.powf(config.exponent);
apply_temperature(logits, dyn_temp).unwrap_or_else(|e| {
eprintln!("[WARN] Dynamic temperature application failed ({e}), using raw logits");
logits.clone()
})
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InfillConfig {
pub eog_tokens: Vec<usize>,
pub eog_ratio_threshold: f32,
}
impl Default for InfillConfig {
fn default() -> Self {
Self {
eog_tokens: vec![],
eog_ratio_threshold: 3.0,
}
}
}
impl InfillConfig {
pub fn new(eog_tokens: Vec<usize>) -> Self {
Self {
eog_tokens,
eog_ratio_threshold: 3.0,
}
}
#[must_use]
pub fn with_threshold(mut self, threshold: f32) -> Self {
self.eog_ratio_threshold = threshold;
self
}
}
#[derive(Debug, Clone)]
pub struct InfillResult {
pub logits: Tensor<f32>,
pub force_eog: bool,
pub p_txt: f32,
pub p_eog: f32,
}
fn compute_eog_txt_probs(probs: &[f32], eog_tokens: &[usize]) -> (f32, f32) {
let mut p_eog: f32 = 0.0;
let mut p_txt: f32 = 0.0;
for (i, &p) in probs.iter().enumerate() {
if eog_tokens.contains(&i) {
p_eog += p;
} else {
p_txt += p;
}
}
(p_eog, p_txt)
}
fn create_eog_only_logits(
data: &[f32],
probs: &[f32],
eog_tokens: &[usize],
shape: &[usize],
) -> Tensor<f32> {
let mut new_data = vec![f32::NEG_INFINITY; data.len()];
let mut eog_sum = 0.0;
for &eog_id in eog_tokens {
if eog_id < data.len() {
new_data[eog_id] = data[eog_id];
eog_sum += probs[eog_id];
}
}
if eog_sum > 0.0 {
for &eog_id in eog_tokens {
if eog_id < data.len() && new_data[eog_id] > f32::NEG_INFINITY {
let normalized_p = probs[eog_id] / eog_sum;
new_data[eog_id] = normalized_p.ln();
}
}
}
Tensor::from_vec(shape.to_vec(), new_data)
.expect("BUG: EOG logits shape/data mismatch (same shape as input tensor)")
}
pub fn apply_infill_sampling(logits: &Tensor<f32>, config: &InfillConfig) -> InfillResult {
let data = logits.data();
if data.is_empty() || config.eog_tokens.is_empty() {
return InfillResult {
logits: logits.clone(),
force_eog: false,
p_txt: 1.0,
p_eog: 0.0,
};
}
let max_logit = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_sum: f32 = data.iter().map(|x| (x - max_logit).exp()).sum();
let probs: Vec<f32> = data
.iter()
.map(|x| (x - max_logit).exp() / exp_sum)
.collect();
let (p_eog, p_txt) = compute_eog_txt_probs(&probs, &config.eog_tokens);
let n = data.len() as f32;
let force_eog = config.eog_ratio_threshold * p_eog * n > p_txt;
if force_eog {
InfillResult {
logits: create_eog_only_logits(data, &probs, &config.eog_tokens, logits.shape()),
force_eog: true,
p_txt,
p_eog,
}
} else {
InfillResult {
logits: logits.clone(),
force_eog: false,
p_txt,
p_eog,
}
}
}
pub trait Sampler: Send + Sync {
fn name(&self) -> &'static str;
fn apply(&self, logits: &mut Tensor<f32>, context: &SamplerContext);
fn clone_box(&self) -> Box<dyn Sampler>;
}
#[derive(Debug, Clone, Default)]
pub struct SamplerContext {
pub tokens: Vec<usize>,
pub rng_value: f32,
pub step: usize,
}
impl SamplerContext {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_tokens(mut self, tokens: Vec<usize>) -> Self {
self.tokens = tokens;
self
}
#[must_use]
pub fn with_rng(mut self, rng_value: f32) -> Self {
self.rng_value = rng_value;
self
}
#[must_use]
pub fn with_step(mut self, step: usize) -> Self {
self.step = step;
self
}
}
pub struct SamplerChain {
samplers: Vec<Box<dyn Sampler>>,
}
impl Default for SamplerChain {
fn default() -> Self {
Self::new()
}
}
impl SamplerChain {
pub fn new() -> Self {
Self { samplers: vec![] }
}
#[must_use]
pub fn with_sampler<S: Sampler + 'static>(mut self, sampler: S) -> Self {
self.samplers.push(Box::new(sampler));
self
}
pub fn push(&mut self, sampler: Box<dyn Sampler>) {
self.samplers.push(sampler);
}
pub fn len(&self) -> usize {
self.samplers.len()
}
pub fn is_empty(&self) -> bool {
self.samplers.is_empty()
}
pub fn names(&self) -> Vec<&'static str> {
self.samplers.iter().map(|s| s.name()).collect()
}
pub fn apply(&self, logits: &mut Tensor<f32>, context: &SamplerContext) {
for sampler in &self.samplers {
sampler.apply(logits, context);
}
}
pub fn sample(&self, logits: &Tensor<f32>, context: &SamplerContext) -> Result<usize> {
let mut modified = logits.clone();
self.apply(&mut modified, context);
sample_greedy(&modified)
}
}
impl Clone for SamplerChain {
fn clone(&self) -> Self {
Self {
samplers: self.samplers.iter().map(|s| s.clone_box()).collect(),
}
}
}
#[derive(Debug, Clone)]
pub struct TemperatureSampler {
pub temp: f32,
}
impl TemperatureSampler {
pub fn new(temp: f32) -> Self {
Self { temp }
}
}
impl Sampler for TemperatureSampler {
fn name(&self) -> &'static str {
"temperature"
}
fn apply(&self, logits: &mut Tensor<f32>, _context: &SamplerContext) {
if let Ok(result) = apply_temperature(logits, self.temp) {
*logits = result;
}
}
fn clone_box(&self) -> Box<dyn Sampler> {
Box::new(self.clone())
}
}
#[derive(Debug, Clone)]
pub struct DynTempSampler {
pub config: DynTempConfig,
}
impl DynTempSampler {
pub fn new(config: DynTempConfig) -> Self {
Self { config }
}
}