Skip to main content

oxibonsai_runtime/
adaptive_sampling.rs

1//! Adaptive sampling: dynamically adjust temperature/top_p based on generation state.
2//!
3//! Strategies:
4//! - `EntropyCooling`: lower temperature when entropy is too high (reduce randomness)
5//! - `RepetitionAdaptation`: lower temp when repeating, raise when stuck
6//! - `ScheduledDecay`: gradually decay temperature over the sequence
7
8use crate::sampling::SamplingParams;
9
10// ─── GenerationState ───────────────────────────────────────────────────────────
11
12/// Current generation state for adaptive decisions.
13#[derive(Debug, Clone)]
14pub struct GenerationState {
15    /// Current decoding step (0-indexed).
16    pub step: usize,
17    /// Last N generated tokens (ring-buffer style; most recent last).
18    pub recent_tokens: Vec<u32>,
19    /// Shannon entropy (in nats) at each recent step.
20    pub recent_entropies: Vec<f32>,
21    /// Number of consecutive steps where repeated n-grams were detected.
22    pub repetition_count: usize,
23}
24
25impl Default for GenerationState {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl GenerationState {
32    const WINDOW_CAP: usize = 64;
33
34    /// Create a fresh, empty generation state.
35    pub fn new() -> Self {
36        Self {
37            step: 0,
38            recent_tokens: Vec::new(),
39            recent_entropies: Vec::new(),
40            repetition_count: 0,
41        }
42    }
43
44    /// Record a newly generated token and the entropy at this step.
45    pub fn update(&mut self, token: u32, entropy: f32) {
46        self.step += 1;
47
48        self.recent_tokens.push(token);
49        if self.recent_tokens.len() > Self::WINDOW_CAP {
50            self.recent_tokens.remove(0);
51        }
52
53        self.recent_entropies.push(entropy);
54        if self.recent_entropies.len() > Self::WINDOW_CAP {
55            self.recent_entropies.remove(0);
56        }
57
58        // Detect bigram repetition in the recent window.
59        let len = self.recent_tokens.len();
60        if len >= 2 {
61            let last = self.recent_tokens[len - 1];
62            let prev = self.recent_tokens[len - 2];
63            // Check whether the same bigram appeared before in the window.
64            let repeated = self.recent_tokens[..len.saturating_sub(2)]
65                .windows(2)
66                .any(|w| w[0] == prev && w[1] == last);
67            if repeated {
68                self.repetition_count += 1;
69            } else {
70                self.repetition_count = 0;
71            }
72        }
73    }
74
75    /// Fraction of the last `window` tokens that are identical to the immediately
76    /// preceding token (simple unigram repetition rate).
77    pub fn recent_repetition_rate(&self, window: usize) -> f32 {
78        if window == 0 || self.recent_tokens.is_empty() {
79            return 0.0;
80        }
81        let tokens = &self.recent_tokens;
82        let start = tokens.len().saturating_sub(window);
83        let slice = &tokens[start..];
84        if slice.len() < 2 {
85            return 0.0;
86        }
87        let repeats = slice.windows(2).filter(|w| w[0] == w[1]).count();
88        repeats as f32 / (slice.len() - 1) as f32
89    }
90
91    /// Mean entropy over the last `window` steps.
92    pub fn mean_recent_entropy(&self, window: usize) -> f32 {
93        if window == 0 || self.recent_entropies.is_empty() {
94            return 0.0;
95        }
96        let start = self.recent_entropies.len().saturating_sub(window);
97        let slice = &self.recent_entropies[start..];
98        if slice.is_empty() {
99            return 0.0;
100        }
101        slice.iter().sum::<f32>() / slice.len() as f32
102    }
103}
104
105// ─── AdaptiveStrategy ──────────────────────────────────────────────────────────
106
107/// Adaptive sampling strategy.
108pub trait AdaptiveStrategy: Send + Sync {
109    /// Given the current generation state and base params, return adjusted params.
110    fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams;
111    /// Human-readable name of this strategy.
112    fn name(&self) -> &'static str;
113}
114
115// ─── EntropyCooling ────────────────────────────────────────────────────────────
116
117/// Lower temperature when entropy is too high (generation is too random).
118///
119/// When `mean_entropy > target_entropy`, temperature is scaled down by
120/// `cooling_rate * excess_ratio`, clamped to `[min_temperature, base_temp]`.
121pub struct EntropyCooling {
122    /// Entropy level above which cooling begins (in nats).
123    pub target_entropy: f32,
124    /// Fraction of the excess entropy translated into temperature reduction (0..1).
125    pub cooling_rate: f32,
126    /// Minimum temperature floor.
127    pub min_temperature: f32,
128}
129
130impl EntropyCooling {
131    /// Create with sensible defaults.
132    pub fn new(target_entropy: f32) -> Self {
133        Self {
134            target_entropy,
135            cooling_rate: 0.5,
136            min_temperature: 0.1,
137        }
138    }
139}
140
141impl AdaptiveStrategy for EntropyCooling {
142    fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
143        let mut params = base.clone();
144        let window = 8.min(state.recent_entropies.len().max(1));
145        let mean_entropy = state.mean_recent_entropy(window);
146
147        if mean_entropy > self.target_entropy {
148            let excess = mean_entropy - self.target_entropy;
149            // Reduce temperature proportionally to excess entropy.
150            let reduction = self.cooling_rate * excess;
151            let new_temp = (base.temperature - reduction).max(self.min_temperature);
152            params.temperature = new_temp;
153        }
154
155        params
156    }
157
158    fn name(&self) -> &'static str {
159        "EntropyCooling"
160    }
161}
162
163// ─── RepetitionAdaptation ─────────────────────────────────────────────────────
164
165/// Adapt temperature based on repetition rate.
166///
167/// - High repetition → cool down (reduce temperature) to break out of loops.
168/// - Low repetition with high entropy → heat up slightly to encourage diversity.
169pub struct RepetitionAdaptation {
170    /// Repetition rate above which cooling is applied (0..1).
171    pub rep_threshold: f32,
172    /// Multiply temperature by this factor when repeating (< 1.0 to cool).
173    pub cool_factor: f32,
174    /// Multiply temperature by this factor when stuck (> 1.0 to heat).
175    pub heat_factor: f32,
176}
177
178impl Default for RepetitionAdaptation {
179    fn default() -> Self {
180        Self::new()
181    }
182}
183
184impl RepetitionAdaptation {
185    /// Create with sensible defaults.
186    pub fn new() -> Self {
187        Self {
188            rep_threshold: 0.3,
189            cool_factor: 0.8,
190            heat_factor: 1.1,
191        }
192    }
193}
194
195impl AdaptiveStrategy for RepetitionAdaptation {
196    fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
197        let mut params = base.clone();
198        let window = 16.min(state.recent_tokens.len().max(1));
199        let rep_rate = state.recent_repetition_rate(window);
200
201        if rep_rate > self.rep_threshold {
202            params.temperature = (base.temperature * self.cool_factor).max(0.01);
203        } else if rep_rate < self.rep_threshold / 2.0 && state.step > 4 {
204            // Very low repetition — gentle heating to encourage variety.
205            params.temperature = (base.temperature * self.heat_factor).min(2.0);
206        }
207
208        params
209    }
210
211    fn name(&self) -> &'static str {
212        "RepetitionAdaptation"
213    }
214}
215
216// ─── ScheduledDecay ────────────────────────────────────────────────────────────
217
218/// Linearly decay temperature from `initial_temperature` to `final_temperature`
219/// over `total_steps` decoding steps.
220pub struct ScheduledDecay {
221    /// Starting temperature (at step 0).
222    pub initial_temperature: f32,
223    /// Ending temperature (at step >= total_steps).
224    pub final_temperature: f32,
225    /// Number of steps over which to interpolate.
226    pub total_steps: usize,
227}
228
229impl ScheduledDecay {
230    /// Create a new scheduled decay.
231    pub fn new(initial: f32, final_temp: f32, steps: usize) -> Self {
232        Self {
233            initial_temperature: initial,
234            final_temperature: final_temp,
235            total_steps: steps,
236        }
237    }
238
239    /// Return the interpolated temperature at the given absolute step.
240    pub fn temperature_at_step(&self, step: usize) -> f32 {
241        if self.total_steps == 0 {
242            return self.final_temperature;
243        }
244        let t = (step as f32 / self.total_steps as f32).min(1.0);
245        self.initial_temperature + t * (self.final_temperature - self.initial_temperature)
246    }
247}
248
249impl AdaptiveStrategy for ScheduledDecay {
250    fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
251        let mut params = base.clone();
252        params.temperature = self.temperature_at_step(state.step);
253        params
254    }
255
256    fn name(&self) -> &'static str {
257        "ScheduledDecay"
258    }
259}
260
261// ─── AdaptiveSamplerChain ─────────────────────────────────────────────────────
262
263/// Compose multiple adaptive strategies by applying them in sequence.
264///
265/// Each strategy sees the result of the previous one's adjustment.
266pub struct AdaptiveSamplerChain {
267    strategies: Vec<Box<dyn AdaptiveStrategy>>,
268}
269
270impl Default for AdaptiveSamplerChain {
271    fn default() -> Self {
272        Self::new()
273    }
274}
275
276impl AdaptiveSamplerChain {
277    /// Create an empty chain.
278    pub fn new() -> Self {
279        Self {
280            strategies: Vec::new(),
281        }
282    }
283
284    /// Append a strategy (builder pattern).
285    #[allow(clippy::should_implement_trait)]
286    pub fn add(mut self, strategy: Box<dyn AdaptiveStrategy>) -> Self {
287        self.strategies.push(strategy);
288        self
289    }
290
291    /// Apply all strategies in order, threading params through each.
292    pub fn adjust(&self, state: &GenerationState, base: &SamplingParams) -> SamplingParams {
293        self.strategies
294            .iter()
295            .fold(base.clone(), |params, strategy| {
296                strategy.adjust(state, &params)
297            })
298    }
299
300    /// Number of strategies in this chain.
301    pub fn len(&self) -> usize {
302        self.strategies.len()
303    }
304
305    /// Whether this chain has no strategies.
306    pub fn is_empty(&self) -> bool {
307        self.strategies.is_empty()
308    }
309}
310
311// ─── Tests ─────────────────────────────────────────────────────────────────────
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn generation_state_new_empty() {
319        let state = GenerationState::new();
320        assert_eq!(state.step, 0);
321        assert!(state.recent_tokens.is_empty());
322        assert!(state.recent_entropies.is_empty());
323        assert_eq!(state.repetition_count, 0);
324    }
325
326    #[test]
327    fn generation_state_update() {
328        let mut state = GenerationState::new();
329        state.update(42, 1.5);
330        assert_eq!(state.step, 1);
331        assert_eq!(state.recent_tokens, vec![42]);
332        assert!((state.recent_entropies[0] - 1.5).abs() < 1e-6);
333    }
334
335    #[test]
336    fn generation_state_repetition_rate_no_rep() {
337        let mut state = GenerationState::new();
338        for tok in [1u32, 2, 3, 4, 5] {
339            state.update(tok, 1.0);
340        }
341        let rate = state.recent_repetition_rate(5);
342        assert!((rate - 0.0).abs() < 1e-6);
343    }
344
345    #[test]
346    fn generation_state_repetition_rate_all_same() {
347        let mut state = GenerationState::new();
348        for _ in 0..5 {
349            state.update(7, 1.0);
350        }
351        let rate = state.recent_repetition_rate(5);
352        assert!(rate > 0.5, "expected high repetition rate, got {rate}");
353    }
354
355    #[test]
356    fn generation_state_mean_entropy() {
357        let mut state = GenerationState::new();
358        state.update(1, 2.0);
359        state.update(2, 4.0);
360        state.update(3, 6.0);
361        let mean = state.mean_recent_entropy(3);
362        assert!((mean - 4.0).abs() < 1e-5, "expected 4.0, got {mean}");
363    }
364
365    #[test]
366    fn entropy_cooling_high_entropy_reduces_temp() {
367        let strategy = EntropyCooling::new(1.0);
368        let base = SamplingParams {
369            temperature: 1.0,
370            ..Default::default()
371        };
372        let mut state = GenerationState::new();
373        // High entropy — well above target of 1.0
374        for _ in 0..8 {
375            state.update(1, 3.0);
376        }
377        let adjusted = strategy.adjust(&state, &base);
378        assert!(
379            adjusted.temperature < base.temperature,
380            "expected temperature to decrease, got {}",
381            adjusted.temperature
382        );
383    }
384
385    #[test]
386    fn entropy_cooling_low_entropy_no_change() {
387        let strategy = EntropyCooling::new(2.0);
388        let base = SamplingParams {
389            temperature: 0.7,
390            ..Default::default()
391        };
392        let mut state = GenerationState::new();
393        // Low entropy — below target of 2.0
394        for _ in 0..8 {
395            state.update(1, 0.5);
396        }
397        let adjusted = strategy.adjust(&state, &base);
398        assert!(
399            (adjusted.temperature - base.temperature).abs() < 1e-6,
400            "expected no change, got {}",
401            adjusted.temperature
402        );
403    }
404
405    #[test]
406    fn entropy_cooling_min_temp_floor() {
407        let strategy = EntropyCooling {
408            target_entropy: 0.0,
409            cooling_rate: 100.0,
410            min_temperature: 0.05,
411        };
412        let base = SamplingParams {
413            temperature: 1.0,
414            ..Default::default()
415        };
416        let mut state = GenerationState::new();
417        for _ in 0..8 {
418            state.update(1, 5.0);
419        }
420        let adjusted = strategy.adjust(&state, &base);
421        assert!(
422            adjusted.temperature >= 0.05,
423            "temperature below min floor: {}",
424            adjusted.temperature
425        );
426    }
427
428    #[test]
429    fn repetition_adaptation_high_rep_cools() {
430        let strategy = RepetitionAdaptation::new();
431        let base = SamplingParams {
432            temperature: 1.0,
433            ..Default::default()
434        };
435        let mut state = GenerationState::new();
436        // Repeated same token many times
437        for _ in 0..20 {
438            state.update(42, 0.1);
439        }
440        let adjusted = strategy.adjust(&state, &base);
441        assert!(
442            adjusted.temperature < base.temperature,
443            "expected cooling, got {}",
444            adjusted.temperature
445        );
446    }
447
448    #[test]
449    fn repetition_adaptation_low_rep_unchanged() {
450        let strategy = RepetitionAdaptation::new();
451        let base = SamplingParams {
452            temperature: 1.0,
453            ..Default::default()
454        };
455        let mut state = GenerationState::new();
456        // Unique tokens only
457        for i in 0..5u32 {
458            state.update(i, 1.0);
459        }
460        // rep_rate = 0 < rep_threshold/2 but step=5, heat_factor applies
461        // We just verify it doesn't go below base.
462        let adjusted = strategy.adjust(&state, &base);
463        // Either unchanged or slightly heated — must not cool.
464        assert!(
465            adjusted.temperature >= base.temperature - 0.01,
466            "unexpected cooling: {}",
467            adjusted.temperature
468        );
469    }
470
471    #[test]
472    fn scheduled_decay_at_step_zero() {
473        let sched = ScheduledDecay::new(1.0, 0.1, 100);
474        assert!((sched.temperature_at_step(0) - 1.0).abs() < 1e-6);
475    }
476
477    #[test]
478    fn scheduled_decay_at_final_step() {
479        let sched = ScheduledDecay::new(1.0, 0.1, 100);
480        assert!((sched.temperature_at_step(100) - 0.1).abs() < 1e-6);
481    }
482
483    #[test]
484    fn scheduled_decay_intermediate() {
485        let sched = ScheduledDecay::new(1.0, 0.0, 100);
486        let mid = sched.temperature_at_step(50);
487        assert!((mid - 0.5).abs() < 1e-5, "expected 0.5, got {mid}");
488    }
489
490    #[test]
491    fn adaptive_chain_empty() {
492        let chain = AdaptiveSamplerChain::new();
493        let base = SamplingParams::default();
494        let state = GenerationState::new();
495        let adjusted = chain.adjust(&state, &base);
496        assert!((adjusted.temperature - base.temperature).abs() < 1e-6);
497    }
498
499    #[test]
500    fn adaptive_chain_applies_all() {
501        // ScheduledDecay brings temp to 0.5 at step 50, then EntropyCooling may lower it further.
502        let chain = AdaptiveSamplerChain::new()
503            .add(Box::new(ScheduledDecay::new(1.0, 0.0, 100)))
504            .add(Box::new(EntropyCooling::new(0.0)));
505
506        assert_eq!(chain.len(), 2);
507
508        let base = SamplingParams {
509            temperature: 1.0,
510            ..Default::default()
511        };
512        let mut state = GenerationState::new();
513        for _ in 0..50 {
514            state.update(1, 5.0); // high entropy
515        }
516
517        let adjusted = chain.adjust(&state, &base);
518        // After ScheduledDecay at step=50: temp=0.5. EntropyCooling lowers further.
519        assert!(
520            adjusted.temperature < 0.5 + 1e-3,
521            "expected temp <= 0.5, got {}",
522            adjusted.temperature
523        );
524    }
525}