#![allow(clippy::needless_range_loop)]
use crate::model::Decoder;
use crate::tree::DraftTree;
use crate::{Error, Result};
#[derive(Debug, Clone)]
pub struct MedusaConfig {
pub n_heads: usize,
pub hidden_size: usize,
pub vocab_size: usize,
pub residual_layers: usize,
}
impl MedusaConfig {
pub fn vicuna_7b_defaults() -> Self {
Self {
n_heads: 4,
hidden_size: 4096,
vocab_size: 32000,
residual_layers: 1,
}
}
}
#[derive(Debug, Clone)]
pub struct MedusaHead {
pub offset: usize,
}
#[derive(Debug, Clone)]
pub struct MedusaHeads {
config: MedusaConfig,
heads: Vec<MedusaHead>,
}
impl MedusaHeads {
pub fn from_config(config: MedusaConfig) -> Self {
let heads = (1..=config.n_heads)
.map(|offset| MedusaHead { offset })
.collect();
Self { config, heads }
}
pub fn len(&self) -> usize {
self.heads.len()
}
pub fn is_empty(&self) -> bool {
self.heads.is_empty()
}
pub fn config(&self) -> &MedusaConfig {
&self.config
}
pub fn build_draft_tree(
&self,
committed_root: u32,
head_top_k: &[Vec<u32>],
topology: TreeTopology,
) -> Result<DraftTree> {
if head_top_k.len() != self.heads.len() {
return Err(Error::Sampling(format!(
"head_top_k has {} entries, expected {} (one per head)",
head_top_k.len(),
self.heads.len()
)));
}
for (h, candidates) in head_top_k.iter().enumerate() {
if candidates.is_empty() {
return Err(Error::Sampling(format!("head {h} has no candidates")));
}
}
match topology {
TreeTopology::Greedy => Ok(build_greedy_chain(committed_root, head_top_k)),
TreeTopology::CartesianProduct => Ok(build_cartesian_tree(committed_root, head_top_k)),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TreeTopology {
Greedy,
CartesianProduct,
}
fn build_greedy_chain(root: u32, head_top_k: &[Vec<u32>]) -> DraftTree {
let chain: Vec<u32> = head_top_k.iter().map(|cands| cands[0]).collect();
DraftTree::linear(root, &chain)
}
fn build_cartesian_tree(root: u32, head_top_k: &[Vec<u32>]) -> DraftTree {
let mut nodes: Vec<(usize, u32)> = vec![(0, root)];
let mut prev_layer_indices: Vec<usize> = vec![0];
for cands in head_top_k {
let mut next_layer_indices = Vec::with_capacity(prev_layer_indices.len() * cands.len());
for &parent_idx in &prev_layer_indices {
for &cand in cands {
let new_idx = nodes.len();
nodes.push((parent_idx, cand));
next_layer_indices.push(new_idx);
}
}
prev_layer_indices = next_layer_indices;
}
DraftTree::from_parent_table(&nodes).expect("Cartesian builder produces valid tree")
}
#[derive(Debug, Clone)]
pub struct MedusaRunConfig {
pub topology: TreeTopology,
pub top_k_per_head: usize,
pub acceptance: Acceptance,
}
impl Default for MedusaRunConfig {
fn default() -> Self {
Self {
topology: TreeTopology::CartesianProduct,
top_k_per_head: 2,
acceptance: Acceptance::Greedy,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Acceptance {
Greedy,
Typical {
epsilon: f32,
delta: f32,
},
}
pub type HeadDraftFn = Box<dyn FnMut(&[u32]) -> Vec<Vec<u32>>>;
pub fn run_medusa<T, R>(
target: &mut T,
heads: &MedusaHeads,
mut head_draft: HeadDraftFn,
prompt: &[u32],
max_new_tokens: usize,
config: &MedusaRunConfig,
rng: &mut R,
) -> Result<Vec<u32>>
where
T: Decoder + ?Sized,
R: rand::Rng + ?Sized,
{
target.reset();
target.observe(prompt)?;
let mut generated: Vec<u32> = Vec::with_capacity(max_new_tokens);
while generated.len() < max_new_tokens {
let root_token = *target
.history()
.last()
.ok_or_else(|| Error::Sampling("Medusa requires non-empty prompt".into()))?;
let head_top_k = head_draft(target.history());
let tree = heads.build_draft_tree(root_token, &head_top_k, config.topology)?;
let pre_target_len = target.history_len();
let per_node_logits = evaluate_tree(target, &tree, pre_target_len)?;
let mut best_path: Vec<usize> = vec![0]; for path in tree.paths() {
let accepted_len =
walk_and_accept(&path, &tree, &per_node_logits, &config.acceptance, rng);
if accepted_len + 1 > best_path.len() {
best_path = path[..=accepted_len].to_vec();
}
}
let mut committed: Vec<u32> = best_path
.iter()
.skip(1)
.map(|&i| tree.token_at(i))
.collect();
let deepest_idx = *best_path.last().unwrap();
if generated.len() + committed.len() < max_new_tokens {
let bonus_logits = &per_node_logits[deepest_idx];
let bonus = sample_argmax_or_categorical(bonus_logits, &config.acceptance, rng)?;
committed.push(bonus);
}
debug_assert_eq!(target.history_len(), pre_target_len);
target.observe(&committed)?;
generated.extend_from_slice(&committed);
if committed.is_empty() {
return Err(Error::Sampling(
"Medusa round committed zero tokens — would loop forever".into(),
));
}
}
generated.truncate(max_new_tokens);
Ok(generated)
}
fn evaluate_tree<T: Decoder + ?Sized>(
target: &mut T,
tree: &DraftTree,
pre_target_len: usize,
) -> Result<Vec<Vec<f32>>> {
let n = tree.len();
let mut out: Vec<Vec<f32>> = vec![Vec::new(); n];
out[0] = target.next_logits()?;
for i in 1..n {
let path = tree.path_to(i); let path_tokens_after_root: Vec<u32> =
path.iter().skip(1).map(|&idx| tree.token_at(idx)).collect();
target.observe(&path_tokens_after_root)?;
out[i] = target.next_logits()?;
target.rollback_to(pre_target_len)?;
}
Ok(out)
}
fn walk_and_accept<R: rand::Rng + ?Sized>(
path: &[usize],
tree: &DraftTree,
per_node_logits: &[Vec<f32>],
acceptance: &Acceptance,
rng: &mut R,
) -> usize {
let _ = rng; let mut accepted = 0;
for w in path.windows(2) {
let parent = w[0];
let child = w[1];
let candidate = tree.token_at(child);
let parent_dist = &per_node_logits[parent];
if accept_one(parent_dist, candidate, acceptance) {
accepted += 1;
} else {
break;
}
}
accepted
}
fn accept_one(target_logits: &[f32], candidate: u32, acceptance: &Acceptance) -> bool {
match acceptance {
Acceptance::Greedy => {
let argmax = target_logits
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv {
(i, v)
} else {
(bi, bv)
}
})
.0;
argmax == candidate as usize
}
Acceptance::Typical { epsilon, delta } => {
let max = target_logits
.iter()
.cloned()
.fold(f32::NEG_INFINITY, f32::max);
let exps: Vec<f32> = target_logits.iter().map(|&l| (l - max).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum <= 0.0 || !sum.is_finite() {
return false;
}
let probs: Vec<f32> = exps.iter().map(|p| p / sum).collect();
let entropy: f32 = probs
.iter()
.filter(|&&p| p > 0.0)
.map(|&p| -p * p.ln())
.sum();
let threshold = epsilon.max(delta * (-entropy).exp());
probs[candidate as usize] >= threshold
}
}
}
fn sample_argmax_or_categorical<R: rand::Rng + ?Sized>(
logits: &[f32],
acceptance: &Acceptance,
_rng: &mut R,
) -> Result<u32> {
if logits.is_empty() {
return Err(Error::Sampling("empty logits for bonus token".into()));
}
let _ = acceptance;
let argmax = logits
.iter()
.enumerate()
.fold((0usize, f32::NEG_INFINITY), |(bi, bv), (i, &v)| {
if v > bv {
(i, v)
} else {
(bi, bv)
}
})
.0;
Ok(argmax as u32)
}
pub fn top_k_indices(logits: &[f32], k: usize) -> Vec<usize> {
let k = k.min(logits.len());
let mut indexed: Vec<(usize, f32)> = logits.iter().enumerate().map(|(i, &v)| (i, v)).collect();
indexed.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.0.cmp(&b.0))
});
indexed.into_iter().take(k).map(|(i, _)| i).collect()
}
use crate::model::TreeDecoder;
use candle_core::{DType, Device, Module, Tensor};
use candle_nn::{linear, linear_no_bias, Linear, VarBuilder};
use std::path::Path;
#[derive(Debug, Clone)]
pub struct MedusaHeadModule {
res_blocks: Vec<Linear>,
output_proj: Linear,
}
impl MedusaHeadModule {
pub fn from_var_builder(cfg: &MedusaConfig, vb: VarBuilder<'_>) -> Result<Self> {
let mut res_blocks = Vec::with_capacity(cfg.residual_layers);
for i in 0..cfg.residual_layers {
let l = linear(cfg.hidden_size, cfg.hidden_size, vb.pp(format!("res.{i}")))
.map_err(Error::Candle)?;
res_blocks.push(l);
}
let output_proj = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("output"))
.map_err(Error::Candle)?;
Ok(Self {
res_blocks,
output_proj,
})
}
pub fn from_random(cfg: &MedusaConfig, device: &Device, dtype: DType) -> Result<Self> {
let mut res_blocks = Vec::with_capacity(cfg.residual_layers);
for _ in 0..cfg.residual_layers {
let w = Tensor::randn(0f32, 0.02, (cfg.hidden_size, cfg.hidden_size), device)
.map_err(Error::Candle)?
.to_dtype(dtype)
.map_err(Error::Candle)?;
let b = Tensor::zeros(cfg.hidden_size, dtype, device).map_err(Error::Candle)?;
res_blocks.push(Linear::new(w, Some(b)));
}
let w = Tensor::randn(0f32, 0.02, (cfg.vocab_size, cfg.hidden_size), device)
.map_err(Error::Candle)?
.to_dtype(dtype)
.map_err(Error::Candle)?;
let output_proj = Linear::new(w, None);
Ok(Self {
res_blocks,
output_proj,
})
}
pub fn forward(&self, hidden: &Tensor) -> Result<Tensor> {
let needs_squeeze = hidden.dims().len() == 1;
let mut x = if needs_squeeze {
hidden.unsqueeze(0).map_err(Error::Candle)?
} else {
hidden.clone()
};
for rb in &self.res_blocks {
let y = candle_nn::ops::silu(&rb.forward(&x).map_err(Error::Candle)?)
.map_err(Error::Candle)?;
x = (y + &x).map_err(Error::Candle)?;
}
let logits = self.output_proj.forward(&x).map_err(Error::Candle)?;
if needs_squeeze {
logits.squeeze(0).map_err(Error::Candle)
} else {
Ok(logits)
}
}
}
#[derive(Debug, Clone)]
pub struct MedusaHeadsCandle {
config: MedusaConfig,
heads: Vec<MedusaHeadModule>,
}
impl MedusaHeadsCandle {
pub fn from_random(cfg: &MedusaConfig, device: &Device, dtype: DType) -> Result<Self> {
let mut heads = Vec::with_capacity(cfg.n_heads);
for _ in 0..cfg.n_heads {
heads.push(MedusaHeadModule::from_random(cfg, device, dtype)?);
}
Ok(Self {
config: cfg.clone(),
heads,
})
}
pub fn from_safetensors(
cfg: &MedusaConfig,
paths: &[impl AsRef<Path>],
device: &Device,
dtype: DType,
) -> Result<Self> {
let owned: Vec<_> = paths.iter().map(|p| p.as_ref().to_path_buf()).collect();
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&owned, dtype, device).map_err(Error::Candle)?
};
let mut heads = Vec::with_capacity(cfg.n_heads);
for i in 0..cfg.n_heads {
heads.push(MedusaHeadModule::from_var_builder(
cfg,
vb.pp(format!("medusa_head.{i}")),
)?);
}
Ok(Self {
config: cfg.clone(),
heads,
})
}
pub fn from_fasterdecoding_pt(
cfg: &MedusaConfig,
path: impl AsRef<Path>,
device: &Device,
dtype: DType,
) -> Result<Self> {
let vb = VarBuilder::from_pth(path.as_ref(), dtype, device).map_err(Error::Candle)?;
Self::from_fasterdecoding_var_builder(cfg, vb)
}
pub fn from_fasterdecoding_var_builder(cfg: &MedusaConfig, vb: VarBuilder<'_>) -> Result<Self> {
let mut heads = Vec::with_capacity(cfg.n_heads);
for i in 0..cfg.n_heads {
let head_vb = vb.pp(i.to_string());
let mut res_blocks = Vec::with_capacity(cfg.residual_layers);
for j in 0..cfg.residual_layers {
let l = linear(
cfg.hidden_size,
cfg.hidden_size,
head_vb.pp(j.to_string()).pp("linear"),
)
.map_err(Error::Candle)?;
res_blocks.push(l);
}
let output_proj = linear_no_bias(
cfg.hidden_size,
cfg.vocab_size,
head_vb.pp(cfg.residual_layers.to_string()),
)
.map_err(Error::Candle)?;
heads.push(MedusaHeadModule {
res_blocks,
output_proj,
});
}
Ok(Self {
config: cfg.clone(),
heads,
})
}
pub fn config(&self) -> &MedusaConfig {
&self.config
}
pub fn forward(&self, hidden: &Tensor) -> Result<Vec<Tensor>> {
self.heads.iter().map(|h| h.forward(hidden)).collect()
}
pub fn top_k_per_head(&self, hidden: &Tensor, k: usize) -> Result<Vec<Vec<u32>>> {
let logits_per_head = self.forward(hidden)?;
let mut out = Vec::with_capacity(self.heads.len());
for logits in logits_per_head {
let v = logits
.to_dtype(DType::F32)
.map_err(Error::Candle)?
.to_vec1::<f32>()
.map_err(Error::Candle)?;
let top: Vec<u32> = top_k_indices(&v, k).into_iter().map(|i| i as u32).collect();
out.push(top);
}
Ok(out)
}
}
pub fn run_medusa_real<T, R>(
target: &mut T,
heads: &MedusaHeadsCandle,
skeleton: &MedusaHeads,
prompt: &[u32],
max_new_tokens: usize,
config: &MedusaRunConfig,
rng: &mut R,
) -> Result<Vec<u32>>
where
T: TreeDecoder + ?Sized,
R: rand::Rng + ?Sized,
{
if heads.config().n_heads != skeleton.len() {
return Err(Error::Sampling(format!(
"head bundle size ({}) does not match skeleton ({})",
heads.config().n_heads,
skeleton.len()
)));
}
target.reset();
Decoder::observe(target, prompt)?;
let mut generated: Vec<u32> = Vec::with_capacity(max_new_tokens);
while generated.len() < max_new_tokens {
let root = *Decoder::history(target)
.last()
.ok_or_else(|| Error::Sampling("Medusa requires non-empty prompt".into()))?;
let hidden = target.last_hidden_state()?;
let head_top_k = heads.top_k_per_head(&hidden, config.top_k_per_head)?;
let tree = skeleton.build_draft_tree(root, &head_top_k, config.topology)?;
let per_node_logits = target.tree_logits(&tree)?;
let mut best_path: Vec<usize> = vec![0];
for path in tree.paths() {
let accepted_len =
walk_and_accept(&path, &tree, &per_node_logits, &config.acceptance, rng);
if accepted_len + 1 > best_path.len() {
best_path = path[..=accepted_len].to_vec();
}
}
let mut committed: Vec<u32> = best_path
.iter()
.skip(1)
.map(|&i| tree.token_at(i))
.collect();
let deepest_idx = *best_path.last().unwrap();
if generated.len() + committed.len() < max_new_tokens {
let bonus = sample_argmax_or_categorical(
&per_node_logits[deepest_idx],
&config.acceptance,
rng,
)?;
committed.push(bonus);
}
if committed.is_empty() {
return Err(Error::Sampling(
"Medusa round committed zero tokens — would loop forever".into(),
));
}
Decoder::observe(target, &committed)?;
generated.extend_from_slice(&committed);
}
generated.truncate(max_new_tokens);
Ok(generated)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_defaults_are_sensible() {
let c = MedusaConfig::vicuna_7b_defaults();
assert_eq!(c.n_heads, 4);
assert_eq!(c.hidden_size, 4096);
}
#[test]
fn heads_bundle_matches_n_heads() {
let h = MedusaHeads::from_config(MedusaConfig {
n_heads: 5,
hidden_size: 256,
vocab_size: 1000,
residual_layers: 1,
});
assert_eq!(h.len(), 5);
for (i, head) in h.heads.iter().enumerate() {
assert_eq!(
head.offset,
i + 1,
"head {i} should target offset {}",
i + 1
);
}
}
#[test]
fn top_k_picks_highest_with_stable_tie_break() {
let logits = [0.1, 0.5, 0.5, 0.3, 0.5];
let idx = top_k_indices(&logits, 3);
assert_eq!(idx, vec![1, 2, 4]);
}
#[test]
fn top_k_clamps_to_vocab_size() {
let logits = [0.5, 0.4];
assert_eq!(top_k_indices(&logits, 100), vec![0, 1]);
}
#[test]
fn greedy_topology_makes_linear_tree() {
let h = MedusaHeads::from_config(MedusaConfig {
n_heads: 3,
hidden_size: 16,
vocab_size: 100,
residual_layers: 1,
});
let tree = h
.build_draft_tree(7, &[vec![10], vec![20], vec![30]], TreeTopology::Greedy)
.unwrap();
assert_eq!(tree.tokens(), &[7, 10, 20, 30]);
assert_eq!(tree.paths(), vec![vec![0, 1, 2, 3]]);
}
#[test]
fn cartesian_topology_branches_at_each_head() {
let h = MedusaHeads::from_config(MedusaConfig {
n_heads: 2,
hidden_size: 16,
vocab_size: 100,
residual_layers: 1,
});
let tree = h
.build_draft_tree(
0,
&[vec![10, 11], vec![20, 21]],
TreeTopology::CartesianProduct,
)
.unwrap();
assert_eq!(tree.len(), 1 + 2 + 4);
let mut paths = tree.paths();
paths.sort();
assert_eq!(paths.len(), 4);
for p in &paths {
assert_eq!(p.len(), 3);
}
}
#[test]
fn cartesian_tree_attention_mask_blocks_cross_branches() {
let h = MedusaHeads::from_config(MedusaConfig {
n_heads: 2,
hidden_size: 16,
vocab_size: 100,
residual_layers: 1,
});
let tree = h
.build_draft_tree(0, &[vec![10, 11], vec![20]], TreeTopology::CartesianProduct)
.unwrap();
let mask = tree.attention_mask_bool();
assert!(mask[3][0] && mask[3][1] && mask[3][3]);
assert!(!mask[3][2], "node 3 must not see sibling-branch ancestor");
assert!(!mask[3][4]);
assert!(mask[4][0] && mask[4][2] && mask[4][4]);
assert!(!mask[4][1] && !mask[4][3]);
}
#[test]
fn build_rejects_wrong_head_count() {
let h = MedusaHeads::from_config(MedusaConfig {
n_heads: 3,
hidden_size: 16,
vocab_size: 100,
residual_layers: 1,
});
let err = h
.build_draft_tree(0, &[vec![1], vec![2]], TreeTopology::Greedy)
.unwrap_err();
assert!(matches!(err, Error::Sampling(_)));
}
#[test]
fn build_rejects_empty_candidate_list() {
let h = MedusaHeads::from_config(MedusaConfig {
n_heads: 2,
hidden_size: 16,
vocab_size: 100,
residual_layers: 1,
});
let err = h
.build_draft_tree(0, &[vec![1], vec![]], TreeTopology::Greedy)
.unwrap_err();
assert!(matches!(err, Error::Sampling(_)));
}
use crate::model::mock::fixed_distribution;
use rand::SeedableRng;
fn fixed_head_draft(per_head: Vec<Vec<u32>>) -> HeadDraftFn {
Box::new(move |_history| per_head.clone())
}
fn vocab_peak_at(vocab_size: usize, peak_idx: usize) -> Vec<f32> {
let mut p = vec![0.001f32; vocab_size];
let remainder = 1.0 - 0.001 * (vocab_size as f32 - 1.0);
p[peak_idx] = remainder;
p
}
#[test]
fn medusa_greedy_oracle_head_accepts_all() {
let vocab = 16;
let mut target = fixed_distribution(vocab_peak_at(vocab, 5));
let heads = MedusaHeads::from_config(MedusaConfig {
n_heads: 4,
hidden_size: 1,
vocab_size: vocab,
residual_layers: 1,
});
let head_draft = fixed_head_draft(vec![vec![5], vec![5], vec![5], vec![5]]);
let cfg = MedusaRunConfig {
topology: TreeTopology::Greedy,
top_k_per_head: 1,
acceptance: Acceptance::Greedy,
};
let mut rng = rand::rngs::StdRng::seed_from_u64(1);
let out = run_medusa(&mut target, &heads, head_draft, &[7u32], 20, &cfg, &mut rng).unwrap();
assert_eq!(out.len(), 20);
for &t in &out {
assert_eq!(t, 5, "expected target argmax (5), got {t}");
}
}
#[test]
fn medusa_greedy_wrong_head_falls_back_to_bonus_only() {
let vocab = 16;
let mut target = fixed_distribution(vocab_peak_at(vocab, 5));
let heads = MedusaHeads::from_config(MedusaConfig {
n_heads: 3,
hidden_size: 1,
vocab_size: vocab,
residual_layers: 1,
});
let head_draft = fixed_head_draft(vec![vec![7], vec![7], vec![7]]);
let cfg = MedusaRunConfig {
topology: TreeTopology::Greedy,
top_k_per_head: 1,
acceptance: Acceptance::Greedy,
};
let mut rng = rand::rngs::StdRng::seed_from_u64(2);
let out = run_medusa(&mut target, &heads, head_draft, &[1u32], 12, &cfg, &mut rng).unwrap();
assert_eq!(out.len(), 12);
for &t in &out {
assert_eq!(t, 5);
}
}
#[test]
fn medusa_cartesian_picks_correct_branch() {
let vocab = 128;
let mut target = fixed_distribution(vocab_peak_at(vocab, 5));
let heads = MedusaHeads::from_config(MedusaConfig {
n_heads: 2,
hidden_size: 1,
vocab_size: vocab,
residual_layers: 1,
});
let head_draft = fixed_head_draft(vec![vec![5, 99], vec![5, 99]]);
let cfg = MedusaRunConfig {
topology: TreeTopology::CartesianProduct,
top_k_per_head: 2,
acceptance: Acceptance::Greedy,
};
let mut rng = rand::rngs::StdRng::seed_from_u64(3);
let out = run_medusa(&mut target, &heads, head_draft, &[1u32], 9, &cfg, &mut rng).unwrap();
assert_eq!(out.len(), 9);
for &t in &out {
assert_eq!(t, 5);
}
}
#[test]
fn medusa_typical_acceptance_threshold_blocks_low_mass_token() {
let vocab = 50;
let mut probs = vec![0.01f32; vocab];
probs[0] = 1.0 - 0.01 * (vocab as f32 - 1.0);
let mut target = fixed_distribution(probs);
let heads = MedusaHeads::from_config(MedusaConfig {
n_heads: 2,
hidden_size: 1,
vocab_size: vocab,
residual_layers: 1,
});
let head_draft = fixed_head_draft(vec![vec![25], vec![25]]);
let cfg = MedusaRunConfig {
topology: TreeTopology::Greedy,
top_k_per_head: 1,
acceptance: Acceptance::Typical {
epsilon: 0.5,
delta: 1.0,
},
};
let mut rng = rand::rngs::StdRng::seed_from_u64(4);
let out = run_medusa(&mut target, &heads, head_draft, &[1u32], 6, &cfg, &mut rng).unwrap();
assert_eq!(out.len(), 6);
for &t in &out {
assert_eq!(t, 0);
}
}
#[test]
fn evaluate_tree_restores_target_history() {
let vocab = 8;
let mut target = fixed_distribution(vocab_peak_at(vocab, 3));
Decoder::observe(&mut target, &[0u32, 1, 2]).unwrap();
let pre = Decoder::history_len(&target);
let tree = DraftTree::linear(2, &[5, 6, 7]);
let _ = evaluate_tree(&mut target, &tree, pre).unwrap();
assert_eq!(
Decoder::history_len(&target),
pre,
"evaluate_tree must restore history"
);
}
}