use crate::sampling::SamplingParams;
#[derive(Debug, Clone)]
pub struct GenerationState {
pub step: usize,
pub recent_tokens: Vec<u32>,
pub recent_entropies: Vec<f32>,
pub repetition_count: usize,
}
impl Default for GenerationState {
fn default() -> Self {
Self::new()
}
}
impl GenerationState {
const WINDOW_CAP: usize = 64;
pub fn new() -> Self {
Self {
step: 0,
recent_tokens: Vec::new(),
recent_entropies: Vec::new(),
repetition_count: 0,
}
}
pub fn update(&mut self, token: u32, entropy: f32) {
self.step += 1;
self.recent_tokens.push(token);
if self.recent_tokens.len() > Self::WINDOW_CAP {
self.recent_tokens.remove(0);
}
self.recent_entropies.push(entropy);
if self.recent_entropies.len() > Self::WINDOW_CAP {
self.recent_entropies.remove(0);
}
let len = self.recent_tokens.len();
if len >= 2 {
let last = self.recent_tokens[len - 1];
let prev = self.recent_tokens[len - 2];
let repeated = self.recent_tokens[..len.saturating_sub(2)]
.windows(2)
.any(|w| w[0] == prev && w[1] == last);
if repeated {
self.repetition_count += 1;
} else {
self.repetition_count = 0;
}
}
}
pub fn recent_repetition_rate(&self, window: usize) -> f32 {
if window == 0 || self.recent_tokens.is_empty() {
return 0.0;
}
let tokens = &self.recent_tokens;
let start = tokens.len().saturating_sub(window);
let slice = &tokens[start..];
if slice.len() < 2 {
return 0.0;
}
let repeats = slice.windows(2).filter(|w| w[0] == w[1]).count();
repeats as f32 / (slice.len() - 1) as f32
}
pub fn mean_recent_entropy(&self, window: usize) -> f32 {
if window == 0 || self.recent_entropies.is_empty() {
return 0.0;
}
let start = self.recent_entropies.len().saturating_sub(window);
let slice = &self.recent_entropies[start..];
if slice.is_empty() {
return 0.0;
}
slice.iter().sum::<f32>() / slice.len() as f32
}
}
pub trait AdaptiveStrategy: Send + Sync {
fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams;
fn name(&self) -> &'static str;
}
pub struct EntropyCooling {
pub target_entropy: f32,
pub cooling_rate: f32,
pub min_temperature: f32,
}
impl EntropyCooling {
pub fn new(target_entropy: f32) -> Self {
Self {
target_entropy,
cooling_rate: 0.5,
min_temperature: 0.1,
}
}
}
impl AdaptiveStrategy for EntropyCooling {
fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
let mut params = base.clone();
let window = 8.min(state.recent_entropies.len().max(1));
let mean_entropy = state.mean_recent_entropy(window);
if mean_entropy > self.target_entropy {
let excess = mean_entropy - self.target_entropy;
let reduction = self.cooling_rate * excess;
let new_temp = (base.temperature - reduction).max(self.min_temperature);
params.temperature = new_temp;
}
params
}
fn name(&self) -> &'static str {
"EntropyCooling"
}
}
pub struct RepetitionAdaptation {
pub rep_threshold: f32,
pub cool_factor: f32,
pub heat_factor: f32,
}
impl Default for RepetitionAdaptation {
fn default() -> Self {
Self::new()
}
}
impl RepetitionAdaptation {
pub fn new() -> Self {
Self {
rep_threshold: 0.3,
cool_factor: 0.8,
heat_factor: 1.1,
}
}
}
impl AdaptiveStrategy for RepetitionAdaptation {
fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
let mut params = base.clone();
let window = 16.min(state.recent_tokens.len().max(1));
let rep_rate = state.recent_repetition_rate(window);
if rep_rate > self.rep_threshold {
params.temperature = (base.temperature * self.cool_factor).max(0.01);
} else if rep_rate < self.rep_threshold / 2.0 && state.step > 4 {
params.temperature = (base.temperature * self.heat_factor).min(2.0);
}
params
}
fn name(&self) -> &'static str {
"RepetitionAdaptation"
}
}
pub struct ScheduledDecay {
pub initial_temperature: f32,
pub final_temperature: f32,
pub total_steps: usize,
}
impl ScheduledDecay {
pub fn new(initial: f32, final_temp: f32, steps: usize) -> Self {
Self {
initial_temperature: initial,
final_temperature: final_temp,
total_steps: steps,
}
}
pub fn temperature_at_step(&self, step: usize) -> f32 {
if self.total_steps == 0 {
return self.final_temperature;
}
let t = (step as f32 / self.total_steps as f32).min(1.0);
self.initial_temperature + t * (self.final_temperature - self.initial_temperature)
}
}
impl AdaptiveStrategy for ScheduledDecay {
fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
let mut params = base.clone();
params.temperature = self.temperature_at_step(state.step);
params
}
fn name(&self) -> &'static str {
"ScheduledDecay"
}
}
pub struct AdaptiveSamplerChain {
strategies: Vec<Box<dyn AdaptiveStrategy>>,
}
impl Default for AdaptiveSamplerChain {
fn default() -> Self {
Self::new()
}
}
impl AdaptiveSamplerChain {
pub fn new() -> Self {
Self {
strategies: Vec::new(),
}
}
#[allow(clippy::should_implement_trait)]
pub fn add(mut self, strategy: Box<dyn AdaptiveStrategy>) -> Self {
self.strategies.push(strategy);
self
}
pub fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
self.strategies
.iter()
.fold(base.clone(), |params, strategy| {
strategy.adjust(state, ¶ms)
})
}
pub fn len(&self) -> usize {
self.strategies.len()
}
pub fn is_empty(&self) -> bool {
self.strategies.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn generation_state_new_empty() {
let state = GenerationState::new();
assert_eq!(state.step, 0);
assert!(state.recent_tokens.is_empty());
assert!(state.recent_entropies.is_empty());
assert_eq!(state.repetition_count, 0);
}
#[test]
fn generation_state_update() {
let mut state = GenerationState::new();
state.update(42, 1.5);
assert_eq!(state.step, 1);
assert_eq!(state.recent_tokens, vec![42]);
assert!((state.recent_entropies[0] - 1.5).abs() < 1e-6);
}
#[test]
fn generation_state_repetition_rate_no_rep() {
let mut state = GenerationState::new();
for tok in [1u32, 2, 3, 4, 5] {
state.update(tok, 1.0);
}
let rate = state.recent_repetition_rate(5);
assert!((rate - 0.0).abs() < 1e-6);
}
#[test]
fn generation_state_repetition_rate_all_same() {
let mut state = GenerationState::new();
for _ in 0..5 {
state.update(7, 1.0);
}
let rate = state.recent_repetition_rate(5);
assert!(rate > 0.5, "expected high repetition rate, got {rate}");
}
#[test]
fn generation_state_mean_entropy() {
let mut state = GenerationState::new();
state.update(1, 2.0);
state.update(2, 4.0);
state.update(3, 6.0);
let mean = state.mean_recent_entropy(3);
assert!((mean - 4.0).abs() < 1e-5, "expected 4.0, got {mean}");
}
#[test]
fn entropy_cooling_high_entropy_reduces_temp() {
let strategy = EntropyCooling::new(1.0);
let base = SamplingParams {
temperature: 1.0,
..Default::default()
};
let mut state = GenerationState::new();
for _ in 0..8 {
state.update(1, 3.0);
}
let adjusted = strategy.adjust(&state, &base);
assert!(
adjusted.temperature < base.temperature,
"expected temperature to decrease, got {}",
adjusted.temperature
);
}
#[test]
fn entropy_cooling_low_entropy_no_change() {
let strategy = EntropyCooling::new(2.0);
let base = SamplingParams {
temperature: 0.7,
..Default::default()
};
let mut state = GenerationState::new();
for _ in 0..8 {
state.update(1, 0.5);
}
let adjusted = strategy.adjust(&state, &base);
assert!(
(adjusted.temperature - base.temperature).abs() < 1e-6,
"expected no change, got {}",
adjusted.temperature
);
}
#[test]
fn entropy_cooling_min_temp_floor() {
let strategy = EntropyCooling {
target_entropy: 0.0,
cooling_rate: 100.0,
min_temperature: 0.05,
};
let base = SamplingParams {
temperature: 1.0,
..Default::default()
};
let mut state = GenerationState::new();
for _ in 0..8 {
state.update(1, 5.0);
}
let adjusted = strategy.adjust(&state, &base);
assert!(
adjusted.temperature >= 0.05,
"temperature below min floor: {}",
adjusted.temperature
);
}
#[test]
fn repetition_adaptation_high_rep_cools() {
let strategy = RepetitionAdaptation::new();
let base = SamplingParams {
temperature: 1.0,
..Default::default()
};
let mut state = GenerationState::new();
for _ in 0..20 {
state.update(42, 0.1);
}
let adjusted = strategy.adjust(&state, &base);
assert!(
adjusted.temperature < base.temperature,
"expected cooling, got {}",
adjusted.temperature
);
}
#[test]
fn repetition_adaptation_low_rep_unchanged() {
let strategy = RepetitionAdaptation::new();
let base = SamplingParams {
temperature: 1.0,
..Default::default()
};
let mut state = GenerationState::new();
for i in 0..5u32 {
state.update(i, 1.0);
}
let adjusted = strategy.adjust(&state, &base);
assert!(
adjusted.temperature >= base.temperature - 0.01,
"unexpected cooling: {}",
adjusted.temperature
);
}
#[test]
fn scheduled_decay_at_step_zero() {
let sched = ScheduledDecay::new(1.0, 0.1, 100);
assert!((sched.temperature_at_step(0) - 1.0).abs() < 1e-6);
}
#[test]
fn scheduled_decay_at_final_step() {
let sched = ScheduledDecay::new(1.0, 0.1, 100);
assert!((sched.temperature_at_step(100) - 0.1).abs() < 1e-6);
}
#[test]
fn scheduled_decay_intermediate() {
let sched = ScheduledDecay::new(1.0, 0.0, 100);
let mid = sched.temperature_at_step(50);
assert!((mid - 0.5).abs() < 1e-5, "expected 0.5, got {mid}");
}
#[test]
fn adaptive_chain_empty() {
let chain = AdaptiveSamplerChain::new();
let base = SamplingParams::default();
let state = GenerationState::new();
let adjusted = chain.adjust(&state, &base);
assert!((adjusted.temperature - base.temperature).abs() < 1e-6);
}
#[test]
fn adaptive_chain_applies_all() {
let chain = AdaptiveSamplerChain::new()
.add(Box::new(ScheduledDecay::new(1.0, 0.0, 100)))
.add(Box::new(EntropyCooling::new(0.0)));
assert_eq!(chain.len(), 2);
let base = SamplingParams {
temperature: 1.0,
..Default::default()
};
let mut state = GenerationState::new();
for _ in 0..50 {
state.update(1, 5.0); }
let adjusted = chain.adjust(&state, &base);
assert!(
adjusted.temperature < 0.5 + 1e-3,
"expected temp <= 0.5, got {}",
adjusted.temperature
);
}
}