1use crate::sampling::SamplingParams;
9
10#[derive(Debug, Clone)]
14pub struct GenerationState {
15 pub step: usize,
17 pub recent_tokens: Vec<u32>,
19 pub recent_entropies: Vec<f32>,
21 pub repetition_count: usize,
23}
24
25impl Default for GenerationState {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl GenerationState {
32 const WINDOW_CAP: usize = 64;
33
34 pub fn new() -> Self {
36 Self {
37 step: 0,
38 recent_tokens: Vec::new(),
39 recent_entropies: Vec::new(),
40 repetition_count: 0,
41 }
42 }
43
44 pub fn update(&mut self, token: u32, entropy: f32) {
46 self.step += 1;
47
48 self.recent_tokens.push(token);
49 if self.recent_tokens.len() > Self::WINDOW_CAP {
50 self.recent_tokens.remove(0);
51 }
52
53 self.recent_entropies.push(entropy);
54 if self.recent_entropies.len() > Self::WINDOW_CAP {
55 self.recent_entropies.remove(0);
56 }
57
58 let len = self.recent_tokens.len();
60 if len >= 2 {
61 let last = self.recent_tokens[len - 1];
62 let prev = self.recent_tokens[len - 2];
63 let repeated = self.recent_tokens[..len.saturating_sub(2)]
65 .windows(2)
66 .any(|w| w[0] == prev && w[1] == last);
67 if repeated {
68 self.repetition_count += 1;
69 } else {
70 self.repetition_count = 0;
71 }
72 }
73 }
74
75 pub fn recent_repetition_rate(&self, window: usize) -> f32 {
78 if window == 0 || self.recent_tokens.is_empty() {
79 return 0.0;
80 }
81 let tokens = &self.recent_tokens;
82 let start = tokens.len().saturating_sub(window);
83 let slice = &tokens[start..];
84 if slice.len() < 2 {
85 return 0.0;
86 }
87 let repeats = slice.windows(2).filter(|w| w[0] == w[1]).count();
88 repeats as f32 / (slice.len() - 1) as f32
89 }
90
91 pub fn mean_recent_entropy(&self, window: usize) -> f32 {
93 if window == 0 || self.recent_entropies.is_empty() {
94 return 0.0;
95 }
96 let start = self.recent_entropies.len().saturating_sub(window);
97 let slice = &self.recent_entropies[start..];
98 if slice.is_empty() {
99 return 0.0;
100 }
101 slice.iter().sum::<f32>() / slice.len() as f32
102 }
103}
104
105pub trait AdaptiveStrategy: Send + Sync {
109 fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams;
111 fn name(&self) -> &'static str;
113}
114
115pub struct EntropyCooling {
122 pub target_entropy: f32,
124 pub cooling_rate: f32,
126 pub min_temperature: f32,
128}
129
130impl EntropyCooling {
131 pub fn new(target_entropy: f32) -> Self {
133 Self {
134 target_entropy,
135 cooling_rate: 0.5,
136 min_temperature: 0.1,
137 }
138 }
139}
140
141impl AdaptiveStrategy for EntropyCooling {
142 fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
143 let mut params = base.clone();
144 let window = 8.min(state.recent_entropies.len().max(1));
145 let mean_entropy = state.mean_recent_entropy(window);
146
147 if mean_entropy > self.target_entropy {
148 let excess = mean_entropy - self.target_entropy;
149 let reduction = self.cooling_rate * excess;
151 let new_temp = (base.temperature - reduction).max(self.min_temperature);
152 params.temperature = new_temp;
153 }
154
155 params
156 }
157
158 fn name(&self) -> &'static str {
159 "EntropyCooling"
160 }
161}
162
163pub struct RepetitionAdaptation {
170 pub rep_threshold: f32,
172 pub cool_factor: f32,
174 pub heat_factor: f32,
176}
177
178impl Default for RepetitionAdaptation {
179 fn default() -> Self {
180 Self::new()
181 }
182}
183
184impl RepetitionAdaptation {
185 pub fn new() -> Self {
187 Self {
188 rep_threshold: 0.3,
189 cool_factor: 0.8,
190 heat_factor: 1.1,
191 }
192 }
193}
194
195impl AdaptiveStrategy for RepetitionAdaptation {
196 fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
197 let mut params = base.clone();
198 let window = 16.min(state.recent_tokens.len().max(1));
199 let rep_rate = state.recent_repetition_rate(window);
200
201 if rep_rate > self.rep_threshold {
202 params.temperature = (base.temperature * self.cool_factor).max(0.01);
203 } else if rep_rate < self.rep_threshold / 2.0 && state.step > 4 {
204 params.temperature = (base.temperature * self.heat_factor).min(2.0);
206 }
207
208 params
209 }
210
211 fn name(&self) -> &'static str {
212 "RepetitionAdaptation"
213 }
214}
215
216pub struct ScheduledDecay {
221 pub initial_temperature: f32,
223 pub final_temperature: f32,
225 pub total_steps: usize,
227}
228
229impl ScheduledDecay {
230 pub fn new(initial: f32, final_temp: f32, steps: usize) -> Self {
232 Self {
233 initial_temperature: initial,
234 final_temperature: final_temp,
235 total_steps: steps,
236 }
237 }
238
239 pub fn temperature_at_step(&self, step: usize) -> f32 {
241 if self.total_steps == 0 {
242 return self.final_temperature;
243 }
244 let t = (step as f32 / self.total_steps as f32).min(1.0);
245 self.initial_temperature + t * (self.final_temperature - self.initial_temperature)
246 }
247}
248
249impl AdaptiveStrategy for ScheduledDecay {
250 fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
251 let mut params = base.clone();
252 params.temperature = self.temperature_at_step(state.step);
253 params
254 }
255
256 fn name(&self) -> &'static str {
257 "ScheduledDecay"
258 }
259}
260
261pub struct AdaptiveSamplerChain {
267 strategies: Vec<Box<dyn AdaptiveStrategy>>,
268}
269
270impl Default for AdaptiveSamplerChain {
271 fn default() -> Self {
272 Self::new()
273 }
274}
275
276impl AdaptiveSamplerChain {
277 pub fn new() -> Self {
279 Self {
280 strategies: Vec::new(),
281 }
282 }
283
284 #[allow(clippy::should_implement_trait)]
286 pub fn add(mut self, strategy: Box<dyn AdaptiveStrategy>) -> Self {
287 self.strategies.push(strategy);
288 self
289 }
290
291 pub fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
293 self.strategies
294 .iter()
295 .fold(base.clone(), |params, strategy| {
296 strategy.adjust(state, ¶ms)
297 })
298 }
299
300 pub fn len(&self) -> usize {
302 self.strategies.len()
303 }
304
305 pub fn is_empty(&self) -> bool {
307 self.strategies.is_empty()
308 }
309}
310
311#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn generation_state_new_empty() {
319 let state = GenerationState::new();
320 assert_eq!(state.step, 0);
321 assert!(state.recent_tokens.is_empty());
322 assert!(state.recent_entropies.is_empty());
323 assert_eq!(state.repetition_count, 0);
324 }
325
326 #[test]
327 fn generation_state_update() {
328 let mut state = GenerationState::new();
329 state.update(42, 1.5);
330 assert_eq!(state.step, 1);
331 assert_eq!(state.recent_tokens, vec![42]);
332 assert!((state.recent_entropies[0] - 1.5).abs() < 1e-6);
333 }
334
335 #[test]
336 fn generation_state_repetition_rate_no_rep() {
337 let mut state = GenerationState::new();
338 for tok in [1u32, 2, 3, 4, 5] {
339 state.update(tok, 1.0);
340 }
341 let rate = state.recent_repetition_rate(5);
342 assert!((rate - 0.0).abs() < 1e-6);
343 }
344
345 #[test]
346 fn generation_state_repetition_rate_all_same() {
347 let mut state = GenerationState::new();
348 for _ in 0..5 {
349 state.update(7, 1.0);
350 }
351 let rate = state.recent_repetition_rate(5);
352 assert!(rate > 0.5, "expected high repetition rate, got {rate}");
353 }
354
355 #[test]
356 fn generation_state_mean_entropy() {
357 let mut state = GenerationState::new();
358 state.update(1, 2.0);
359 state.update(2, 4.0);
360 state.update(3, 6.0);
361 let mean = state.mean_recent_entropy(3);
362 assert!((mean - 4.0).abs() < 1e-5, "expected 4.0, got {mean}");
363 }
364
365 #[test]
366 fn entropy_cooling_high_entropy_reduces_temp() {
367 let strategy = EntropyCooling::new(1.0);
368 let base = SamplingParams {
369 temperature: 1.0,
370 ..Default::default()
371 };
372 let mut state = GenerationState::new();
373 for _ in 0..8 {
375 state.update(1, 3.0);
376 }
377 let adjusted = strategy.adjust(&state, &base);
378 assert!(
379 adjusted.temperature < base.temperature,
380 "expected temperature to decrease, got {}",
381 adjusted.temperature
382 );
383 }
384
385 #[test]
386 fn entropy_cooling_low_entropy_no_change() {
387 let strategy = EntropyCooling::new(2.0);
388 let base = SamplingParams {
389 temperature: 0.7,
390 ..Default::default()
391 };
392 let mut state = GenerationState::new();
393 for _ in 0..8 {
395 state.update(1, 0.5);
396 }
397 let adjusted = strategy.adjust(&state, &base);
398 assert!(
399 (adjusted.temperature - base.temperature).abs() < 1e-6,
400 "expected no change, got {}",
401 adjusted.temperature
402 );
403 }
404
405 #[test]
406 fn entropy_cooling_min_temp_floor() {
407 let strategy = EntropyCooling {
408 target_entropy: 0.0,
409 cooling_rate: 100.0,
410 min_temperature: 0.05,
411 };
412 let base = SamplingParams {
413 temperature: 1.0,
414 ..Default::default()
415 };
416 let mut state = GenerationState::new();
417 for _ in 0..8 {
418 state.update(1, 5.0);
419 }
420 let adjusted = strategy.adjust(&state, &base);
421 assert!(
422 adjusted.temperature >= 0.05,
423 "temperature below min floor: {}",
424 adjusted.temperature
425 );
426 }
427
428 #[test]
429 fn repetition_adaptation_high_rep_cools() {
430 let strategy = RepetitionAdaptation::new();
431 let base = SamplingParams {
432 temperature: 1.0,
433 ..Default::default()
434 };
435 let mut state = GenerationState::new();
436 for _ in 0..20 {
438 state.update(42, 0.1);
439 }
440 let adjusted = strategy.adjust(&state, &base);
441 assert!(
442 adjusted.temperature < base.temperature,
443 "expected cooling, got {}",
444 adjusted.temperature
445 );
446 }
447
448 #[test]
449 fn repetition_adaptation_low_rep_unchanged() {
450 let strategy = RepetitionAdaptation::new();
451 let base = SamplingParams {
452 temperature: 1.0,
453 ..Default::default()
454 };
455 let mut state = GenerationState::new();
456 for i in 0..5u32 {
458 state.update(i, 1.0);
459 }
460 let adjusted = strategy.adjust(&state, &base);
463 assert!(
465 adjusted.temperature >= base.temperature - 0.01,
466 "unexpected cooling: {}",
467 adjusted.temperature
468 );
469 }
470
471 #[test]
472 fn scheduled_decay_at_step_zero() {
473 let sched = ScheduledDecay::new(1.0, 0.1, 100);
474 assert!((sched.temperature_at_step(0) - 1.0).abs() < 1e-6);
475 }
476
477 #[test]
478 fn scheduled_decay_at_final_step() {
479 let sched = ScheduledDecay::new(1.0, 0.1, 100);
480 assert!((sched.temperature_at_step(100) - 0.1).abs() < 1e-6);
481 }
482
483 #[test]
484 fn scheduled_decay_intermediate() {
485 let sched = ScheduledDecay::new(1.0, 0.0, 100);
486 let mid = sched.temperature_at_step(50);
487 assert!((mid - 0.5).abs() < 1e-5, "expected 0.5, got {mid}");
488 }
489
490 #[test]
491 fn adaptive_chain_empty() {
492 let chain = AdaptiveSamplerChain::new();
493 let base = SamplingParams::default();
494 let state = GenerationState::new();
495 let adjusted = chain.adjust(&state, &base);
496 assert!((adjusted.temperature - base.temperature).abs() < 1e-6);
497 }
498
499 #[test]
500 fn adaptive_chain_applies_all() {
501 let chain = AdaptiveSamplerChain::new()
503 .add(Box::new(ScheduledDecay::new(1.0, 0.0, 100)))
504 .add(Box::new(EntropyCooling::new(0.0)));
505
506 assert_eq!(chain.len(), 2);
507
508 let base = SamplingParams {
509 temperature: 1.0,
510 ..Default::default()
511 };
512 let mut state = GenerationState::new();
513 for _ in 0..50 {
514 state.update(1, 5.0); }
516
517 let adjusted = chain.adjust(&state, &base);
518 assert!(
520 adjusted.temperature < 0.5 + 1e-3,
521 "expected temp <= 0.5, got {}",
522 adjusted.temperature
523 );
524 }
525}