use crate::engine::{EngineConfig, InferenceEngine};
use crate::error::{RuntimeError, RuntimeResult};
#[derive(Debug, Clone)]
pub struct SpeculativeConfig {
pub target: EngineConfig,
pub draft: EngineConfig,
pub num_speculative: usize,
pub seed: Option<u64>,
}
impl SpeculativeConfig {
pub fn new(target: EngineConfig, draft: EngineConfig) -> Self {
Self {
target,
draft,
num_speculative: 4,
seed: None,
}
}
}
pub struct SpeculativeEngine {
draft: InferenceEngine,
target: InferenceEngine,
num_speculative: usize,
rng: Xorshift64,
}
impl SpeculativeEngine {
pub fn new(config: SpeculativeConfig) -> RuntimeResult<Self> {
let seed = config.seed.unwrap_or(0x517cc1b727220a95u64);
let mut draft = InferenceEngine::new(config.draft);
draft.load_model()?;
let mut target = InferenceEngine::new(config.target);
target.load_model()?;
Ok(Self {
draft,
target,
num_speculative: config.num_speculative,
rng: Xorshift64::new(seed),
})
}
pub fn generate(
&mut self,
prompt: &str,
max_tokens: usize,
mut callback: impl FnMut(&str),
) -> RuntimeResult<String> {
self.draft.reset();
self.target.reset();
let prompt_tokens = self.target.tokenize(prompt)?;
if prompt_tokens.is_empty() {
return Ok(String::new());
}
self.draft.prefill(&prompt_tokens)?;
self.target.prefill(&prompt_tokens)?;
let mut all_tokens: Vec<u32> = prompt_tokens;
let mut generated = String::new();
let mut tokens_generated = 0usize;
while tokens_generated < max_tokens {
let k = self.num_speculative.min(max_tokens - tokens_generated);
let last_token = *all_tokens.last().ok_or(RuntimeError::ModelLoadError {
message: "token history is unexpectedly empty".to_string(),
})?;
let mut draft_logits = self.draft.forward_one(last_token)?;
let mut draft_tokens: Vec<u32> = Vec::with_capacity(k);
let mut draft_probs: Vec<f32> = Vec::with_capacity(k);
for _ in 0..k {
let (token, prob) = sample_with_prob(&draft_logits, &mut self.rng);
if self.draft.is_eos(token) {
break;
}
draft_tokens.push(token);
draft_probs.push(prob);
draft_logits = self.draft.forward_one(token)?;
}
if draft_tokens.is_empty() {
let target_logits = self.target.forward_one(last_token)?;
let (bonus_tok, _) = sample_with_prob(&target_logits, &mut self.rng);
if !self.target.is_eos(bonus_tok) {
let text = self.target.decode_token(bonus_tok)?;
callback(&text);
generated.push_str(&text);
}
break;
}
let mut accepted = 0usize;
let mut bonus_token: Option<u32> = None;
let mut target_logits = self.target.forward_one(last_token)?;
for (i, (&draft_tok, &p_draft)) in
draft_tokens.iter().zip(draft_probs.iter()).enumerate()
{
let target_probs = softmax(&target_logits);
let p_target = target_probs
.get(draft_tok as usize)
.copied()
.unwrap_or(0.0f32);
let u = self.rng.next_f32();
let accept_threshold = (p_target / p_draft.max(1e-10)).min(1.0);
if u <= accept_threshold {
accepted += 1;
if self.target.is_eos(draft_tok) {
commit_and_emit(
&draft_tokens[..accepted.saturating_sub(1)],
&mut self.target,
&mut all_tokens,
&mut generated,
&mut tokens_generated,
&mut callback,
)?;
return Ok(generated);
}
if i + 1 < draft_tokens.len() {
target_logits = self.target.forward_one(draft_tok)?;
} else {
target_logits = self.target.forward_one(draft_tok)?;
}
} else {
let target_probs_for_bonus = softmax(&target_logits);
let draft_probs_full = softmax_draft_at(&draft_logits, &draft_tokens, i);
let bonus =
sample_residual(&target_probs_for_bonus, &draft_probs_full, &mut self.rng);
bonus_token = Some(bonus);
break;
}
}
commit_and_emit(
&draft_tokens[..accepted],
&mut self.target,
&mut all_tokens,
&mut generated,
&mut tokens_generated,
&mut callback,
)?;
if let Some(bonus) = bonus_token {
if !self.target.is_eos(bonus) {
let _fwd = self.target.forward_one(bonus)?;
let text = self.target.decode_token(bonus)?;
callback(&text);
generated.push_str(&text);
all_tokens.push(bonus);
tokens_generated += 1;
}
resync_draft(&mut self.draft, &all_tokens)?;
} else if accepted == draft_tokens.len() {
let (bonus, _) = sample_with_prob(&target_logits, &mut self.rng);
if self.target.is_eos(bonus) {
break;
}
let _fwd = self.target.forward_one(bonus)?;
let text = self.target.decode_token(bonus)?;
callback(&text);
generated.push_str(&text);
all_tokens.push(bonus);
tokens_generated += 1;
resync_draft(&mut self.draft, &all_tokens)?;
}
}
Ok(generated)
}
}
fn resync_draft(draft: &mut InferenceEngine, all_tokens: &[u32]) -> RuntimeResult<()> {
draft.reset();
if all_tokens.len() > 1 {
draft.prefill(&all_tokens[..all_tokens.len() - 1])?;
}
Ok(())
}
fn commit_and_emit(
tokens: &[u32],
target: &mut InferenceEngine,
all_tokens: &mut Vec<u32>,
generated: &mut String,
tokens_generated: &mut usize,
callback: &mut impl FnMut(&str),
) -> RuntimeResult<()> {
for &tok in tokens {
let text = target.decode_token(tok)?;
callback(&text);
generated.push_str(&text);
all_tokens.push(tok);
*tokens_generated += 1;
}
Ok(())
}
fn sample_with_prob(logits: &[f32], rng: &mut Xorshift64) -> (u32, f32) {
if logits.is_empty() {
return (0, 1.0);
}
let probs = softmax(logits);
let r = rng.next_f32();
let mut cumulative = 0.0f32;
for (i, &p) in probs.iter().enumerate() {
cumulative += p;
if r < cumulative {
return (i as u32, p);
}
}
let last = probs.len() - 1;
(last as u32, probs[last])
}
fn sample_residual(p_target: &[f32], p_draft: &[f32], rng: &mut Xorshift64) -> u32 {
let len = p_target.len().max(p_draft.len());
let mut residual: Vec<f32> = (0..len)
.map(|i| {
let pt = p_target.get(i).copied().unwrap_or(0.0);
let pd = p_draft.get(i).copied().unwrap_or(0.0);
(pt - pd).max(0.0)
})
.collect();
let z: f32 = residual.iter().sum();
if z > 1e-10 {
for v in &mut residual {
*v /= z;
}
} else {
residual.clear();
residual.extend_from_slice(p_target);
}
let r = rng.next_f32();
let mut cumulative = 0.0f32;
for (i, &p) in residual.iter().enumerate() {
cumulative += p;
if r < cumulative {
return i as u32;
}
}
residual.len().saturating_sub(1) as u32
}
fn softmax_draft_at(
draft_logits_at_pos: &[f32],
draft_tokens: &[u32],
candidate_idx: usize,
) -> Vec<f32> {
let _ = (draft_tokens, candidate_idx); softmax(draft_logits_at_pos)
}
fn softmax(logits: &[f32]) -> Vec<f32> {
if logits.is_empty() {
return Vec::new();
}
let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let mut exps: Vec<f32> = logits.iter().map(|&v| (v - max_val).exp()).collect();
let sum: f32 = exps.iter().sum();
if sum > 0.0 {
for v in &mut exps {
*v /= sum;
}
}
exps
}
struct Xorshift64 {
state: u64,
}
impl Xorshift64 {
fn new(seed: u64) -> Self {
Self {
state: if seed == 0 {
0x517cc1b727220a95u64
} else {
seed
},
}
}
fn next_u64(&mut self) -> u64 {
let mut x = self.state;
x ^= x << 13;
x ^= x >> 7;
x ^= x << 17;
self.state = x;
x
}
fn next_f32(&mut self) -> f32 {
(self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_xorshift_range() {
let mut rng = Xorshift64::new(42);
for _ in 0..100_000 {
let v = rng.next_f32();
assert!((0.0..1.0).contains(&v), "xorshift_f32 out of range: {v}");
}
}
#[test]
fn test_xorshift_different_seeds() {
let mut rng1 = Xorshift64::new(0x1111_1111_1111_1111u64);
let mut rng2 = Xorshift64::new(0x2222_2222_2222_2222u64);
let any_different = (0..10).any(|_| rng1.next_f32() != rng2.next_f32());
assert!(
any_different,
"two different seeds must produce at least one differing value in 10 steps"
);
}
#[test]
fn test_softmax_sums_to_one() {
let logits = vec![1.0f32, 2.0, 0.5, -1.0, 3.0];
let probs = softmax(&logits);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax should sum to 1.0, got {sum}"
);
}
#[test]
fn test_softmax_empty() {
assert!(softmax(&[]).is_empty());
}
#[test]
fn test_speculative_config_defaults() {
let target = EngineConfig {
model_path: "target.gguf".to_string(),
..EngineConfig::default()
};
let draft = EngineConfig {
model_path: "draft.gguf".to_string(),
..EngineConfig::default()
};
let cfg = SpeculativeConfig::new(target.clone(), draft.clone());
assert_eq!(cfg.num_speculative, 4);
assert_eq!(cfg.target.model_path, "target.gguf");
assert_eq!(cfg.draft.model_path, "draft.gguf");
}
#[test]
fn test_speculative_config_override() {
let target = EngineConfig::default();
let draft = EngineConfig::default();
let cfg = SpeculativeConfig {
num_speculative: 8,
..SpeculativeConfig::new(target, draft)
};
assert_eq!(cfg.num_speculative, 8);
}
#[test]
fn test_speculative_engine_is_send() {
fn assert_send<T: Send>() {}
assert_send::<SpeculativeEngine>();
}
#[test]
fn test_sample_residual_identical_distributions() {
let p = vec![0.25f32, 0.25, 0.25, 0.25];
let mut rng = Xorshift64::new(99);
let token = sample_residual(&p, &p, &mut rng);
assert!((token as usize) < p.len());
}
#[test]
fn test_sample_residual_peaked() {
let p_target = vec![0.0f32, 0.0, 1.0, 0.0];
let p_draft = vec![0.0f32, 0.0, 0.0, 0.0];
let mut rng = Xorshift64::new(7);
for _ in 0..100 {
let token = sample_residual(&p_target, &p_draft, &mut rng);
assert_eq!(token, 2, "should always pick index 2");
}
}
#[test]
fn test_sample_with_prob_empty_logits() {
let mut rng = Xorshift64::new(1);
let (token, prob) = sample_with_prob(&[], &mut rng);
assert_eq!(token, 0);
assert!((prob - 1.0).abs() < 1e-6);
}
#[test]
fn test_sample_with_prob_single() {
let logits = vec![5.0f32];
let mut rng = Xorshift64::new(42);
for _ in 0..50 {
let (token, prob) = sample_with_prob(&logits, &mut rng);
assert_eq!(token, 0, "single-element: must pick 0");
assert!((prob - 1.0).abs() < 1e-5, "single-element prob must be 1.0");
}
}
#[test]
fn test_sample_with_prob_peaked() {
let logits = vec![-1000.0f32, -1000.0, 1000.0, -1000.0];
let mut rng = Xorshift64::new(99);
for _ in 0..100 {
let (token, _prob) = sample_with_prob(&logits, &mut rng);
assert_eq!(token, 2, "peaked distribution must always return index 2");
}
}
#[test]
fn test_softmax_draft_at_sums_to_one() {
let logits = vec![0.5f32, 1.5, -0.5, 2.0];
let draft_tokens = vec![1u32, 3u32];
let probs = softmax_draft_at(&logits, &draft_tokens, 0);
let sum: f32 = probs.iter().sum();
assert!(
(sum - 1.0).abs() < 1e-5,
"softmax_draft_at must sum to 1, got {sum}"
);
assert_eq!(probs.len(), logits.len());
}
#[test]
fn test_softmax_draft_at_empty() {
let probs = softmax_draft_at(&[], &[], 0);
assert!(probs.is_empty());
}
#[test]
fn test_xorshift_zero_seed_nonzero_output() {
let mut rng = Xorshift64::new(0);
let v = rng.next_u64();
assert_ne!(v, 0, "zero seed must be remapped to non-zero state");
}
#[test]
fn test_softmax_uniform_logits() {
let logits = vec![1.0f32; 4];
let probs = softmax(&logits);
for &p in &probs {
assert!((p - 0.25).abs() < 1e-5, "uniform logits → p=0.25, got {p}");
}
}
#[test]
fn test_sample_residual_empty() {
let mut rng = Xorshift64::new(3);
let token = sample_residual(&[], &[], &mut rng);
assert_eq!(token, 0, "empty distributions → fallback index 0");
}
use crate::engine::InferenceEngine;
fn make_loaded_engine(
num_speculative: usize,
) -> crate::error::RuntimeResult<SpeculativeEngine> {
let gguf_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
let tok_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
let mut draft = InferenceEngine::new(EngineConfig::default());
draft.load_model_from_bytes(&gguf_bytes, tok_json)?;
let mut target = InferenceEngine::new(EngineConfig::default());
target.load_model_from_bytes(&gguf_bytes, tok_json)?;
Ok(SpeculativeEngine {
draft,
target,
num_speculative,
rng: Xorshift64::new(42),
})
}
#[test]
fn test_speculative_engine_loads_from_bytes() {
let result = make_loaded_engine(4);
assert!(
result.is_ok(),
"SpeculativeEngine should load from synthetic GGUF: {:?}",
result.err()
);
}
#[test]
fn test_speculative_engine_new_fails_with_missing_model() {
let target = EngineConfig {
model_path: "/nonexistent/target.gguf".to_string(),
..EngineConfig::default()
};
let draft = EngineConfig {
model_path: "/nonexistent/draft.gguf".to_string(),
..EngineConfig::default()
};
let cfg = SpeculativeConfig::new(target, draft);
let result = SpeculativeEngine::new(cfg);
assert!(result.is_err(), "should fail with missing model files");
}
#[test]
fn test_speculative_engine_generate_short_prompt() {
let mut engine = match make_loaded_engine(2) {
Ok(e) => e,
Err(e) => {
eprintln!("skip test_speculative_engine_generate_short_prompt: {e}");
return;
}
};
let result = engine.generate("abc", 5, |_tok| {});
assert!(
result.is_ok(),
"generate should succeed on synthetic model: {:?}",
result.err()
);
}
#[test]
fn test_speculative_engine_generate_multiple_rounds() {
let mut engine = match make_loaded_engine(2) {
Ok(e) => e,
Err(e) => {
eprintln!("skip test_speculative_engine_generate_multiple_rounds: {e}");
return;
}
};
let result = engine.generate("hello", 10, |_tok| {});
assert!(
result.is_ok(),
"multi-round generate should succeed: {:?}",
result.err()
);
}
#[test]
fn test_speculative_engine_generate_callback_accumulates() {
let mut engine = match make_loaded_engine(2) {
Ok(e) => e,
Err(e) => {
eprintln!("skip test_speculative_engine_generate_callback_accumulates: {e}");
return;
}
};
let mut cb_output = String::new();
let result = engine.generate("ab", 4, |tok| cb_output.push_str(tok));
let generated = match result {
Ok(s) => s,
Err(e) => {
eprintln!("skip callback accumulation check: {e}");
return;
}
};
assert_eq!(
generated, cb_output,
"returned string must equal callback accumulation"
);
}
#[test]
fn test_speculative_engine_generate_empty_prompt() {
let mut engine = match make_loaded_engine(2) {
Ok(e) => e,
Err(e) => {
eprintln!("skip test_speculative_engine_generate_empty_prompt: {e}");
return;
}
};
let result = engine.generate("", 5, |_| {});
let _ = result;
}
#[test]
fn test_speculative_config_seed_none_by_default() {
let cfg = SpeculativeConfig::new(EngineConfig::default(), EngineConfig::default());
assert!(
cfg.seed.is_none(),
"seed should be None by default, got {:?}",
cfg.seed
);
}
#[test]
fn test_speculative_config_num_speculative_default_is_4() {
let cfg = SpeculativeConfig::new(EngineConfig::default(), EngineConfig::default());
assert_eq!(cfg.num_speculative, 4);
}
#[test]
fn test_speculative_engine_deterministic_with_seed() {
let mk = |seed: u64| -> Option<String> {
let gguf_bytes = oxillama_gguf::test_utils::build_minimal_llama_gguf();
let tok_json = oxillama_gguf::test_utils::minimal_tokenizer_json();
let mut draft = InferenceEngine::new(EngineConfig::default());
draft.load_model_from_bytes(&gguf_bytes, tok_json).ok()?;
let mut target = InferenceEngine::new(EngineConfig::default());
target.load_model_from_bytes(&gguf_bytes, tok_json).ok()?;
let mut engine = SpeculativeEngine {
draft,
target,
num_speculative: 2,
rng: Xorshift64::new(seed),
};
engine.generate("test", 4, |_| {}).ok()
};
let run1 = mk(0xdead_beef_cafe_babe);
let run2 = mk(0xdead_beef_cafe_babe);
if run1.is_some() && run2.is_some() {
assert_eq!(run1, run2, "identical seeds must produce identical output");
}
}
#[test]
fn test_speculative_engine_generate_twice() {
let mut engine = match make_loaded_engine(2) {
Ok(e) => e,
Err(e) => {
eprintln!("skip test_speculative_engine_generate_twice: {e}");
return;
}
};
let r1 = engine.generate("first", 3, |_| {});
let r2 = engine.generate("second", 3, |_| {});
assert!(r1.is_ok(), "first generate should succeed: {:?}", r1.err());
assert!(
r2.is_ok(),
"second generate should succeed (state reset): {:?}",
r2.err()
);
}
}