use smol_str::format_smolstr;
use crate::{
array::Array,
error::{
EmptyInputPayload, Error, InvariantViolationPayload, OutOfRangePayload, RankMismatchPayload,
Result, try_extend_from_slice, try_with_capacity,
},
lm::{
cache::{KvCache, can_trim_prompt_cache, trim_prompt_cache},
generate::{FinishReason, GenConfig, GenerationResponse, make_logits_processors, make_sampler},
model::Model,
},
ops,
};
#[cfg(feature = "tokenizer-stream")]
use crate::tokenizer::StreamingDetokenizer as _;
#[derive(Debug, Clone, Copy, Default)]
pub struct GenerationStats {
pub proposed_drafts: usize,
pub accepted_drafts: usize,
pub generated_tokens: usize,
}
impl GenerationStats {
pub fn accept_rate(&self) -> f32 {
if self.proposed_drafts == 0 {
0.0
} else {
self.accepted_drafts as f32 / self.proposed_drafts as f32
}
}
}
pub struct DraftConfig {
pub draft_model: Box<dyn Model>,
pub n_draft_tokens: usize,
}
pub fn speculative_generate(
target: &dyn Model,
tokenizer: &crate::tokenizer::Tokenizer,
prompt: &[u32],
target_cache: Vec<Box<dyn KvCache>>,
draft_cache: Vec<Box<dyn KvCache>>,
draft_cfg: DraftConfig,
cfg: GenConfig,
) -> Result<(String, GenerationStats)> {
let mut text = String::new();
let mut stats = GenerationStats::default();
for response in speculative_stream_generate(
target,
tokenizer,
prompt,
target_cache,
draft_cache,
draft_cfg,
cfg,
) {
let r = response?;
text.push_str(&r.text);
stats = r.stats;
}
Ok((text, stats))
}
#[derive(Debug)]
pub struct SpeculativeResponse {
pub response: GenerationResponse,
pub from_draft: bool,
pub stats: GenerationStats,
}
impl SpeculativeResponse {
pub fn text(&self) -> &str {
&self.response.text
}
}
impl std::ops::Deref for SpeculativeResponse {
type Target = GenerationResponse;
fn deref(&self) -> &Self::Target {
&self.response
}
}
pub struct SpeculativeStream<'a> {
driver: Option<SpeculativeDriver<'a>>,
pending_err: Option<Error>,
detok: crate::tokenizer::wrapper::BoxedDetokenizer,
prompt_tokens: usize,
eos: Vec<u32>,
max_tokens: usize,
collect_logprobs: bool,
n: usize,
finished: bool,
tic: std::time::Instant,
prompt_tps: f64,
}
impl SpeculativeStream<'_> {
pub fn finalize_tail(&mut self) -> String {
self.detok.finalize();
self.detok.last_segment()
}
}
impl Iterator for SpeculativeStream<'_> {
type Item = Result<SpeculativeResponse>;
fn next(&mut self) -> Option<Self::Item> {
if let Some(e) = self.pending_err.take() {
self.finished = true;
return Some(Err(e));
}
if self.finished {
return None;
}
let d = self.driver.as_mut()?;
let TokenOut {
token,
logprobs,
from_draft,
stats,
} = match d.next_token() {
Ok(Some(t)) => t,
Ok(None) => {
self.finished = true;
return None;
}
Err(e) => {
self.finished = true;
return Some(Err(e));
}
};
if self.n == 0 {
let prompt_time = self.tic.elapsed().as_secs_f64();
self.prompt_tps = if prompt_time > 0.0 {
self.prompt_tokens as f64 / prompt_time
} else {
0.0
};
self.tic = std::time::Instant::now();
}
let dt = self.tic.elapsed().as_secs_f64();
let gen_tps = |gen_count: usize| -> f64 { if dt > 0.0 { gen_count as f64 / dt } else { 0.0 } };
if self.eos.contains(&token) {
self.finished = true;
self.detok.finalize();
let text = self.detok.last_segment();
return Some(Ok(SpeculativeResponse {
response: GenerationResponse {
text,
token,
logprobs: self.collect_logprobs.then_some(logprobs),
prompt_tokens: self.prompt_tokens,
prompt_tps: self.prompt_tps,
generation_tokens: self.n + 1,
generation_tps: gen_tps(self.n + 1),
peak_memory_bytes: crate::memory::peak_memory().ok(),
finish_reason: Some(FinishReason::Eos),
},
from_draft,
stats,
}));
}
self.detok.add_token(token);
self.n += 1;
if self.n >= self.max_tokens {
self.finished = true;
self.detok.finalize();
let text = self.detok.last_segment();
return Some(Ok(SpeculativeResponse {
response: GenerationResponse {
text,
token,
logprobs: self.collect_logprobs.then_some(logprobs),
prompt_tokens: self.prompt_tokens,
prompt_tps: self.prompt_tps,
generation_tokens: self.n,
generation_tps: gen_tps(self.n),
peak_memory_bytes: crate::memory::peak_memory().ok(),
finish_reason: Some(FinishReason::Length),
},
from_draft,
stats,
}));
}
let text = self.detok.last_segment();
Some(Ok(SpeculativeResponse {
response: GenerationResponse {
text,
token,
logprobs: self.collect_logprobs.then_some(logprobs),
prompt_tokens: self.prompt_tokens,
prompt_tps: self.prompt_tps,
generation_tokens: self.n,
generation_tps: gen_tps(self.n),
peak_memory_bytes: crate::memory::peak_memory().ok(),
finish_reason: None,
},
from_draft,
stats,
}))
}
}
pub fn speculative_stream_generate<'a>(
target: &'a dyn Model,
tokenizer: &'a crate::tokenizer::Tokenizer,
prompt: &[u32],
target_cache: Vec<Box<dyn KvCache>>,
draft_cache: Vec<Box<dyn KvCache>>,
draft_cfg: DraftConfig,
cfg: GenConfig,
) -> SpeculativeStream<'a> {
let prompt_tokens = prompt.len();
let mut cfg = cfg;
cfg.eos = tokenizer.eos_token_ids_iter().collect();
let max_tokens = cfg.max_tokens;
let eos: Vec<u32> = cfg.eos.clone();
let collect_logprobs = cfg.collect_logprobs;
let prompt: Vec<u32> = prompt.to_vec();
let (driver, pending_err) =
match SpeculativeDriver::new(target, draft_cfg, prompt, target_cache, draft_cache, &cfg) {
Ok(d) => (Some(d), None),
Err(e) => (None, Some(e)),
};
SpeculativeStream {
driver,
pending_err,
detok: tokenizer.detokenizer(),
prompt_tokens,
eos,
max_tokens,
collect_logprobs,
n: 0,
finished: false,
tic: std::time::Instant::now(),
prompt_tps: 0.0,
}
}
struct TokenOut {
token: u32,
logprobs: Array,
from_draft: bool,
stats: GenerationStats,
}
struct PendingToken {
token: u32,
logprobs: Array,
from_draft: bool,
delta: StatsDelta,
}
#[derive(Debug, Clone, Copy, Default)]
struct StatsDelta {
proposed: usize,
accepted: usize,
generated: usize,
}
struct SpeculativeDriver<'a> {
target: &'a dyn Model,
draft_model: Box<dyn Model>,
n_draft_tokens: usize,
max_tokens: usize,
target_cache: Vec<Box<dyn KvCache>>,
draft_cache: Vec<Box<dyn KvCache>>,
sampler: crate::lm::generate::Sampler,
processors: Vec<crate::lm::generate::LogitsProcessor>,
history: Vec<u32>,
produced: usize,
prefilled: bool,
prefill_step_size: usize,
prompt: Vec<u32>,
y_input: Vec<u32>,
draft_y_input: Vec<u32>,
pending: std::collections::VecDeque<PendingToken>,
stats: GenerationStats,
exhausted: bool,
}
impl<'a> SpeculativeDriver<'a> {
fn new(
target: &'a dyn Model,
draft_cfg: DraftConfig,
prompt: Vec<u32>,
target_cache: Vec<Box<dyn KvCache>>,
draft_cache: Vec<Box<dyn KvCache>>,
cfg: &GenConfig,
) -> Result<Self> {
if !can_trim_prompt_cache(&target_cache) {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"speculative_generate: target_cache",
"must be trimmable (see mlx-lm generate.py:529-533)",
)));
}
if !can_trim_prompt_cache(&draft_cache) {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"speculative_generate: draft_cache",
"must be trimmable (see mlx-lm generate.py:529-533)",
)));
}
if prompt.is_empty() {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"speculative_generate: prompt",
)));
}
cfg.validate()?;
let sampler = make_sampler(
cfg.temp,
cfg.top_p,
cfg.min_p,
cfg.min_tokens_to_keep,
cfg.top_k,
cfg.xtc_probability,
cfg.xtc_threshold,
&cfg.xtc_special_tokens,
cfg.seed,
)?;
let processors = make_logits_processors(
&cfg.logit_bias,
cfg.repetition_penalty,
cfg.repetition_context_size,
cfg.presence_penalty,
cfg.presence_context_size,
cfg.frequency_penalty,
cfg.frequency_context_size,
)?;
Ok(Self {
target,
draft_model: draft_cfg.draft_model,
n_draft_tokens: draft_cfg.n_draft_tokens,
max_tokens: cfg.max_tokens,
target_cache,
draft_cache,
sampler,
processors,
history: Vec::new(),
produced: 0,
prefilled: false,
prefill_step_size: cfg.prefill_step_size.max(1),
prompt,
y_input: Vec::new(),
draft_y_input: Vec::new(),
pending: std::collections::VecDeque::new(),
stats: GenerationStats::default(),
exhausted: false,
})
}
fn prefill(&mut self) -> Result<()> {
let mut offset = 0usize;
while self.prompt.len() - offset > 1 {
let remaining = (self.prompt.len() - offset) - 1;
let n = self.prefill_step_size.min(remaining);
let chunk = token_window(&self.prompt[offset..offset + n])?;
let _ = self.target.forward(&chunk, &mut self.target_cache)?;
let _ = self.draft_model.forward(&chunk, &mut self.draft_cache)?;
offset += n;
}
let tail = self.prompt[offset..].to_vec();
self.y_input = tail.clone();
self.draft_y_input = tail;
self.prefilled = true;
Ok(())
}
fn next_token(&mut self) -> Result<Option<TokenOut>> {
if let Some(t) = self.pending.pop_front() {
return Ok(Some(self.commit_pending(t)));
}
if self.exhausted {
return Ok(None);
}
if self.produced >= self.max_tokens {
self.exhausted = true;
return Ok(None);
}
if !self.prefilled {
self.prefill()?;
}
self.run_speculative_step()?;
Ok(self.pending.pop_front().map(|t| self.commit_pending(t)))
}
fn commit_pending(&mut self, t: PendingToken) -> TokenOut {
self.stats.proposed_drafts += t.delta.proposed;
self.stats.accepted_drafts += t.delta.accepted;
self.stats.generated_tokens += t.delta.generated;
TokenOut {
token: t.token,
logprobs: t.logprobs,
from_draft: t.from_draft,
stats: self.stats,
}
}
fn run_speculative_step(&mut self) -> Result<()> {
let remaining = self.max_tokens.saturating_sub(self.produced);
let num_draft = self.n_draft_tokens.min(remaining);
let draft_tokens = self.draft_generate(num_draft)?;
let mut combined: Vec<u32> = try_with_capacity(self.y_input.len() + draft_tokens.len())?;
try_extend_from_slice(&mut combined, &self.y_input)?;
try_extend_from_slice(&mut combined, &draft_tokens)?;
let n_predict = num_draft + 1; let combined_arr = token_window(&combined)?;
let logits = self.target.forward(&combined_arr, &mut self.target_cache)?;
let per_pos_logits = last_n_positions(&logits, n_predict)?;
let mut target_tokens: Vec<u32> = try_with_capacity(n_predict)?;
let mut target_logprobs: Vec<Array> = try_with_capacity(n_predict)?;
let have_procs = !self.processors.is_empty();
let mut history_snapshot = if have_procs {
self.history.clone()
} else {
Vec::new()
};
for pos in 0..n_predict {
let row = slice_position(&per_pos_logits, pos as i32)?;
let mut row = row;
if have_procs {
try_extend_from_slice(&mut history_snapshot, &combined[pos..pos + 1])?;
for p in &self.processors {
row = p.apply(&history_snapshot, &row)?;
}
}
let lse = ops::reduction::logsumexp(&row, true)?;
let logprobs = ops::arithmetic::subtract(&row, &lse)?;
let mut sampled = self.sampler.sample(&logprobs)?;
let tok = sampled.item::<u32>()?;
target_tokens.push(tok);
target_logprobs.push(ops::shape::squeeze_axes(&logprobs, &[0])?);
}
let mut n_accept = 0usize;
let mut hit_max = false;
for i in 0..num_draft {
let t_n = target_tokens[i];
let d_n = draft_tokens[i];
if t_n != d_n {
break;
}
n_accept += 1;
self.produced += 1;
let lp = std::mem::replace(&mut target_logprobs[i], empty_logprobs()?);
self.pending.push_back(PendingToken {
token: t_n,
logprobs: lp,
from_draft: true,
delta: StatsDelta {
proposed: 1,
accepted: 1,
generated: 1,
},
});
if self.produced >= self.max_tokens {
hit_max = true;
break;
}
}
if !hit_max && self.produced < self.max_tokens {
let bonus = target_tokens[n_accept];
let bonus_lp = std::mem::replace(&mut target_logprobs[n_accept], empty_logprobs()?);
self.produced += 1;
self.pending.push_back(PendingToken {
token: bonus,
logprobs: bonus_lp,
from_draft: false,
delta: StatsDelta {
proposed: num_draft - n_accept,
accepted: 0,
generated: 1,
},
});
}
let committed_len = self.y_input.len() + n_accept;
try_extend_from_slice(&mut self.history, &combined[..committed_len])?;
let target_trim = num_draft - n_accept;
let draft_trim = (num_draft.saturating_sub(n_accept)).saturating_sub(1);
if target_trim > 0 {
trim_prompt_cache(&mut self.target_cache, target_trim)?;
}
if draft_trim > 0 {
trim_prompt_cache(&mut self.draft_cache, draft_trim)?;
}
if self.produced < self.max_tokens && !hit_max {
let bonus = self.pending.back().map(|p| p.token).expect("bonus pending");
self.y_input = vec![bonus];
if num_draft > 0 && n_accept == num_draft {
let mut d = try_with_capacity(2)?;
d.push(draft_tokens[num_draft - 1]);
d.push(bonus);
self.draft_y_input = d;
} else {
self.draft_y_input = vec![bonus];
}
} else {
self.exhausted = true;
}
Ok(())
}
fn draft_generate(&mut self, num_draft: usize) -> Result<Vec<u32>> {
if num_draft == 0 {
return Ok(Vec::new());
}
let mut drafts = try_with_capacity(num_draft)?;
let mut y = self.draft_y_input.clone();
let have_procs = !self.processors.is_empty();
let mut draft_history = if have_procs {
self.history.clone()
} else {
Vec::new()
};
let mut next_history_token: u32 = *y.last().ok_or_else(|| {
Error::EmptyInput(EmptyInputPayload::new(
"speculative_generate: draft_y_input",
))
})?;
for _ in 0..num_draft {
let arr = token_window(&y)?;
let logits = self.draft_model.forward(&arr, &mut self.draft_cache)?;
let last = last_n_positions(&logits, 1)?;
let mut row = slice_position(&last, 0)?;
if have_procs {
draft_history.push(next_history_token);
for p in &self.processors {
row = p.apply(&draft_history, &row)?;
}
}
let lse = ops::reduction::logsumexp(&row, true)?;
let lp = ops::arithmetic::subtract(&row, &lse)?;
let mut sampled = self.sampler.sample(&lp)?;
let tok = sampled.item::<u32>()?;
drafts.push(tok);
y = vec![tok];
next_history_token = tok;
}
Ok(drafts)
}
}
fn token_window(ids: &[u32]) -> Result<Array> {
let mut row: Vec<i32> = try_with_capacity(ids.len())?;
row.extend(ids.iter().map(|&t| t as i32));
Array::from_slice::<i32>(&row, &(1usize, row.len()))
}
fn last_n_positions(logits: &Array, n: usize) -> Result<Array> {
let shape = logits.shape();
if shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"speculative_generate: expected [B, S, V] logits from `forward`",
shape.len() as u32,
shape.to_vec(),
)));
}
if shape[1] == 0 || shape[2] == 0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"speculative_generate: forward logits axes (S and V)",
"must be >= 1 to slice the last n positions",
format_smolstr!("S={}, V={}", shape[1], shape[2]),
)));
}
if n == 0 || n > shape[1] {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"speculative_generate: slice n",
"must be in 1..=S",
format_smolstr!("n={n} (S={})", shape[1]),
)));
}
let (b, s, v) = (shape[0] as i32, shape[1] as i32, shape[2] as i32);
let start = s - n as i32;
ops::indexing::slice(logits, &[0, start, 0], &[b, s, v], &[1, 1, 1])
}
fn slice_position(logits: &Array, pos: i32) -> Result<Array> {
let shape = logits.shape();
if shape.len() != 3 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"slice_position: expected [B, S, V]",
shape.len() as u32,
shape.to_vec(),
)));
}
let s = shape[1] as i32;
if pos < 0 || pos >= s {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"slice_position: pos",
"must be in 0..S",
format_smolstr!("pos={pos} (S={s})"),
)));
}
let (b, v) = (shape[0] as i32, shape[2] as i32);
let sliced = ops::indexing::slice(logits, &[0, pos, 0], &[b, pos + 1, v], &[1, 1, 1])?;
ops::shape::squeeze_axes(&sliced, &[1])
}
fn empty_logprobs() -> Result<Array> {
Array::from_slice::<f32>(&[], &(0usize,))
}