use crate::{
methods::{
eagle::{EagleDraftCandle, EagleRunConfig},
eagle3::{Eagle3DraftCandle, Eagle3RunConfig},
medusa::{MedusaHeads, MedusaHeadsCandle, MedusaRunConfig},
Method,
},
model::{loader::ModelSource, Decoder, TreeDecoder},
sampling::{sample_from_distribution, softmax_with_temperature, top_p_filter},
Error, Result,
};
pub enum SpeculateDraft {
Vanilla(Box<dyn Decoder + Send>),
Medusa {
heads: Box<MedusaHeadsCandle>,
skeleton: Box<MedusaHeads>,
},
Eagle2(Box<EagleDraftCandle>),
Eagle3(Box<Eagle3DraftCandle>),
}
impl std::fmt::Debug for SpeculateDraft {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SpeculateDraft::Vanilla(_) => f.write_str("SpeculateDraft::Vanilla(_)"),
SpeculateDraft::Medusa { .. } => f.write_str("SpeculateDraft::Medusa { .. }"),
SpeculateDraft::Eagle2(_) => f.write_str("SpeculateDraft::Eagle2(_)"),
SpeculateDraft::Eagle3(_) => f.write_str("SpeculateDraft::Eagle3(_)"),
}
}
}
impl SpeculateDraft {
fn matches_method(&self, method: Method) -> bool {
matches!(
(self, method),
(SpeculateDraft::Vanilla(_), Method::Vanilla)
| (SpeculateDraft::Medusa { .. }, Method::Medusa)
| (SpeculateDraft::Eagle2(_), Method::Eagle2)
| (SpeculateDraft::Eagle3(_), Method::Eagle3)
)
}
fn variant_name(&self) -> &'static str {
match self {
SpeculateDraft::Vanilla(_) => "Vanilla",
SpeculateDraft::Medusa { .. } => "Medusa",
SpeculateDraft::Eagle2(_) => "Eagle2",
SpeculateDraft::Eagle3(_) => "Eagle3",
}
}
}
pub struct SpeculateEngine {
config: EngineConfig,
target: Option<Box<dyn TreeDecoder + Send>>,
draft: Option<SpeculateDraft>,
eagle_run_config: EagleRunConfig,
eagle3_run_config: Eagle3RunConfig,
medusa_run_config: MedusaRunConfig,
}
impl std::fmt::Debug for SpeculateEngine {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SpeculateEngine")
.field("config", &self.config)
.field("target_loaded", &self.target.is_some())
.field(
"draft",
&self.draft.as_ref().map(|d| d.variant_name()),
)
.finish()
}
}
#[derive(Debug, Clone, Default)]
pub struct GenerationOptions {
pub max_new_tokens: usize,
pub stop_tokens: Vec<u32>,
}
impl GenerationOptions {
pub fn new(max_new_tokens: usize) -> Self {
Self {
max_new_tokens,
stop_tokens: Vec::new(),
}
}
pub fn with_stop(mut self, tok: u32) -> Self {
self.stop_tokens.push(tok);
self
}
pub fn with_stops(mut self, toks: Vec<u32>) -> Self {
self.stop_tokens = toks;
self
}
}
#[derive(Debug, Clone)]
pub struct EngineConfig {
pub target: ModelSource,
pub draft: Option<ModelSource>,
pub method: Method,
pub default_max_tokens: usize,
pub temperature: f32,
pub top_p: f32,
pub draft_lookahead: usize,
pub seed: Option<u64>,
}
impl SpeculateEngine {
pub fn builder() -> SpeculateEngineBuilder {
SpeculateEngineBuilder::default()
}
pub fn preset_for(model_name: &str) -> Result<Self> {
let preset = crate::presets::lookup(model_name)
.ok_or_else(|| Error::UnknownPreset(model_name.to_string()))?;
SpeculateEngineBuilder::default()
.target_model(&preset.target)
.method(preset.method)
.maybe_draft_model(preset.draft.as_deref())
.draft_lookahead(preset.draft_lookahead)
.temperature(preset.temperature)
.top_p(preset.top_p)
.build()
}
pub fn config(&self) -> &EngineConfig {
&self.config
}
pub fn with_target<D: TreeDecoder + Send + 'static>(mut self, target: D) -> Self {
self.target = Some(Box::new(target));
self
}
pub fn with_draft<D: Decoder + Send + 'static>(mut self, draft: D) -> Self {
self.draft = Some(SpeculateDraft::Vanilla(Box::new(draft)));
self
}
pub fn with_eagle_draft(mut self, draft: EagleDraftCandle) -> Self {
self.draft = Some(SpeculateDraft::Eagle2(Box::new(draft)));
if self.config.method != Method::Eagle2 {
self.config.method = Method::Eagle2;
}
self
}
pub fn with_eagle3_draft(mut self, draft: Eagle3DraftCandle) -> Self {
self.draft = Some(SpeculateDraft::Eagle3(Box::new(draft)));
if self.config.method != Method::Eagle3 {
self.config.method = Method::Eagle3;
}
self
}
pub fn with_medusa(mut self, heads: MedusaHeadsCandle, skeleton: MedusaHeads) -> Self {
self.draft = Some(SpeculateDraft::Medusa {
heads: Box::new(heads),
skeleton: Box::new(skeleton),
});
if self.config.method != Method::Medusa {
self.config.method = Method::Medusa;
}
self
}
pub fn eagle_run_config(mut self, cfg: EagleRunConfig) -> Self {
self.eagle_run_config = cfg;
self
}
pub fn eagle3_run_config(mut self, cfg: Eagle3RunConfig) -> Self {
self.eagle3_run_config = cfg;
self
}
pub fn medusa_run_config(mut self, cfg: MedusaRunConfig) -> Self {
self.medusa_run_config = cfg;
self
}
pub fn is_ready(&self) -> bool {
if self.target.is_none() {
return false;
}
if self.config.method.needs_draft_model() {
match &self.draft {
Some(d) if d.matches_method(self.config.method) => {}
_ => return false,
}
}
true
}
pub fn generate_tokens(&mut self, prompt: &[u32], max_new_tokens: usize) -> Result<Vec<u32>> {
self.generate_tokens_with(prompt, &GenerationOptions::new(max_new_tokens), |_| true)
}
pub fn generate_tokens_with<F>(
&mut self,
prompt: &[u32],
opts: &GenerationOptions,
on_token: F,
) -> Result<Vec<u32>>
where
F: FnMut(u32) -> bool,
{
if !self.is_ready() {
return Err(Error::MissingField(
"models not loaded — call with_target / with_draft first",
));
}
let mut rng: Box<dyn rand::RngCore> = match self.config.seed {
Some(s) => {
use rand::SeedableRng;
Box::new(rand::rngs::StdRng::seed_from_u64(s))
}
None => Box::new(rand::thread_rng()),
};
match self.config.method {
Method::Autoregressive => {
let target = self.target.as_mut().unwrap();
run_autoregressive(
target.as_mut(),
prompt,
opts,
&self.config,
&mut rng,
on_token,
)
}
Method::Vanilla => {
let target = self.target.as_mut().unwrap();
let SpeculateDraft::Vanilla(draft) = self.draft.as_mut().unwrap() else {
return Err(Error::MissingField(
"method=Vanilla but draft is not a vanilla Decoder; \
call with_draft(_) — for EAGLE/Medusa methods use \
with_eagle_draft / with_eagle3_draft / with_medusa.",
));
};
let cfg = crate::methods::vanilla::VanillaConfig {
draft_lookahead: self.config.draft_lookahead,
temperature: self.config.temperature,
top_p: self.config.top_p,
};
crate::methods::vanilla::run_vanilla_sd_with(
target.as_mut() as &mut dyn Decoder,
draft.as_mut(),
prompt,
opts,
&cfg,
&mut rng,
on_token,
)
}
Method::Medusa => {
let target = self.target.as_mut().unwrap();
let SpeculateDraft::Medusa { heads, skeleton } = self.draft.as_mut().unwrap() else {
return Err(Error::MissingField(
"method=Medusa but draft is not Medusa heads; call with_medusa(_, _).",
));
};
let out = crate::methods::medusa::run_medusa_real(
target.as_mut(),
heads.as_ref(),
skeleton.as_ref(),
prompt,
opts.max_new_tokens,
&self.medusa_run_config,
&mut rng,
)?;
stream_and_stop(out, opts, on_token)
}
Method::Eagle2 => {
let target = self.target.as_mut().unwrap();
let SpeculateDraft::Eagle2(draft) = self.draft.as_mut().unwrap() else {
return Err(Error::MissingField(
"method=Eagle2 but draft is not an EAGLE-2 draft; call with_eagle_draft(_).",
));
};
let out = crate::methods::eagle::run_eagle(
target.as_mut(),
draft.as_mut(),
prompt,
opts.max_new_tokens,
&self.eagle_run_config,
&mut rng,
)?;
stream_and_stop(out, opts, on_token)
}
Method::Eagle3 => {
let target = self.target.as_mut().unwrap();
let SpeculateDraft::Eagle3(draft) = self.draft.as_mut().unwrap() else {
return Err(Error::MissingField(
"method=Eagle3 but draft is not an EAGLE-3 draft; call with_eagle3_draft(_).",
));
};
let out = crate::methods::eagle3::run_eagle3(
target.as_mut(),
draft.as_mut(),
prompt,
opts.max_new_tokens,
&self.eagle3_run_config,
&mut rng,
)?;
stream_and_stop(out, opts, on_token)
}
}
}
pub fn generate(&mut self, prompt: &str, max_tokens: usize) -> Result<String> {
if !self.is_ready() {
return Err(Error::MissingField(
"models not loaded — call with_target / with_draft first",
));
}
let target = self.target.as_ref().unwrap();
let prompt_ids = target.encode(prompt, true)?;
let stops = target.eos_token_ids();
let opts = GenerationOptions::new(max_tokens).with_stops(stops);
let out_ids = self.generate_tokens_with(&prompt_ids, &opts, |_| true)?;
let target = self.target.as_ref().unwrap();
target.decode(&out_ids, true)
}
}
fn stream_and_stop<F: FnMut(u32) -> bool>(
out: Vec<u32>,
opts: &GenerationOptions,
mut on_token: F,
) -> Result<Vec<u32>> {
let mut emitted = Vec::with_capacity(out.len());
for tok in out {
emitted.push(tok);
if !on_token(tok) || opts.stop_tokens.contains(&tok) {
break;
}
}
Ok(emitted)
}
fn run_autoregressive<T, R, F>(
target: &mut T,
prompt: &[u32],
opts: &GenerationOptions,
config: &EngineConfig,
rng: &mut R,
mut on_token: F,
) -> Result<Vec<u32>>
where
T: Decoder + ?Sized,
R: rand::Rng + ?Sized,
F: FnMut(u32) -> bool,
{
target.reset();
target.observe(prompt)?;
let mut out = Vec::with_capacity(opts.max_new_tokens);
for _ in 0..opts.max_new_tokens {
let logits = target.next_logits()?;
let mut probs = softmax_with_temperature(&logits, config.temperature)?;
if config.top_p < 1.0 {
top_p_filter(&mut probs, config.top_p)?;
}
let tok = sample_from_distribution(rng, &probs)? as u32;
target.observe(&[tok])?;
out.push(tok);
if !on_token(tok) || opts.stop_tokens.contains(&tok) {
break;
}
}
Ok(out)
}
#[derive(Debug, Default, Clone)]
pub struct SpeculateEngineBuilder {
target: Option<ModelSource>,
draft: Option<ModelSource>,
method: Option<Method>,
default_max_tokens: Option<usize>,
temperature: Option<f32>,
top_p: Option<f32>,
draft_lookahead: Option<usize>,
seed: Option<u64>,
}
impl SpeculateEngineBuilder {
pub fn target_model(mut self, source: &str) -> Self {
self.target = Some(ModelSource::parse(source));
self
}
pub fn draft_model(mut self, source: &str) -> Self {
self.draft = Some(ModelSource::parse(source));
self
}
pub fn maybe_draft_model(mut self, source: Option<&str>) -> Self {
if let Some(s) = source {
self.draft = Some(ModelSource::parse(s));
}
self
}
pub fn draft_path(self, source: &str) -> Self {
self.draft_model(source)
}
pub fn method(mut self, m: Method) -> Self {
self.method = Some(m);
self
}
pub fn default_max_tokens(mut self, n: usize) -> Self {
self.default_max_tokens = Some(n);
self
}
pub fn temperature(mut self, t: f32) -> Self {
self.temperature = Some(t);
self
}
pub fn top_p(mut self, p: f32) -> Self {
self.top_p = Some(p);
self
}
pub fn draft_lookahead(mut self, n: usize) -> Self {
self.draft_lookahead = Some(n);
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = Some(s);
self
}
pub fn build(self) -> Result<SpeculateEngine> {
let target = self.target.ok_or(Error::MissingField("target_model"))?;
let method = self.method.unwrap_or(Method::Autoregressive);
if method.needs_draft_model() && self.draft.is_none() {
return Err(Error::UnsupportedMethod {
method: method.name(),
reason: "method requires a draft model; call .draft_model(...)".into(),
});
}
let config = EngineConfig {
target,
draft: self.draft,
method,
default_max_tokens: self.default_max_tokens.unwrap_or(256),
temperature: self.temperature.unwrap_or(0.7),
top_p: self.top_p.unwrap_or(0.95),
draft_lookahead: self.draft_lookahead.unwrap_or(4),
seed: self.seed,
};
Ok(SpeculateEngine {
config,
target: None,
draft: None,
eagle_run_config: EagleRunConfig::default(),
eagle3_run_config: Eagle3RunConfig::default(),
medusa_run_config: MedusaRunConfig::default(),
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::mock::fixed_distribution;
#[test]
fn builder_requires_target() {
let err = SpeculateEngineBuilder::default().build().unwrap_err();
assert!(matches!(err, Error::MissingField(_)));
}
#[test]
fn vanilla_method_requires_draft_in_config() {
let err = SpeculateEngineBuilder::default()
.target_model("meta-llama/Llama-3.1-8B-Instruct")
.method(Method::Vanilla)
.build()
.unwrap_err();
assert!(matches!(err, Error::UnsupportedMethod { .. }));
}
#[test]
fn autoregressive_does_not_require_draft() {
let engine = SpeculateEngineBuilder::default()
.target_model("meta-llama/Llama-3.1-8B-Instruct")
.method(Method::Autoregressive)
.build()
.unwrap();
assert_eq!(engine.config().method, Method::Autoregressive);
assert!(engine.config().draft.is_none());
assert!(!engine.is_ready(), "no model attached yet");
}
#[test]
fn vanilla_with_draft_succeeds() {
let engine = SpeculateEngineBuilder::default()
.target_model("meta-llama/Llama-3.1-8B-Instruct")
.draft_model("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
.method(Method::Vanilla)
.draft_lookahead(6)
.build()
.unwrap();
assert_eq!(engine.config().draft_lookahead, 6);
assert!(engine.config().draft.is_some());
}
#[test]
fn generate_tokens_runs_autoregressive_with_mock() {
let target = fixed_distribution(vec![0.5, 0.3, 0.2]);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy")
.method(Method::Autoregressive)
.seed(42)
.build()
.unwrap()
.with_target(target);
assert!(engine.is_ready());
let out = engine.generate_tokens(&[7u32], 8).unwrap();
assert_eq!(out.len(), 8);
for &t in &out {
assert!(t < 3, "produced token {t} outside vocab");
}
}
#[test]
fn generate_tokens_runs_vanilla_sd_with_mocks() {
let target = fixed_distribution(vec![0.6, 0.3, 0.1]);
let draft = fixed_distribution(vec![0.33, 0.33, 0.34]);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy-target")
.draft_model("dummy-draft")
.method(Method::Vanilla)
.draft_lookahead(3)
.seed(99)
.build()
.unwrap()
.with_target(target)
.with_draft(draft);
assert!(engine.is_ready());
let out = engine.generate_tokens(&[1u32], 12).unwrap();
assert_eq!(out.len(), 12);
}
#[test]
fn generate_tokens_with_stops_at_stop_token() {
let mut probs = vec![0.0f32; 8];
probs[5] = 1.0;
let target = fixed_distribution(probs);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy")
.method(Method::Autoregressive)
.seed(1)
.build()
.unwrap()
.with_target(target);
let opts = GenerationOptions::new(64).with_stop(5);
let out = engine
.generate_tokens_with(&[0u32], &opts, |_| true)
.unwrap();
assert_eq!(out, vec![5], "should stop after first emitted EOS");
}
#[test]
fn generate_tokens_with_callback_can_halt_early() {
let target = fixed_distribution(vec![0.0, 0.0, 1.0, 0.0]);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy")
.method(Method::Autoregressive)
.seed(1)
.build()
.unwrap()
.with_target(target);
let mut count = 0;
let opts = GenerationOptions::new(100);
let out = engine
.generate_tokens_with(&[0u32], &opts, |_| {
count += 1;
count < 5
})
.unwrap();
assert_eq!(out.len(), 5, "callback should stop after 5 tokens");
}
#[test]
fn generate_tokens_with_callback_streams_each_token() {
let target = fixed_distribution(vec![0.0, 0.0, 1.0, 0.0]);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy")
.method(Method::Autoregressive)
.seed(1)
.build()
.unwrap()
.with_target(target);
let mut seen = Vec::new();
let opts = GenerationOptions::new(7);
let out = engine
.generate_tokens_with(&[0u32], &opts, |t| {
seen.push(t);
true
})
.unwrap();
assert_eq!(seen, out, "callback sequence should match returned tokens");
assert_eq!(out.len(), 7);
}
#[test]
fn vanilla_sd_with_options_respects_stop_tokens() {
let mut probs = vec![0.0f32; 8];
probs[5] = 1.0;
let target = fixed_distribution(probs.clone());
let draft = fixed_distribution(probs);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy-target")
.draft_model("dummy-draft")
.method(Method::Vanilla)
.draft_lookahead(4)
.seed(99)
.build()
.unwrap()
.with_target(target)
.with_draft(draft);
let opts = GenerationOptions::new(64).with_stop(5);
let out = engine
.generate_tokens_with(&[1u32], &opts, |_| true)
.unwrap();
assert_eq!(out.len(), 1, "vanilla SD should stop at first EOS");
assert_eq!(out[0], 5);
}
#[test]
fn generate_text_is_explicit_about_unsupported_path() {
let target = fixed_distribution(vec![0.5, 0.5]);
let mut engine = SpeculateEngineBuilder::default()
.target_model("dummy")
.build()
.unwrap()
.with_target(target);
let err = engine.generate("hi", 5).unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("Backend") || msg.contains("tokenizer"),
"expected guidance toward the lower-level path; got: {msg}"
);
}
}