use crate::error::Result;
#[derive(Clone, Debug)]
pub struct EbParams {
pub max_denoising_steps: u32,
pub t_min: f32,
pub t_max: f32,
pub entropy_bound: f32,
pub stability_threshold: u32,
pub confidence_threshold: f32,
}
impl Default for EbParams {
fn default() -> Self {
Self {
max_denoising_steps: 48,
t_min: 0.4,
t_max: 0.8,
entropy_bound: 0.1,
stability_threshold: 1,
confidence_threshold: 0.005,
}
}
}
pub trait SamplerRng {
fn uniform01(&mut self) -> f32;
fn token(&mut self, n_vocab: u32) -> u32;
}
pub struct XorShiftRng(pub u64);
impl SamplerRng for XorShiftRng {
fn uniform01(&mut self) -> f32 {
self.0 ^= self.0 << 13;
self.0 ^= self.0 >> 7;
self.0 ^= self.0 << 17;
((self.0 >> 40) as f32) / ((1u64 << 24) as f32)
}
fn token(&mut self, n_vocab: u32) -> u32 {
(self.uniform01() * n_vocab as f32) as u32 % n_vocab
}
}
pub struct StepInfo<'a> {
pub step_idx: u32,
pub total_steps: u32,
pub argmax_canvas: &'a [u32],
pub mean_entropy: f32,
pub n_accepted: usize,
}
pub trait CanvasForward {
fn forward(&mut self, canvas: &[u32], prev: Option<(&[f32], f32)>) -> Result<Vec<f32>>;
fn n_vocab(&self) -> usize;
}
pub struct StepOutcome {
pub step_idx: u32,
pub total_steps: u32,
pub mean_entropy: f32,
pub n_accepted: usize,
pub done: bool,
}
pub struct DenoiseState {
params: EbParams,
canvas_len: usize,
n_vocab: usize,
s_total: u32,
cur_step: u32,
current: Vec<u32>,
argmax_canvas: Vec<u32>,
prev_argmax: Vec<u32>,
prev_logits: Option<Vec<f32>>,
prev_temp_inv: f32,
held: u32,
finished: bool,
}
impl DenoiseState {
pub fn new(
canvas_len: usize,
n_vocab: usize,
params: EbParams,
rng: &mut dyn SamplerRng,
) -> Self {
let s_total = params.max_denoising_steps.max(1);
let current: Vec<u32> = (0..canvas_len).map(|_| rng.token(n_vocab as u32)).collect();
Self {
params,
canvas_len,
n_vocab,
s_total,
cur_step: s_total,
current,
argmax_canvas: vec![0u32; canvas_len],
prev_argmax: vec![u32::MAX; canvas_len], prev_logits: None,
prev_temp_inv: 1.0,
held: 0,
finished: false,
}
}
pub fn is_done(&self) -> bool {
self.finished || self.cur_step == 0
}
pub fn forward_inputs(&self) -> (&[u32], Option<(&[f32], f32)>) {
(
&self.current,
self.prev_logits.as_deref().map(|l| (l, self.prev_temp_inv)),
)
}
pub fn argmax_canvas(&self) -> &[u32] {
&self.argmax_canvas
}
pub fn input_canvas(&self) -> Vec<u32> {
self.current.clone()
}
pub fn take_prev(&mut self) -> Option<(Vec<f32>, f32)> {
let ti = self.prev_temp_inv;
self.prev_logits.take().map(|l| (l, ti))
}
pub fn ingest(&mut self, logits: Vec<f32>, rng: &mut dyn SamplerRng) -> StepOutcome {
let canvas_len = self.canvas_len;
let n_vocab = self.n_vocab;
debug_assert_eq!(logits.len(), canvas_len * n_vocab);
let step_idx = self.s_total - self.cur_step;
let t = self.params.t_min
+ (self.params.t_max - self.params.t_min)
* (self.cur_step as f32 / self.s_total as f32);
let temp_inv = 1.0 / t;
let us: Vec<f32> = (0..canvas_len).map(|_| rng.uniform01()).collect();
let renoise: Vec<u32> = (0..canvas_len).map(|_| rng.token(n_vocab as u32)).collect();
let mut entropy = vec![0f32; canvas_len];
let mut denoised = vec![0u32; canvas_len];
for pos in 0..canvas_len {
let row = &logits[pos * n_vocab..(pos + 1) * n_vocab];
let mut m = f32::NEG_INFINITY;
let mut amax = 0usize;
for (v, &z) in row.iter().enumerate() {
let zt = z * temp_inv;
if zt > m {
m = zt;
amax = v;
}
}
let mut z_sum = 0f32;
for &z in row {
z_sum += (z * temp_inv - m).exp();
}
let target = us[pos] * z_sum;
let mut cum = 0f32;
let mut h = 0f32;
let mut sampled = n_vocab - 1;
let mut picked = false;
for (v, &z) in row.iter().enumerate() {
let e = (z * temp_inv - m).exp();
let p = e / z_sum;
if p > 0.0 {
h -= p * p.ln();
}
cum += e;
if !picked && cum >= target {
sampled = v;
picked = true;
}
}
entropy[pos] = h;
self.argmax_canvas[pos] = amax as u32;
denoised[pos] = sampled as u32;
}
let mut order: Vec<usize> = (0..canvas_len).collect();
order.sort_by(|&a, &b| entropy[a].partial_cmp(&entropy[b]).unwrap().then(a.cmp(&b)));
let mut accepted = vec![false; canvas_len];
let mut cum_e = 0f64;
for &pos in &order {
if cum_e <= self.params.entropy_bound as f64 {
accepted[pos] = true;
}
cum_e += entropy[pos] as f64;
}
let mut entropy_sum = 0f32;
let mut n_accepted = 0usize;
for pos in 0..canvas_len {
self.current[pos] = if accepted[pos] {
n_accepted += 1;
denoised[pos]
} else {
renoise[pos]
};
entropy_sum += entropy[pos];
}
self.held = if self.prev_argmax == self.argmax_canvas {
self.held + 1
} else {
0
};
let mean_entropy = entropy_sum / canvas_len as f32;
let confident = mean_entropy < self.params.confidence_threshold;
self.prev_argmax.copy_from_slice(&self.argmax_canvas);
self.prev_logits = Some(logits);
self.prev_temp_inv = temp_inv;
self.cur_step -= 1;
if (self.held >= self.params.stability_threshold && confident) || self.cur_step == 0 {
self.finished = true;
}
StepOutcome {
step_idx,
total_steps: self.s_total,
mean_entropy,
n_accepted,
done: self.finished,
}
}
}
pub fn generate_entropy_bound(
model: &mut dyn CanvasForward,
canvas_len: usize,
params: &EbParams,
rng: &mut dyn SamplerRng,
mut step_cb: Option<&mut dyn FnMut(&StepInfo) -> bool>,
) -> Result<Vec<u32>> {
let n_vocab = model.n_vocab();
let mut state = DenoiseState::new(canvas_len, n_vocab, params.clone(), rng);
while !state.is_done() {
let logits = {
let (canvas, prev) = state.forward_inputs();
model.forward(canvas, prev)?
};
let outcome = state.ingest(logits, rng);
if let Some(cb) = step_cb.as_deref_mut() {
let info = StepInfo {
step_idx: outcome.step_idx,
total_steps: outcome.total_steps,
argmax_canvas: state.argmax_canvas(),
mean_entropy: outcome.mean_entropy,
n_accepted: outcome.n_accepted,
};
if !cb(&info) {
break;
}
}
if outcome.done {
break;
}
}
Ok(state.argmax_canvas().to_vec())
}
#[cfg(test)]
mod tests {
use super::*;
struct FixedTarget {
target: Vec<u32>,
n_vocab: usize,
calls: u32,
saw_self_cond: bool,
}
impl CanvasForward for FixedTarget {
fn forward(&mut self, canvas: &[u32], prev: Option<(&[f32], f32)>) -> Result<Vec<f32>> {
assert_eq!(canvas.len(), self.target.len());
if self.calls == 0 {
assert!(prev.is_none(), "step 0 must not self-condition");
} else {
self.saw_self_cond |= prev.is_some();
}
self.calls += 1;
let mut out = vec![0f32; canvas.len() * self.n_vocab];
for (pos, &t) in self.target.iter().enumerate() {
out[pos * self.n_vocab + t as usize] = 50.0; }
Ok(out)
}
fn n_vocab(&self) -> usize {
self.n_vocab
}
}
#[test]
fn converges_to_confident_target_and_stops_early() {
let target = vec![3u32, 1, 4, 1, 5, 9, 2, 6];
let mut model = FixedTarget {
target: target.clone(),
n_vocab: 11,
calls: 0,
saw_self_cond: false,
};
let mut rng = XorShiftRng(0x5EED);
let mut steps_seen = 0u32;
let out = generate_entropy_bound(
&mut model,
target.len(),
&EbParams::default(),
&mut rng,
Some(&mut |info: &StepInfo| {
steps_seen = info.step_idx + 1;
true
}),
)
.unwrap();
assert_eq!(out, target, "argmax canvas must converge to the target");
assert!(
model.calls <= 4,
"expected early stop, ran {} steps",
model.calls
);
assert!(model.saw_self_cond, "later steps must pass prev logits");
}
#[test]
fn manual_step_loop_matches_driver() {
let target = vec![3u32, 1, 4, 1, 5, 9, 2, 6];
let params = EbParams {
max_denoising_steps: 12,
..Default::default()
};
let mut m1 = FixedTarget {
target: target.clone(),
n_vocab: 11,
calls: 0,
saw_self_cond: false,
};
let mut rng1 = XorShiftRng(0xABCD);
let driver =
generate_entropy_bound(&mut m1, target.len(), ¶ms, &mut rng1, None).unwrap();
let mut m2 = FixedTarget {
target: target.clone(),
n_vocab: 11,
calls: 0,
saw_self_cond: false,
};
let mut rng2 = XorShiftRng(0xABCD);
let mut st = DenoiseState::new(target.len(), 11, params.clone(), &mut rng2);
while !st.is_done() {
let canvas = st.input_canvas();
let prev = st.take_prev();
let logits = m2
.forward(&canvas, prev.as_ref().map(|(l, t)| (l.as_slice(), *t)))
.unwrap();
st.ingest(logits, &mut rng2);
}
let manual = st.argmax_canvas().to_vec();
assert_eq!(driver, manual, "step-driven loop must match the driver");
assert_eq!(driver, target);
}
struct UniformModel {
n_vocab: usize,
calls: u32,
}
impl CanvasForward for UniformModel {
fn forward(&mut self, canvas: &[u32], _prev: Option<(&[f32], f32)>) -> Result<Vec<f32>> {
self.calls += 1;
Ok(vec![0f32; canvas.len() * self.n_vocab])
}
fn n_vocab(&self) -> usize {
self.n_vocab
}
}
#[test]
fn uniform_logits_run_to_step_cap() {
let mut model = UniformModel {
n_vocab: 7,
calls: 0,
};
let mut rng = XorShiftRng(42);
let params = EbParams {
max_denoising_steps: 5,
..Default::default()
};
let _ = generate_entropy_bound(&mut model, 4, ¶ms, &mut rng, None).unwrap();
assert_eq!(model.calls, 5, "no early stop without confidence");
}
#[test]
fn acceptance_respects_entropy_budget_ordering() {
struct TwoTok {
confidences: Vec<f32>, }
impl CanvasForward for TwoTok {
fn forward(&mut self, _c: &[u32], _p: Option<(&[f32], f32)>) -> Result<Vec<f32>> {
let mut out = Vec::new();
for &p in &self.confidences {
let l0 = (p / (1.0 - p)).ln();
out.extend_from_slice(&[l0, 0.0]);
}
Ok(out)
}
fn n_vocab(&self) -> usize {
2
}
}
let mut model = TwoTok {
confidences: vec![0.999_999, 0.6, 0.999_99, 0.55],
};
let mut rng = XorShiftRng(7);
let params = EbParams {
max_denoising_steps: 1, ..Default::default()
};
let mut accepted_count = 0usize;
let _ = generate_entropy_bound(
&mut model,
4,
¶ms,
&mut rng,
Some(&mut |info: &StepInfo| {
accepted_count = info.n_accepted;
true
}),
)
.unwrap();
assert_eq!(
accepted_count, 3,
"entropy budget must cut the least-confident position"
);
}
}